Initial commit

This commit is contained in:
bain 2022-08-21 23:12:25 +02:00
commit 0354f21ea8
No known key found for this signature in database
GPG key ID: A708F07AF3D92C02
12 changed files with 637 additions and 0 deletions

6
.gitignore vendored Normal file
View file

@ -0,0 +1,6 @@
/venv/
/pyrightconfig.json
/.env
/.vimspector.json
*__pycache__*
/cache/

4
ari/__main__.py Normal file
View file

@ -0,0 +1,4 @@
from .bot import bot
from .constants import BOT_TOKEN
bot.run(BOT_TOKEN, log_handler=None)

79
ari/bot.py Normal file
View file

@ -0,0 +1,79 @@
import logging
from typing import Dict
import discord
from discord import app_commands
from .player import Player
from .constants import BAINS_POSIT_NOTE, MUSIC_CACHE, Emoji
from .messages import MessageHandler
from .cache import Cache
discord.utils.setup_logging()
logging.getLogger("ari").setLevel(logging.DEBUG)
logger = logging.getLogger(__name__)
intents = discord.Intents.default()
intents.message_content = True
class Bot(discord.Client):
def __init__(self, intents: discord.Intents) -> None:
super().__init__(intents=intents)
self.tree = app_commands.CommandTree(self)
self.message_handler = MessageHandler(self)
self.music_cache = Cache(self.http, MUSIC_CACHE, max_size=30000000)
self.players: Dict[int, Player] = {}
async def setup_hook(self) -> None:
return
await self.tree.sync()
async def on_message(self, message: discord.Message):
await self.message_handler.handle_message(message)
async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent):
await self.message_handler.handle_reaction_add(payload)
bot = Bot(intents)
@bot.tree.command(name="skip", description="Skip specified number of songs")
@app_commands.describe(number="Number of songs to skip")
@app_commands.guild_only()
async def skip(interaction: discord.Interaction, number: int = 1):
assert interaction.guild is not None
if number < 1:
await interaction.response.send_message(
f"{Emoji.error} The number of songs to skip must be bigger or equal to 1",
ephemeral=True,
)
return
player = bot.players.get(interaction.guild.id)
if player is None or not player.is_running():
await interaction.response.send_message(
f"{Emoji.error} I am not currently playing anything", ephemeral=True
)
else:
player.skip(number)
await interaction.response.send_message(
f"Skipping {number} song{'s' if number > 1 else ''}..."
)
@bot.tree.command(name="stop", description="Stop playing and disconnect")
@app_commands.guild_only()
async def stop(interaction: discord.Interaction):
assert interaction.guild is not None
player = bot.players.get(interaction.guild.id)
if player is None or not player.is_running():
await interaction.response.send_message(
f"{Emoji.error} I am not currently playing anything", ephemeral=True
)
else:
player.stop()
await interaction.response.send_message(f"Disconnecting...")

158
ari/cache.py Normal file
View file

