Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of unexpected DMs #282

Merged
merged 6 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions botto/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def parse(config):
"support": {
"channel_id": "890978723451523086",
"user_ids": ["328674204780068864"],
"dm_log_channel": "935122779643191347",
},
"watching_statūs": ["for food", "for snails", "for apologies", "for love"],
"disabled_features": {},
Expand Down Expand Up @@ -362,6 +363,9 @@ def parse(config):
if user_ids := decode_base64_env("TLDBOTTO_SUPPORT_USER_IDS"):
defaults["support"]["user_ids"] = user_ids

if channel_id := os.getenv("TLDBOTTO_SUPPORT_DM_LOG_CHANNEL_ID"):
defaults["support"]["dm_log_channel"] = channel_id

defaults["clickup_enabled_guilds"] = set(defaults["clickup_enabled_guilds"])

if id := os.getenv("TLDBOTTO_ID"):
Expand Down
8 changes: 5 additions & 3 deletions botto/mixins/clickup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ async def clickup_task(self, message: Message, **kwargs):
task_embed = (
discord.Embed(
title=truncate_string(task.name, 256),
description=truncate_string(task.description, 4096)
if task.description
else discord.Embed.Empty,
description=(
truncate_string(task.description, 4096)
if task.description
else None
),
colour=discord.Colour.from_rgb(*hex_to_rgb(task.status.colour)),
timestamp=task.date_created,
url=task.url,
Expand Down
17 changes: 11 additions & 6 deletions botto/mixins/reaction_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def get_or_fetch_message(

processing_emoji = "⏳"


class ReactionRoles(ExtendedClient):
def __init__(
self,
Expand All @@ -79,7 +80,7 @@ def __init__(
trigger="cron",
minute="*/30",
coalesce=True,
next_run_time=datetime.now() + timedelta(seconds=5),
next_run_time=datetime.utcnow() + timedelta(seconds=5),
misfire_grace_time=10,
)
self.tester_locks = WeakValueDictionary()
Expand Down Expand Up @@ -509,11 +510,15 @@ async def handle_role_approval(self, payload: discord.RawReactionActionEvent):
raise

try:
other_messages = [
channel.get_partial_message(message_id)
for message_id in testing_request.further_notification_message_ids
if int(message_id) != payload.message_id
]
other_messages = (
[
channel.get_partial_message(message_id)
for message_id in testing_request.further_notification_message_ids
if int(message_id) != payload.message_id
]
if testing_request.further_notification_message_ids
else []
)
if int(testing_request.notification_message_id) != payload.message_id:
other_messages.append(
channel.get_partial_message(
Expand Down
15 changes: 14 additions & 1 deletion botto/mixins/remote_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import json
from datetime import datetime, timedelta
from typing import Literal, Optional, Union

import discord
from apscheduler.schedulers.asyncio import AsyncIOScheduler

from botto.storage import ConfigStorage
Expand All @@ -23,7 +25,7 @@ def __init__(
trigger="cron",
minute="*/40",
coalesce=True,
next_run_time=datetime.now() + timedelta(seconds=5),
next_run_time=datetime.utcnow() + timedelta(seconds=5),
)
super().__init__(scheduler=scheduler, **kwargs)

Expand All @@ -41,3 +43,14 @@ async def is_feature_disabled(
str(server_id), "disabled_features"
)
return feature_name in disabled_features_for_server.parsed_value

async def should_respond_dms(self, member: discord.User) -> bool:
config_entries = await asyncio.gather(
*[
self.config_storage.get_config(guild.id, "respond_member_dms")
for guild in member.mutual_guilds
]
)
return any(
guild_config.parsed_value for guild_config in config_entries if guild_config
)
8 changes: 4 additions & 4 deletions botto/reminder_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self.missed_job_ids = []
self.get_channel_func = None

initial_refresh_run = datetime.now() + timedelta(seconds=5)
initial_refresh_run = datetime.utcnow() + timedelta(seconds=5)
scheduler.add_job(
self.refresh_reminders,
name="Refresh reminders",
Expand Down Expand Up @@ -264,9 +264,9 @@ async def list_reminders(
) -> list[Reminder]:
reminders_for_guild: list[Reminder] = []
async for reminder in self.storage.retrieve_reminders():
reminder_channel: Optional[
discord.TextChannel
] = await self.get_channel_func(reminder.channel_id)
reminder_channel: Optional[discord.TextChannel] = (
await self.get_channel_func(reminder.channel_id)
)
if not reminder_channel or reminder_channel.guild.id != guild.id:
continue
if channel is not None and reminder_channel.id != channel.id:
Expand Down
47 changes: 37 additions & 10 deletions botto/storage/config_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
log = logging.getLogger(__name__)

ConfigCache = dict[str, dict[str, ConfigEntry]]
NegativeConfigKeyCache = dict[str, set[str]]


class ConfigStorage(Storage):
Expand All @@ -19,6 +20,8 @@ def __init__(self, airtable_base: str, airtable_key: str):
)
self.config_lock = asyncio.Lock()
self.config_cache: ConfigCache = {}
self.config_key_negative_cache_lock = asyncio.Lock()
self.config_key_negative_cache: NegativeConfigKeyCache = {}
self.auth_header = {"Authorization": f"Bearer {self.airtable_key}"}

async def clear_server_cache(self, server_id: str):
Expand All @@ -42,8 +45,8 @@ async def list_config_by_server(self) -> ConfigCache:
return self.config_cache

async def retrieve_config(
self, server_id: str, key: Optional[Union[str, int]]
) -> Optional[ConfigEntry]:
self, server_id: str | int, key: Optional[Union[str, int]]
) -> Optional[ConfigEntry | dict[str, ConfigEntry]]:
log.debug(f"Fetching {key or 'config'} for {server_id}")
filter_by_formula = f"AND({{Server ID}}='{server_id}'"
if key := key:
Expand All @@ -56,17 +59,26 @@ async def retrieve_config(
)
config_iterator = (ConfigEntry.from_airtable(x) async for x in result_iterator)
try:
config = await config_iterator.__anext__()
async with self.config_lock:
server_config = self.config_cache.get(str(server_id), {})
server_config[config.config_key] = config
self.config_cache[config.server_id] = server_config
return config
except (StopIteration, StopAsyncIteration):
async for config in config_iterator:
server_config[config.config_key] = config
self.config_cache[str(server_id)] = server_config
return server_config if not key else server_config[key]
except (StopIteration, StopAsyncIteration, KeyError):
log.info(f"No config found for Key {key} with Server ID {server_id}")
async with self.config_key_negative_cache_lock:
self.config_key_negative_cache.setdefault(str(server_id), set()).add(
key
)
return None

async def get_config(self, server_id: str, key: str) -> Optional[ConfigEntry]:
async def get_config(self, server_id: str | int, key: str) -> Optional[ConfigEntry]:
if (
not self.config_key_negative_cache_lock.locked()
and self.config_key_negative_cache.get(str(server_id), {}).get(key)
):
return None
await self.config_lock.acquire()
if (server_config := self.config_cache.get(str(server_id))) and (
config := server_config.get(key)
Expand All @@ -82,6 +94,21 @@ async def refresh_cache(self):
await self.config_lock.acquire()
current_cache = self.config_cache
self.config_lock.release()
for key in current_cache:
for key, entries in current_cache.items():
log.debug(f"Refreshing config for server {key}")
await self.retrieve_config(server_id=key, key=None)
for entry in entries.values():
await self.retrieve_config(server_id=key, key=entry.config_key)

log.info("Refreshing negative config key cache")
for server_id, keys in self.config_key_negative_cache.items():
keys_to_remove = []
for key in keys:
async with self.config_key_negative_cache_lock:
if await self.retrieve_config(server_id, key):
keys_to_remove.append(key)
log.debug(
f"Previously non-existent key {key} now exists for {server_id}"
)
for key in keys_to_remove:
self.config_key_negative_cache[server_id].remove(key)
log.info("Config cache refreshed")
98 changes: 65 additions & 33 deletions botto/tld_botto.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
hour="*/12",
coalesce=True,
)
initial_refresh_run = datetime.now() + timedelta(seconds=5)
initial_refresh_run = datetime.utcnow() + timedelta(seconds=5)
scheduler.add_job(
self.storage.update_meals_cache,
name="Refresh meals cache",
Expand Down Expand Up @@ -564,7 +564,8 @@ async def react(self, message: Message):
await reaction[1](message)
has_matched = True
if (
not await self.is_feature_disabled(
message.guild
and not await self.is_feature_disabled(
"apology_reaction", str(message.guild.id)
)
and not self.regexes.sorry.search(message.content)
Expand Down Expand Up @@ -718,36 +719,10 @@ async def process_dm(self, message: Message):
message_content = message.content.lower().strip()
dm_channel = await get_dm_channel(message.author)
if message_content in ("!help", "help", "help!", "halp", "halp!", "!halp"):
help_message = f"""
I am a multi-function bot providing assistance and jokes.
""".strip()

support_config = self.config["support"]
support_channel_id = support_config.get("channel_id")
support_user_ids = support_config.get("user_ids")

if support_user_ids or support_channel_id:
message_add = "\nIf you need assistance with my operation"
if support_channel := self.get_channel(int(support_channel_id)):
support_guild = self.get_guild(support_channel.guild.id)
message_add = (
f"{message_add} and are a member of `{support_guild.name}`, "
f"please ask for help in {support_channel.mention}"
)
if support_user_ids:
message_add = f"{message_add}. Otherwise"
if support_user_ids:
users = [
self.get_user(int(user_id)).mention
for user_id in support_user_ids
]
if len(support_user_ids) > 1:
message_add = f"{message_add}, please DM one of the following users: {', '.join(users)}"
else:
message_add = f"{message_add}, please DM {', '.join(users)}"
help_message = f"{help_message}\n{message_add}."

await dm_channel.send(help_message)
async with dm_channel.typing():
help_message = await self.make_help_message(message)
logging.info(f"Sending help message in response to {message}")
await dm_channel.send(help_message)
return

if message_content == "!version":
Expand All @@ -774,7 +749,64 @@ async def process_dm(self, message: Message):
return

if not await self.react(message):
await self.reactions.unknown_dm(message)
react_task = asyncio.create_task(self.reactions.unknown_dm(message))
log_task = asyncio.create_task(self.log_dm(message))
if await self.should_respond_dms(message.author):
async with dm_channel.typing():
help_message = await self.make_help_message(message)
logging.info(f"Sending help message in response to {message}")
await dm_channel.send(
"Sorry, I am not currently capable of extended conversation, but I have "
"forwarded your message to my operators. " + help_message
)
await asyncio.gather(react_task, log_task)

async def log_dm(self, message: Message):
support_config = self.config["support"]
dm_log_channel_id = support_config.get("dm_log_channel")
if not dm_log_channel_id:
log.warning("No DM log channel configured")
return
dm_log_channel = self.get_channel(int(dm_log_channel_id))
embed = discord.Embed(
title=truncate_string(f"New DM", 256),
description=(
truncate_string(message.content, 4096) if message.content else None
),
timestamp=message.created_at,
).set_author(name=f"{message.author.name} ({message.author.mention})")
await dm_log_channel.send(embed=embed)

async def make_help_message(self, responding_to: Message):
help_message = f"""
I am a multi-function bot providing assistance and jokes.
""".strip()
support_config = self.config["support"]
support_channel_id = support_config.get("channel_id")
support_user_ids = support_config.get("user_ids")
if support_user_ids or support_channel_id:
message_add = "\nIf you need assistance with my operation"
if (
(support_channel := self.get_channel(int(support_channel_id)))
and support_channel.guild
and (support_channel.guild in responding_to.author.mutual_guilds)
):
message_add = (
f"{message_add} and are a member of `{support_channel.guild.name}`, "
f"please ask for help in {support_channel.mention}"
)
if support_user_ids:
message_add = f"{message_add}. Otherwise"
if support_user_ids:
users = [
self.get_user(int(user_id)).mention for user_id in support_user_ids
]
if len(support_user_ids) > 1:
message_add = f"{message_add}, please DM one of the following users: {', '.join(users)}"
else:
message_add = f"{message_add}, please DM {', '.join(users)}"
help_message = f"{help_message}\n{message_add}."
return help_message

@property
def local_times(self) -> list[datetime]:
Expand Down
Loading