diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py index 9cfd75df8b..fe53a5c895 100644 --- a/tests/bot/exts/utils/snekbox/test_snekbox.py +++ b/tests/bot/exts/utils/snekbox/test_snekbox.py @@ -3,7 +3,7 @@ from base64 import b64encode from unittest.mock import AsyncMock, MagicMock, Mock, call, create_autospec, patch -from discord import AllowedMentions +from discord import AllowedMentions, ui from discord.ext import commands from pydis_core.utils.paste_service import MAX_PASTE_SIZE @@ -12,7 +12,14 @@ from bot.exts.utils import snekbox from bot.exts.utils.snekbox import EvalJob, EvalResult, Snekbox from bot.exts.utils.snekbox._io import FileAttachment -from tests.helpers import MockBot, MockContext, MockMember, MockMessage, MockReaction, MockUser +from tests.helpers import ( + MockBot, + MockContext, + MockMember, + MockMessage, + MockReaction, + MockUser, +) class SnekboxTests(unittest.IsolatedAsyncioTestCase): @@ -25,12 +32,14 @@ def setUp(self): @staticmethod def code_args(code: str) -> tuple[EvalJob]: """Converts code to a tuple of arguments expected.""" - return EvalJob.from_code(code), + return (EvalJob.from_code(code),) async def test_post_job(self): """Post the eval code to the URLs.snekbox_eval_api endpoint.""" resp = MagicMock() - resp.json = AsyncMock(return_value={"stdout": "Hi", "returncode": 137, "files": []}) + resp.json = AsyncMock( + return_value={"stdout": "Hi", "returncode": 137, "files": []} + ) context_manager = MagicMock() context_manager.__aenter__.return_value = resp @@ -50,9 +59,7 @@ async def test_post_job(self): "executable_path": f"/snekbin/python/{py_version}/bin/python", } self.bot.http_session.post.assert_called_with( - constants.URLs.snekbox_eval_api, - json=expected, - raise_for_status=True + constants.URLs.snekbox_eval_api, json=expected, raise_for_status=True ) resp.json.assert_awaited_once() @@ -70,20 +77,42 @@ async def test_codeblock_converter(self): cases = ( ('print("Hello world!")', 'print("Hello world!")', "non-formatted"), ('`print("Hello world!")`', 'print("Hello world!")', "one line code block"), - ('```\nprint("Hello world!")```', 'print("Hello world!")', "multiline code block"), - ('```py\nprint("Hello world!")```', 'print("Hello world!")', "multiline python code block"), - ('text```print("Hello world!")```text', 'print("Hello world!")', "code block surrounded by text"), - ('```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', - 'print("Hello world!")\nprint("Hello world!")', "two code blocks with text in-between"), - ('`print("Hello world!")`\ntext\n```print("How\'s it going?")```', - 'print("How\'s it going?")', "code block preceded by inline code"), - ('`print("Hello world!")`\ntext\n`print("Hello world!")`', - 'print("Hello world!")', "one inline code block of two") + ( + '```\nprint("Hello world!")```', + 'print("Hello world!")', + "multiline code block", + ), + ( + '```py\nprint("Hello world!")```', + 'print("Hello world!")', + "multiline python code block", + ), + ( + 'text```print("Hello world!")```text', + 'print("Hello world!")', + "code block surrounded by text", + ), + ( + '```print("Hello world!")```\ntext\n```py\nprint("Hello world!")```', + 'print("Hello world!")\nprint("Hello world!")', + "two code blocks with text in-between", + ), + ( + '`print("Hello world!")`\ntext\n```print("How\'s it going?")```', + 'print("How\'s it going?")', + "code block preceded by inline code", + ), + ( + '`print("Hello world!")`\ntext\n`print("Hello world!")`', + 'print("Hello world!")', + "one inline code block of two", + ), ) for case, expected, testname in cases: with self.subTest(msg=f"Extract code from {testname}."): self.assertEqual( - "\n".join(await snekbox.CodeblockConverter.convert(ctx, case)), expected + "\n".join(await snekbox.CodeblockConverter.convert(ctx, case)), + expected, ) def test_prepare_timeit_input(self): @@ -92,21 +121,35 @@ def test_prepare_timeit_input(self): cases = ( (['print("Hello World")'], "", "single block of code"), (["x = 1", "print(x)"], "x = 1", "two blocks of code"), - (["x = 1", "print(x)", 'print("Some other code.")'], "x = 1", "three blocks of code") + ( + ["x = 1", "print(x)", 'print("Some other code.")'], + "x = 1", + "three blocks of code", + ), ) for case, setup_code, test_name in cases: setup = snekbox._cog.TIMEIT_SETUP_WRAPPER.format(setup=setup_code) expected = [*base_args, setup, "\n".join(case[1:] if setup_code else case)] - with self.subTest(msg=f"Test with {test_name} and expected return {expected}"): + with self.subTest( + msg=f"Test with {test_name} and expected return {expected}" + ): self.assertEqual(self.cog.prepare_timeit_input(case), expected) def test_eval_result_message(self): """EvalResult.get_message(), should return message.""" cases = ( ("ERROR", None, ("Your 3.12 eval job has failed", "ERROR", "")), - ("", 128 + snekbox._eval.SIGKILL, ("Your 3.12 eval job timed out or ran out of memory", "", "")), - ("", 255, ("Your 3.12 eval job has failed", "A fatal NsJail error occurred", "")) + ( + "", + 128 + snekbox._eval.SIGKILL, + ("Your 3.12 eval job timed out or ran out of memory", "", ""), + ), + ( + "", + 255, + ("Your 3.12 eval job has failed", "A fatal NsJail error occurred", ""), + ), ) for stdout, returncode, expected in cases: exp_msg, exp_err, exp_files_err = expected @@ -125,21 +168,33 @@ def test_eval_result_message(self): def test_eval_result_files_error_message(self): """EvalResult.files_error_message, should return files error message.""" cases = [ - ([], ["abc"], ( - "1 file upload (abc) failed because its file size exceeds 8 MiB." - )), - ([], ["file1.bin", "f2.bin"], ( - "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB." - )), - (["a", "b"], ["c"], ( - "1 file upload (c) failed as it exceeded the 2 file limit." - )), - (["a"], ["b", "c"], ( - "2 file uploads (b, c) failed as they exceeded the 2 file limit." - )), + ( + [], + ["abc"], + ("1 file upload (abc) failed because its file size exceeds 8 MiB."), + ), + ( + [], + ["file1.bin", "f2.bin"], + ( + "2 file uploads (file1.bin, f2.bin) failed because each file's size exceeds 8 MiB." + ), + ), + ( + ["a", "b"], + ["c"], + ("1 file upload (c) failed as it exceeded the 2 file limit."), + ), + ( + ["a"], + ["b", "c"], + ("2 file uploads (b, c) failed as they exceeded the 2 file limit."), + ), ] for files, failed_files, expected_msg in cases: - with self.subTest(files=files, failed_files=failed_files, expected_msg=expected_msg): + with self.subTest( + files=files, failed_files=failed_files, expected_msg=expected_msg + ): result = EvalResult("", 0, files, failed_files) msg = result.files_error_message self.assertIn(expected_msg, msg) @@ -168,7 +223,7 @@ def test_eval_result_message_invalid_signal(self, _mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( result.get_status_message(EvalJob([], version="3.10")), - "Your 3.10 eval job has completed with return code 127" + "Your 3.10 eval job has completed with return code 127", ) self.assertEqual(result.error_message, "") self.assertEqual(result.files_error_message, "") @@ -179,7 +234,7 @@ def test_eval_result_message_valid_signal(self, mock_signals: Mock): result = EvalResult(stdout="", returncode=127) self.assertEqual( result.get_status_message(EvalJob([], version="3.12")), - "Your 3.12 eval job has completed with return code 127 (SIGTEST)" + "Your 3.12 eval job has completed with return code 127 (SIGTEST)", ) def test_eval_result_status_emoji(self): @@ -187,7 +242,7 @@ def test_eval_result_status_emoji(self): cases = ( (" ", -1, ":warning:"), ("Hello world!", 0, ":white_check_mark:"), - ("Invalid beard size", -1, ":x:") + ("Invalid beard size", -1, ":x:"), ) for stdout, returncode, expected in cases: with self.subTest(stdout=stdout, returncode=returncode, expected=expected): @@ -204,8 +259,10 @@ async def test_format_output(self): ) too_long_too_many_lines = ( "\n".join( - f"{i:03d} | {line}" for i, line in enumerate(["verylongbeard" * 10] * 15, 1) - )[:1000] + "\n... (truncated - too long, too many lines)" + f"{i:03d} | {line}" + for i, line in enumerate(["verylongbeard" * 10] * 15, 1) + )[:1000] + + "\n... (truncated - too long, too many lines)" ) cases = ( @@ -215,29 +272,38 @@ async def test_format_output(self): ("