diff --git a/src/testing.py b/src/testing.py index e21ecef..001afd7 100644 --- a/src/testing.py +++ b/src/testing.py @@ -18,6 +18,20 @@ from .bot import MILBot +@dataclass +class TestingMember: + member: discord.Member + swimming: bool + + def embed_str(self) -> str: + return f"* {self.member.mention}{' (swimming)' if self.swimming else ''}" + + def __str__(self) -> str: + return f"TestingMember(member={self.member}, swimming={self.swimming})" + + __repr__ = __str__ + + class MemberTestingAttendance(Enum): CANNOT = "cannot" CANNOTDRIVE = "cannotdrive" @@ -45,25 +59,57 @@ def emoji(self) -> str: @dataclass class TestingAttendance: - cannot: list[discord.Member] - cannotdrive: list[discord.Member] - candrive: list[discord.Member] - candrivesub: list[discord.Member] + cannot: list[TestingMember] + cannotdrive: list[TestingMember] + candrive: list[TestingMember] + candrivesub: list[TestingMember] + + def __post_init__(self): + self.associations: dict[MemberTestingAttendance, list[TestingMember]] = { + MemberTestingAttendance.CANNOT: self.cannot, + MemberTestingAttendance.CANNOTDRIVE: self.cannotdrive, + MemberTestingAttendance.CANDRIVE: self.candrive, + MemberTestingAttendance.CANDRIVESUB: self.candrivesub, + } @property - def attending(self) -> list[discord.Member]: + def attending(self) -> list[TestingMember]: return self.candrive + self.candrivesub + self.cannotdrive def members_with_state( self, state: MemberTestingAttendance, - ) -> list[discord.Member]: - return { - MemberTestingAttendance.CANNOT: self.cannot, - MemberTestingAttendance.CANNOTDRIVE: self.cannotdrive, - MemberTestingAttendance.CANDRIVE: self.candrive, - MemberTestingAttendance.CANDRIVESUB: self.candrivesub, - }[state] + ) -> list[TestingMember]: + return self.associations[state] + + def update_members( + self, + state: MemberTestingAttendance, + members: list[TestingMember], + ): + self.associations[state] = members + + +class SwimmingView(MILBotView): + + swimming: bool | None + + def __init__(self, bot: MILBot): + self.bot = bot + self.swimming = None + super().__init__() + + @discord.ui.button(label="Yes", style=discord.ButtonStyle.green) + async def yes(self, interaction: discord.Interaction, _): + self.swimming = True + await interaction.response.defer() + self.stop() + + @discord.ui.button(label="No", style=discord.ButtonStyle.red) + async def no(self, interaction: discord.Interaction, _): + self.swimming = False + await interaction.response.defer() + self.stop() class TestingSignUpSelect(discord.ui.Select): @@ -92,7 +138,7 @@ async def parse_embed_field( self, embed: discord.Embed, field_name: str, - ) -> list[discord.Member]: + ) -> list[TestingMember]: """ Parses the value of the embed, assuming that it is a new-line Markdown bulleted list of member mentions. @@ -101,12 +147,13 @@ async def parse_embed_field( if value == "": return [] raw_mentions = value.split("\n") - ids = [] + members: list[TestingMember] = [] for mention_str in raw_mentions: id = re.findall(r"<@!?(\d+)>", mention_str) + swimming = "(swimming)" in mention_str if id: - ids.append(int(id[0])) - members = [await self.bot.get_member(id) for id in ids] + member = await self.bot.get_member(int(id[0])) + members.append(TestingMember(member, swimming)) return members def get_field_named(self, embed: discord.Embed, field_name: str) -> str: @@ -143,10 +190,10 @@ def replace_embed_value( embed.set_field_at(i, name=field.name, value=new_value, inline=False) break - def format_members(self, members: list[discord.Member]) -> str: + def format_members(self, members: list[TestingMember]) -> str: if not members: return "_No members yet._" - return "\n".join([f"* {member.mention}" for member in members]) + return "\n".join([member.embed_str() for member in members]) async def callback(self, interaction: discord.Interaction): message = interaction.message @@ -171,38 +218,55 @@ async def callback(self, interaction: discord.Interaction): embed_field_name = state.english_title members_with_state = attendance.members_with_state(state) - # If member already in list, remove them. Otherwise, add them. - if interaction.user in members_with_state: - members_with_state.remove(interaction.user) - self.replace_embed_value( - embed, - embed_field_name, - self.format_members(members_with_state), - ) + swimming_view = SwimmingView(self.bot) + if state is not MemberTestingAttendance.CANNOT: await interaction.response.send_message( - "You have been removed from the list.", + "Are you planning on swimming?", + view=swimming_view, ephemeral=True, ) + await swimming_view.wait() + + # Remove the member from every list + members_previous_state = None + was_swimming = None + for available_state in MemberTestingAttendance: + cur_state = attendance.members_with_state(available_state) + if interaction.user in [m.member for m in cur_state]: + was_swimming = next( + m.swimming for m in cur_state if m.member == interaction.user + ) + cur_state = [m for m in cur_state if m.member != interaction.user] + members_previous_state = available_state + attendance.update_members(available_state, cur_state) + self.replace_embed_value( + embed, + available_state.english_title, + self.format_members(cur_state), + ) + + response = None + if state == members_previous_state and was_swimming == ( + swimming_view.swimming or False + ): + response = "You have been removed from the list." else: - # If the user is already in another list, remove them from that list. - for state in MemberTestingAttendance: - cur_state = attendance.members_with_state(state) - if interaction.user in cur_state: - cur_state.remove(interaction.user) - self.replace_embed_value( - embed, - state.english_title, - self.format_members(cur_state), - ) - - members_with_state.append(interaction.user) + members_with_state = attendance.members_with_state(state) + members_with_state.append( + TestingMember(interaction.user, swimming_view.swimming or False), + ) self.replace_embed_value( embed, embed_field_name, self.format_members(members_with_state), ) + response = "Your response was recorded!" + + if state is not MemberTestingAttendance.CANNOT: + await interaction.edit_original_response(content=response, view=None) + else: await interaction.response.send_message( - "Your response was recorded!", + response, ephemeral=True, )