@ -0,0 +1,158 @@
import asyncio
import hashlib
import time
from typing import Dict, Tuple, List
import aiohttp
import yt_dlp as youtube_dl
import os
import logging
logger = logging.getLogger(__name__)
class CacheError(Exception):
"""Error while making sure a file exists"""
pass
def link_hash(link: str) -> str:
return hashlib.md5(link.encode()).hexdigest()
def extract_info_yt(link):
with youtube_dl.YoutubeDL({"format": "bestaudio"}) as ydl:
return ydl.extract_info(link, download=False)
async def get_size(link: str):
try:
if link.startswith("https://youtu.be/"):
return (await asyncio.to_thread(extract_info_yt, link)).get( # type: ignore
"filesize", -1
)
else:
async with aiohttp.ClientSession() as session:
async with session.head(link, allow_redirects=True) as response:
return response.content_length or -1
except youtube_dl.DownloadError:
logger.info("failed to fetch video size")
return -1
class Cache:
def __init__(self, http, directory: str, max_size: int = 500 * 1000**2):
self._http = http
self._directory: str = directory
self._max_size: int = max_size
self._current_size: int = 0
self._links: Dict[str, Tuple[int, int, str]] = {}
self.locked_files: List[str] = []
self.cache_lock = asyncio.Lock()
os.makedirs(self._directory, exist_ok=True)
async def _try_free_space(self, size: int) -> bool:
"""NEEDS TO BE GUARDED WITH CACHE LOCK"""
logger.debug(f"freeing {size} bytes of space")
if size > self._max_size:
return False
to_remove = []
successful = False
for link in sorted(self._links, key=lambda x: x[0]):
if self._links[link][2] not in self.locked_files:
os.remove(f"{self._directory}/{self._links[link][2]}")
self._current_size -= self._links[link][1]
to_remove.append(link)
if size + self._current_size <= self._max_size:
successful = True
break
if to_remove:
logger.info(f"removed {len(to_remove)} files")
for link in to_remove:
del self._links[link]
return successful
async def _download_youtube_file(self, link: str):
success = False
with youtube_dl.YoutubeDL(
{
"outtmpl": f"{self._directory}/{link_hash(link)}",
"format": "bestaudio",
"updatetime": False,
"ratelimit": 5000000,
}
) as ydl:
for _ in range(3):
ex = asyncio.to_thread(ydl.download, (link,))
try:
await asyncio.wait_for(ex, 10)
except (
asyncio.TimeoutError,
youtube_dl.utils.DownloadError,
):
pass
else:
success = True
break
finally:
# clean up potential leftovers from ytdl
if os.path.exists(f"{self._directory}/{link_hash(link)}.part"):
os.remove(f"{self._directory}/{link_hash(link)}.part")
if not success:
raise CacheError("Youtube download failed")
async def _download_file(self, link: str):
"""NEEDS TO BE GUARDED WITH CACHE LOCK"""
size = await get_size(link)
if size < 0:
raise CacheError("Music file size unknown")
logger.info(f"Downloading video of size {size / 1000:.2f}kb")
if (size < self._max_size - self._current_size) or (
await self._try_free_space(size)
):
logger.debug(f"size: {self._max_size - self._current_size}")
if link.startswith("https://youtu.be/"):
await self._download_youtube_file(link)
else:
try:
data = await asyncio.wait_for(self._http.get_from_cdn(link), 10)
except asyncio.TimeoutError:
raise CacheError("Discord download is taking too long")
with open(f"{self._directory}/{link_hash(link)}", "wb+") as f:
f.write(data)
self._current_size += size # reserve space
self._links[link] = (time.time_ns(), size, link_hash(link))
return True
raise CacheError("Music file size too large or unknown")
async def ensure_existence(self, link: str):
async with self.cache_lock:
fp = f"{self._directory}/{link_hash(link)}"
if self._links.get(link):
return CacheContextManager(self, fp)
else:
await self._download_file(link)
return CacheContextManager(self, fp)
class CacheContextManager:
def __init__(self, cache: Cache, file: str):
self._cache: Cache = cache
self._file: str = file
def __enter__(self) -> str:
self._cache.locked_files.append(self._file)
if not os.path.exists(self._file):
self._cache.locked_files.remove(self._file)
raise CacheError("cannot lock file; it no longer exists")
return self._file
def __exit__(self, exc_type, exc_val, exc_tb):
self._cache.locked_files.remove(self._file)

17
ari/constants.py Normal file
View file

@ -0,0 +1,17 @@
import os
from typing import NamedTuple
import discord
BAINS_POSIT_NOTE = discord.Object(id=630144683359862814)
class Emoji(NamedTuple):
play = "<:play:1010298926550749247>"
skip_to = "<:push_to_front:1010299358174007396>"
download_error = "<:download_error:1010299866246807662>"
error = "<:error:755487487807324230>"
PRELOAD = bool(int(os.getenv("PRELOAD", "1")))
MUSIC_CACHE = os.getenv("MUSIC_CACHE", "cache")
BOT_TOKEN = os.getenv("BOT_TOKEN", "invalid")

167
ari/messages.py Normal file
View file

@ -0,0 +1,167 @@
import time
import asyncio
import logging
import re
from typing import TYPE_CHECKING, Dict, Set
from bidict import bidict, MutableBidirectionalMapping
from .constants import Emoji
from .queue import Content, QueueItem
from .player import Player
if TYPE_CHECKING:
from typing import Dict, Set
from .bot import Bot
_youtube_regex = re.compile(
r"http(?:s?):\/\/(?:www\.)?youtu(?:be\.com\/watch\?v=|\.be\/)([\w\-\_]*)(&(amp;)?[\w\?=]*)?"
)
import discord
logger = logging.getLogger(__name__)
class MessageHandler:
def __init__(self, client) -> None:
self.client: "Bot" = client
self.skippables: MutableBidirectionalMapping[int, QueueItem] = bidict()
self.requests = {
Emoji.play: self.handle_play_request,
Emoji.download_error: self.show_errors,
}
async def handle_message(self, message: discord.Message) -> None:
if message.guild is None:
return
if _youtube_regex.search(message.content) is not None or any(
[
"audio" in a.content_type
for a in message.attachments
if a.content_type is not None
]
):
logger.debug("message %s has a youtuble link", message.id)
await message.add_reaction(Emoji.play)
async def handle_reaction_add(
self, payload: discord.RawReactionActionEvent
) -> None:
assert self.client.user is not None
if payload.guild_id is None:
return # the message must be in a guild
if payload.user_id == self.client.user.id:
return # ignore self
req = self.requests.get(str(payload.emoji))
if req is not None:
await req(payload)
async def get_or_fetch_message(
self, message_id: int, channel_id: int
) -> discord.Message:
message = next(
filter(lambda m: m.id == message_id, self.client.cached_messages),
None,
)
if message is None:
logger.debug("message %s was not cached, fetching...", message_id)
channel = self.client.get_channel(
channel_id
) or await self.client.fetch_channel(channel_id)
assert isinstance(channel, discord.TextChannel)
message = await channel.fetch_message(message_id)
# manually add message to the client's message cache.
# this way the client can refresh the message when it is edited
cache = self.client._connection._messages
if cache is not None:
cache.append(message)
return message
async def handle_play_request(self, payload: discord.RawReactionActionEvent):
logger.info("play request on message %s", payload.message_id)
# check cache
message = await self.get_or_fetch_message(
payload.message_id, payload.channel_id
)
assert message.guild is not None
user = message.guild.get_member(payload.user_id)
if user is None:
user = await message.guild.fetch_member(payload.user_id)
# get all videos and attachments from the message
playable = []
for attachment in message.attachments:
if attachment.content_type in ("audio/mpeg", "audio/ogg", "audio/wave"):
playable.append(attachment.url)
pos = 0
while pos < len(message.content):
match = _youtube_regex.search(message.content, pos)
if match is None:
break
pos = match.span()[1]
playable.append("https://youtu.be/" + match.groups()[0])
logger.debug("adding %s songs to queue", len(playable))
if not playable:
await user.send(
f"{Emoji.error} There are no playable videos/music in the message"
)
await message.clear_reaction(Emoji.play)
return
if (
next(
filter(
lambda x: str(x.emoji) == Emoji.download_error, message.reactions
),
None,
)
is not None
):
# clear any hanging errors on the message
# caused only when the message is played again before
# the reaction is automatically removed after a timeout
logger.debug("clearing download error reaction")
await message.clear_reaction(Emoji.download_error)
await message.remove_reaction(Emoji.play, discord.Object(id=payload.user_id))
# push video ids to the player queue
player = self.client.players.get(user.guild.id)
if player is None or not player.is_running():
player = await Player.create(self.client, user)
if player is None:
await user.send(
f"{Emoji.error} Failed to connect to voice. Are you in a voice channel?",
)
return
self.client.players[user.guild.id] = player
for id in playable:
player.queue.push(Content(message, id))
if not player.is_running():
asyncio.create_task(player.run())
async def show_errors(self, payload: discord.RawReactionActionEvent):
assert payload.guild_id is not None
dm = await self.client.create_dm(discord.Object(id=payload.user_id))
player = self.client.players.get(payload.guild_id)
if player is None or not player.is_running():
# error emojis should be only visible when a player is running
# this was probably fired in the split second when the player was
# turning off
return
message = f"{Emoji.download_error} ***Sorry, I was not able to play the following songs:***\n```\n"
for error in player.errored_songs:
message += f" - {error}\n"
message += "```"
await dm.send(message)
message = await self.get_or_fetch_message(
payload.message_id, payload.channel_id
)
await message.remove_reaction(
Emoji.download_error, discord.Object(id=payload.user_id)
)

131
ari/player.py Normal file
View file

@ -0,0 +1,131 @@
from yt_dlp import os
from .cache import CacheError
from .constants import Emoji, PRELOAD
from .queue import Queue, QueueItem
import discord
import logging
import asyncio
from typing import TYPE_CHECKING, Set
if TYPE_CHECKING:
from .bot import Bot
logger = logging.getLogger(__name__)
class Player:
def __init__(self, client: "Bot", voice_client: discord.VoiceClient) -> None:
self.queue = Queue()
self.voice = voice_client
self.client = client
self.errored: Set[discord.Message] = set()
self.errored_songs: Set[str] = set()
self._running = False
self._skip = 0
@classmethod
async def create(cls, client: "Bot", user: discord.Member):
if not user.voice or not user.voice.channel:
return None
logger.debug("creating player")
voice_client = user.guild.voice_client
assert voice_client is None or isinstance(voice_client, discord.VoiceClient)
if not voice_client:
voice_client = await user.voice.channel.connect() # type: ignore
else:
if not voice_client.is_connected():
try:
voice_client.stop()
logger.debug("reusing existing voice client")
await voice_client.connect(reconnect=True, timeout=10)
except asyncio.TimeoutError or discord.ConnectionClosed:
logger.warning(f"failed to connect")
return None
return cls(client, voice_client)
async def run(self):
logger.info(f"{hash(self)} running player")
self._running = True
while not self.queue.empty() and self.voice.is_connected() and self._running:
music = self.queue.pop()
if self._skip > 0:
logger.debug(f"skips remaining {self._skip-1}")
self._skip -= 1
continue
assert music is not None
logger.debug("playing %s", music)
try:
with await self.client.music_cache.ensure_existence(
music.content.video_id
) as file:
# ensuring existence can take a long time
if self.voice.is_connected():
self.voice.play(discord.FFmpegPCMAudio(file))
tried_preload = False
while (
self.voice.is_connected()
and self.voice.is_playing()
and self._running
):
if PRELOAD and not tried_preload:
tried_preload = await self.preload()
if self._skip > 0:
logger.debug("skipping currently playing")
self.voice.stop()
self._skip -= 1
break
await asyncio.sleep(1)
except CacheError:
await self.add_error(music)
self.errored.add(music.content.message)
self.errored_songs.add(music.content.video_id)
except Exception:
logger.exception(f"{hash(self)}: exception while playing")
break
self._running = False
await self.voice.disconnect()
await self.cleanup_errors()
logger.info(f"{hash(self)} player shutdown")
async def cleanup_errors(self):
for message in self.errored:
try:
await message.clear_reaction(Emoji.download_error)
except discord.HTTPException:
# the message could already be gone
pass
async def add_error(self, item: QueueItem):
try:
await item.content.message.add_reaction(Emoji.download_error)
except discord.HTTPException as e:
# the message could be already gone
logger.warning(
"could not add error to message %s: %s", item.content.message.id, e
)
def is_running(self):
return self._running
def skip(self, num: int):
self._skip += num
def stop(self):
self._running = False
async def preload(self):
preload = self.queue.peek()
if preload is not None:
# we ignore the context manager, thus we're not actually
# locking the file, just preloading it
try:
await self.client.music_cache.ensure_existence(preload.content.video_id)
return True
except CacheError:
# error silently, maybe we can free up space by letting
# go of the currently playing song
pass
return False

71
ari/queue.py Normal file
View file

@ -0,0 +1,71 @@
from collections import deque
from dataclasses import dataclass
from threading import Lock
from typing import Deque, NamedTuple, Optional
import logging
import discord
from .constants import Emoji
logger = logging.getLogger(__name__)
class LimitReached(Exception):
pass
class Content(NamedTuple):
message: discord.Message
video_id: str
@dataclass(eq=False)
class QueueItem:
content: Content
next: Optional["QueueItem"]
queue: Optional[int]
class Queue:
def __init__(self, max_length: int = 50) -> None:
self._lock = Lock()
self._front = None
self._back = None
self.length = 0
def push(self, item: Content) -> QueueItem:
with self._lock:
if self._front is None:
self._front = QueueItem(content=item, queue=id(self), next=None)
self._back = self._front
else:
assert self._back is not None
self._back.next = QueueItem(content=item, queue=id(self), next=None)
self._back = self._back.next
self.length += 1
return self._back
def pop(self) -> Optional[QueueItem]:
with self._lock:
item = self._front
if item is not None:
self._front = item.next
item.queue = None # item is no longer a part of the queue
self.length -= 1
return item
def skip_to(self, item: QueueItem):
with self._lock:
if item.queue != id(self):
raise ValueError("item must be from this queue")
self._front = item
def peek(self) -> Optional[QueueItem]:
return self._front
def __contains__(self, item: QueueItem) -> bool:
return item.queue == id(self)
def empty(self) -> bool:
return self.length == 0

BIN
images/emoji_play.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 979 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

BIN
images/emoji_warning.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

4
requirements.txt Normal file
View file

@ -0,0 +1,4 @@
discord.py[voice]
bidict
yt_dlp