From 5b3b07c718b7ab12f770e61098b9900997b61630 Mon Sep 17 00:00:00 2001 From: glrs <5999366+glrs@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:02:44 +0100 Subject: [PATCH 1/7] Update unittests --- tests/test_singleton_decorator.py | 491 ++++++++++++++++++++++++++++++ tests/test_sjob_manager.py | 322 ++++++++++++++++---- 2 files changed, 749 insertions(+), 64 deletions(-) create mode 100644 tests/test_singleton_decorator.py diff --git a/tests/test_singleton_decorator.py b/tests/test_singleton_decorator.py new file mode 100644 index 0000000..97bc241 --- /dev/null +++ b/tests/test_singleton_decorator.py @@ -0,0 +1,491 @@ +import unittest +from typing import Generic, TypeVar + +T = TypeVar("T") + +from lib.core_utils.singleton_decorator import SingletonMeta, singleton + + +class TestSingletonDecorator(unittest.TestCase): + + def test_singleton_basic(self): + @singleton + class MyClass: + pass + + instance1 = MyClass() + instance2 = MyClass() + + self.assertIs(instance1, instance2) + self.assertEqual(id(instance1), id(instance2)) + + def test_singleton_with_args(self): + @singleton + class MyClass: + def __init__(self, value): + self.value = value + + instance1 = MyClass(10) + instance2 = MyClass(20) + + self.assertIs(instance1, instance2) + self.assertEqual(instance1.value, 10) + self.assertEqual(instance2.value, 10) + + def test_singleton_different_classes(self): + @singleton + class ClassA: + pass + + @singleton + class ClassB: + pass + + instance_a1 = ClassA() + instance_a2 = ClassA() + instance_b1 = ClassB() + instance_b2 = ClassB() + + self.assertIs(instance_a1, instance_a2) + self.assertIs(instance_b1, instance_b2) + self.assertIsNot(instance_a1, instance_b1) + + def test_singleton_inheritance(self): + @singleton + class BaseClass: + pass + + class SubClass(BaseClass): + pass + + base_instance1 = BaseClass() + base_instance2 = BaseClass() + sub_instance1 = SubClass() + sub_instance2 = SubClass() + + self.assertIs(base_instance1, base_instance2) + self.assertIs(sub_instance1, sub_instance2) + self.assertIsNot(base_instance1, sub_instance1) + + def test_singleton_with_kwargs(self): + @singleton + class MyClass: + def __init__(self, **kwargs): + self.kwargs = kwargs + + instance1 = MyClass(a=1, b=2) + instance2 = MyClass(a=3, b=4) + + self.assertIs(instance1, instance2) + self.assertEqual(instance1.kwargs, {"a": 1, "b": 2}) + self.assertEqual(instance2.kwargs, {"a": 1, "b": 2}) + + def test_singleton_reset_instance(self): + @singleton + class MyClass: + pass + + instance1 = MyClass() + instance2 = MyClass() + + self.assertIs(instance1, instance2) + + # Reset the singleton instance + SingletonMeta._instances.pop(MyClass, None) + instance3 = MyClass() + + self.assertIsNot(instance1, instance3) + + def test_singleton_thread_safety(self): + # Note: The current singleton implementation is not thread-safe. + # This test demonstrates that, and in a real-world scenario, + # you should use threading locks to make it thread-safe. + + import threading + + @singleton + class MyClass: + def __init__(self, value): + self.value = value + + instances = [] + + def create_instance(value): + instances.append(MyClass(value)) + + threads = [] + for i in range(10): + thread = threading.Thread(target=create_instance, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Check that all instances are the same + for instance in instances: + self.assertIs(instance, instances[0]) + + def test_singleton_with_multiple_instances(self): + # Ensure that singleton instances are maintained separately for different classes + @singleton + class MyClassA: + def __init__(self, value): + self.value = value + + @singleton + class MyClassB: + def __init__(self, value): + self.value = value + + instance_a1 = MyClassA(1) + instance_a2 = MyClassA(2) + instance_b1 = MyClassB(3) + instance_b2 = MyClassB(4) + + self.assertIs(instance_a1, instance_a2) + self.assertIs(instance_b1, instance_b2) + self.assertIsNot(instance_a1, instance_b1) + self.assertEqual(instance_a1.value, 1) + self.assertEqual(instance_b1.value, 3) + + def test_singleton_with_classmethod(self): + @singleton + class MyClass: + @classmethod + def cls_method(cls): + return "cls_method called" + + instance1 = MyClass() + instance2 = MyClass() + + self.assertIs(instance1, instance2) + self.assertEqual(MyClass.cls_method(), "cls_method called") + + def test_singleton_with_staticmethod(self): + @singleton + class MyClass: + @staticmethod + def static_method(): + return "static_method called" + + instance1 = MyClass() + instance2 = MyClass() + + self.assertIs(instance1, instance2) + self.assertEqual(MyClass.static_method(), "static_method called") + + def test_singleton_with_property(self): + @singleton + class MyClass: + def __init__(self, value): + self._value = value + + @property + def value(self): + return self._value + + instance1 = MyClass(10) + instance2 = MyClass(20) + + self.assertIs(instance1, instance2) + self.assertEqual(instance1.value, 10) + + def test_singleton_decorator_without_parentheses(self): + # Ensure that the singleton decorator can be used without parentheses + @singleton + class MyClass: + pass + + instance1 = MyClass() + instance2 = MyClass() + + self.assertIs(instance1, instance2) + + def test_singleton_repr(self): + @singleton + class MyClass: + pass + + instance = MyClass() + self.assertEqual(repr(instance), repr(MyClass())) + + def test_singleton_str(self): + @singleton + class MyClass: + pass + + instance = MyClass() + self.assertEqual(str(instance), str(MyClass())) + + def test_singleton_isinstance(self): + @singleton + class MyClass: + pass + + instance = MyClass() + self.assertIsInstance(instance, MyClass) + + def test_singleton_pickle_not_supported(self): + """Test that pickling a singleton instance is not supported + and raises an exception. + """ + import pickle + + @singleton + class MyClass: + def __init__(self, value): + self.value = value + + instance = MyClass(10) + + with self.assertRaises((TypeError, AttributeError, pickle.PicklingError)): + pickle.dumps(instance) + + def test_singleton_subclassing_singleton(self): + @singleton + class BaseClass: + pass + + @singleton + class SubClass(BaseClass): + pass + + base_instance = BaseClass() + sub_instance = SubClass() + + self.assertIsNot(base_instance, sub_instance) + self.assertIsInstance(sub_instance, SubClass) + self.assertIsInstance(sub_instance, BaseClass) + + def test_singleton_metaclass_conflict(self): + """Test that applying the singleton decorator toa class + with a custom metaclass raises a TypeError. + """ + + class Meta(type): + pass + + with self.assertRaises(TypeError): + + @singleton + class MyClass(metaclass=Meta): + pass + + def test_singleton_with_decorated_class(self): + def decorator(cls): + cls.decorated = True + return cls + + @singleton + @decorator + class MyClass: + pass + + instance = MyClass() + self.assertTrue(hasattr(instance, "decorated")) + self.assertTrue(instance.decorated) + + def test_singleton_with_exceptions_in_init(self): + @singleton + class MyClass: + def __init__(self, value): + if value < 0: + raise ValueError("Negative value not allowed") + self.value = value + + with self.assertRaises(ValueError): + MyClass(-1) + + # Instance should not be created due to exception + self.assertFalse(MyClass in SingletonMeta._instances) + + # Creating with valid value + instance = MyClass(10) + self.assertEqual(instance.value, 10) + + def test_singleton_docstring_preserved(self): + @singleton + class MyClass: + """This is MyClass docstring.""" + + pass + + self.assertEqual(MyClass.__doc__, "This is MyClass docstring.") + + def test_singleton_name_preserved(self): + @singleton + class MyClass: + pass + + self.assertEqual(MyClass.__name__, "MyClass") + + def test_singleton_module_preserved(self): + @singleton + class MyClass: + pass + + self.assertEqual(MyClass.__module__, __name__) + + def test_singleton_annotations_preserved(self): + @singleton + class MyClass: + x: int + + def __init__(self, x: int): + self.x = x + + instance = MyClass(10) + self.assertEqual(instance.x, 10) + self.assertEqual(MyClass.__annotations__, {"x": int}) + + def test_singleton_with_slots(self): + @singleton + class MyClass: + __slots__ = ["value"] + + def __init__(self, value): + self.value = value + + instance1 = MyClass(10) + instance2 = MyClass(20) + + self.assertIs(instance1, instance2) + self.assertEqual(instance1.value, 10) + + def test_singleton_with_weakref(self): + import weakref + + @singleton + class MyClass: + pass + + instance = MyClass() + weak_instance = weakref.ref(instance) + self.assertIs(weak_instance(), instance) + + def test_singleton_with_del(self): + @singleton + class MyClass: + pass + + instance1 = MyClass() + del instance1 + + instance2 = MyClass() + self.assertIsNotNone(instance2) + + def test_singleton_reset_between_tests(self): + @singleton + class MyClass: + pass + + instance1 = MyClass() + instance2 = MyClass() + self.assertIs(instance1, instance2) + + # Reset the instance (for testing purposes) + SingletonMeta._instances.pop(MyClass, None) + + instance3 = MyClass() + self.assertIsNot(instance1, instance3) + + def test_singleton_no_args(self): + @singleton + class MyClass: + def __init__(self): + self.value = 42 + + instance = MyClass() + self.assertEqual(instance.value, 42) + + def test_singleton_calling_class_directly(self): + @singleton + class MyClass: + pass + + instance = MyClass() + # Since MyClass is a class, calling it directly is the correct way + direct_instance = MyClass() + + self.assertIs(instance, direct_instance) + + def test_singleton_calling_get_instance_directly(self): + @singleton + class MyClass: + pass + + # Access the get_instance function directly + get_instance = MyClass + instance1 = get_instance() + instance2 = get_instance() + + self.assertIs(instance1, instance2) + + def test_singleton_multiple_arguments(self): + @singleton + class MyClass: + def __init__(self, a, b, c): + self.total = a + b + c + + instance1 = MyClass(1, 2, 3) + instance2 = MyClass(4, 5, 6) + + self.assertIs(instance1, instance2) + self.assertEqual(instance1.total, 6) + + def test_singleton_class_variables(self): + @singleton + class MyClass: + count = 0 + + def __init__(self): + MyClass.count += 1 + + instance1 = MyClass() + instance2 = MyClass() + + self.assertIs(instance1, instance2) + self.assertEqual(MyClass.count, 1) + + def test_singleton_with_already_existing_instance(self): + @singleton + class MyClass: + pass + + # Manually add an instance to the instances dict + SingletonMeta._instances[MyClass] = "ExistingInstance" + + instance = MyClass() + self.assertEqual(instance, "ExistingInstance") + + def test_singleton_with_different_classes_same_name(self): + @singleton + class MyClass: # type: ignore + pass + + # Define another class with the same name + @singleton + class MyClass: # noqa: F811 + pass + + instance1 = MyClass() + instance2 = MyClass() + + self.assertIs(instance1, instance2) + + def test_singleton_with_type_var(self): + @singleton + class MyClass(Generic[T]): + def __init__(self, value: T): + self.value = value + + instance1 = MyClass(10) + instance2 = MyClass(20) + + self.assertIs(instance1, instance2) + self.assertEqual(instance1.value, 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sjob_manager.py b/tests/test_sjob_manager.py index 804b053..a51221c 100644 --- a/tests/test_sjob_manager.py +++ b/tests/test_sjob_manager.py @@ -1,90 +1,284 @@ import asyncio -import subprocess import unittest from unittest.mock import AsyncMock, MagicMock, patch from lib.module_utils.sjob_manager import SlurmJobManager -class MockSample: - def __init__(self, id): - self.id = id +class Sample: + """Mock sample object with id and status attributes.""" + + def __init__(self, sample_id): + self.id = sample_id + self.status = None def post_process(self): - pass # Add your mock implementation here if needed + pass # Mock method to simulate post-processing + +class TestSlurmJobManager(unittest.IsolatedAsyncioTestCase): -class TestSlurmJobManager(unittest.TestCase): def setUp(self): self.manager = SlurmJobManager() - self.sample = MagicMock() - self.sample.id = "sample1" - self.sample.post_process = AsyncMock() - - async def test_monitor_job(self): - with unittest.mock.patch.object( - self.manager, "_job_status", new_callable=AsyncMock - ) as mock_job_status, unittest.mock.patch.object( - self.manager, "check_status" - ) as mock_check_status: - - for status in ["COMPLETED", "FAILED", "CANCELLED"]: - mock_job_status.return_value = status - await self.manager.monitor_job("job1", self.sample) - mock_check_status.assert_called_with("job1", status, self.sample) - - @patch( - "lib.utils.sjob_manager.asyncio.create_subprocess_exec", new_callable=AsyncMock - ) - @patch("lib.utils.sjob_manager.asyncio.wait_for", new_callable=AsyncMock) - def test_submit_job(self, mock_wait_for, mock_create_subprocess_exec): - # Set up the mocks - mock_create_subprocess_exec.return_value.communicate.return_value = ( - b"1234", - b"", - ) - mock_create_subprocess_exec.return_value.returncode = 0 - mock_wait_for.return_value = (b"1234", b"") + self.script_path = "test_script.sh" + self.job_id = "12345" + self.sample = Sample("sample1") - # Call the submit_job method - job_id = asyncio.run(self.manager.submit_job("script.sh")) + @patch("lib.module_utils.sjob_manager.Path") + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_exec") + async def test_submit_job_success(self, mock_create_subprocess_exec, mock_path): + # Mock Path.is_file() to return True + mock_path.return_value.is_file.return_value = True - # Assert the mocks were called correctly - mock_create_subprocess_exec.assert_called_once_with( - "sbatch", "script.sh", stdout=subprocess.PIPE, stderr=subprocess.PIPE + # Mock the subprocess + process_mock = MagicMock() + process_mock.communicate = AsyncMock( + return_value=(b"Submitted batch job 12345\n", b"") ) - mock_wait_for.assert_called_once() - - # Assert the correct job ID was returned - self.assertEqual(job_id, "1234") - - @patch( - "lib.utils.sjob_manager.asyncio.create_subprocess_shell", new_callable=AsyncMock - ) - @patch("lib.utils.sjob_manager.asyncio.wait_for", new_callable=AsyncMock) - def test__job_status(self, mock_wait_for, mock_create_subprocess_shell): - # Set up the mocks - mock_create_subprocess_shell.return_value.communicate.return_value = ( - b"COMPLETED", - b"", + process_mock.returncode = 0 + mock_create_subprocess_exec.return_value = process_mock + + job_id = await self.manager.submit_job(self.script_path) + self.assertEqual(job_id, "12345") + + @patch("lib.module_utils.sjob_manager.Path") + async def test_submit_job_script_not_found(self, mock_path): + # Mock Path.is_file() to return False + mock_path.return_value.is_file.return_value = False + + job_id = await self.manager.submit_job(self.script_path) + self.assertIsNone(job_id) + + @patch("lib.module_utils.sjob_manager.Path") + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_exec") + async def test_submit_job_sbatch_error( + self, mock_create_subprocess_exec, mock_path + ): + # Mock Path.is_file() to return True + mock_path.return_value.is_file.return_value = True + + # Mock the subprocess to simulate sbatch error + process_mock = MagicMock() + process_mock.communicate = AsyncMock( + return_value=(b"", b"Error submitting job") ) - mock_create_subprocess_shell.return_value.returncode = 0 - mock_wait_for.return_value = (b"COMPLETED", b"") + process_mock.returncode = 1 + mock_create_subprocess_exec.return_value = process_mock + + job_id = await self.manager.submit_job(self.script_path) + self.assertIsNone(job_id) - # Call the _job_status method - status = asyncio.run(self.manager._job_status("1234")) + @patch("lib.module_utils.sjob_manager.Path") + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_exec") + async def test_submit_job_no_job_id(self, mock_create_subprocess_exec, mock_path): + # Mock Path.is_file() to return True + mock_path.return_value.is_file.return_value = True - # Assert the mocks were called correctly - mock_create_subprocess_shell.assert_called_once_with( - "sacct -n -X -o State -j 1234", - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + # Mock the subprocess to return output without job ID + process_mock = MagicMock() + process_mock.communicate = AsyncMock( + return_value=(b"Submission output without job ID", b"") ) - mock_wait_for.assert_called_once() + process_mock.returncode = 0 + mock_create_subprocess_exec.return_value = process_mock + + job_id = await self.manager.submit_job(self.script_path) + self.assertIsNone(job_id) + + @patch("lib.module_utils.sjob_manager.Path") + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_exec") + async def test_submit_job_timeout(self, mock_create_subprocess_exec, mock_path): + # Mock Path.is_file() to return True + mock_path.return_value.is_file.return_value = True + + # Mock the subprocess to simulate a timeout + async def mock_communicate(): + await asyncio.sleep(0.1) + raise asyncio.TimeoutError() + + # Mock the subprocess to simulate a timeout + mock_create_subprocess_exec.side_effect = asyncio.TimeoutError - # Assert the correct status was returned + job_id = await self.manager.submit_job(self.script_path) + self.assertIsNone(job_id) + + @patch("lib.module_utils.sjob_manager.asyncio.sleep", new_callable=AsyncMock) + @patch("lib.module_utils.sjob_manager.SlurmJobManager._job_status") + async def test_monitor_job_completed(self, mock_job_status, mock_sleep): + # Mock _job_status to return 'COMPLETED' after a few calls + mock_job_status.side_effect = ["PENDING", "RUNNING", "COMPLETED"] + + await self.manager.monitor_job(self.job_id, self.sample) + self.assertEqual(self.sample.status, "processed") + + @patch("lib.module_utils.sjob_manager.asyncio.sleep", new_callable=AsyncMock) + @patch("lib.module_utils.sjob_manager.SlurmJobManager._job_status") + async def test_monitor_job_failed(self, mock_job_status, mock_sleep): + # Mock _job_status to return 'FAILED' + mock_job_status.return_value = "FAILED" + + await self.manager.monitor_job(self.job_id, self.sample) + self.assertEqual(self.sample.status, "processing_failed") + + @patch("lib.module_utils.sjob_manager.asyncio.sleep", new_callable=AsyncMock) + @patch("lib.module_utils.sjob_manager.SlurmJobManager._job_status") + async def test_monitor_job_unexpected_status(self, mock_job_status, mock_sleep): + # Mock _job_status to return 'UNKNOWN_STATUS' a few times, then 'COMPLETED' + mock_job_status.side_effect = ["UNKNOWN_STATUS"] * 3 + ["COMPLETED"] + + await self.manager.monitor_job(self.job_id, self.sample) + self.assertEqual(self.sample.status, "processed") + + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_shell") + async def test_job_status_success(self, mock_create_subprocess_shell): + # Mock the subprocess to return a valid status + process_mock = MagicMock() + process_mock.communicate = AsyncMock(return_value=(b"COMPLETED", b"")) + mock_create_subprocess_shell.return_value = process_mock + + status = await self.manager._job_status(self.job_id) self.assertEqual(status, "COMPLETED") + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_shell") + async def test_job_status_error(self, mock_create_subprocess_shell): + # Mock the subprocess to return stderr + process_mock = MagicMock() + process_mock.communicate = AsyncMock(return_value=(b"", b"sacct error")) + mock_create_subprocess_shell.return_value = process_mock + + status = await self.manager._job_status(self.job_id) + self.assertIsNone(status) + + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_shell") + async def test_job_status_timeout(self, mock_create_subprocess_shell): + # Mock the subprocess to simulate a timeout + mock_create_subprocess_shell.side_effect = asyncio.TimeoutError + + status = await self.manager._job_status(self.job_id) + self.assertIsNone(status) + + def test_check_status_completed(self): + # Test check_status with 'COMPLETED' status + self.manager.check_status(self.job_id, "COMPLETED", self.sample) + self.assertEqual(self.sample.status, "processed") + + def test_check_status_failed(self): + # Test check_status with 'FAILED' status + self.manager.check_status(self.job_id, "FAILED", self.sample) + self.assertEqual(self.sample.status, "processing_failed") + + def test_check_status_unexpected(self): + # Test check_status with an unexpected status + self.manager.check_status(self.job_id, "UNKNOWN_STATUS", self.sample) + self.assertEqual(self.sample.status, "processing_failed") + + @patch("lib.module_utils.sjob_manager.custom_logger") + def test_init_with_configs(self, mock_custom_logger): + # Mock configs to return custom polling interval + with patch( + "lib.module_utils.sjob_manager.configs", {"job_monitor_poll_interval": 5.0} + ): + manager = SlurmJobManager() + self.assertEqual(manager.polling_interval, 5.0) + + @patch("lib.module_utils.sjob_manager.custom_logger") + def test_init_with_default_configs(self, mock_custom_logger): + # Mock configs to be empty + with patch("lib.module_utils.sjob_manager.configs", {}): + manager = SlurmJobManager() + self.assertEqual(manager.polling_interval, 10.0) + + @patch("lib.module_utils.sjob_manager.Path") + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_exec") + async def test_submit_job_exception(self, mock_create_subprocess_exec, mock_path): + # Mock Path.is_file() to return True + mock_path.return_value.is_file.return_value = True + + # Simulate an exception during subprocess creation + mock_create_subprocess_exec.side_effect = Exception("Unexpected error") + + job_id = await self.manager.submit_job(self.script_path) + self.assertIsNone(job_id) + + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_shell") + async def test_job_status_exception(self, mock_create_subprocess_shell): + # Simulate an exception during subprocess creation + mock_create_subprocess_shell.side_effect = Exception("Unexpected error") + + status = await self.manager._job_status(self.job_id) + self.assertIsNone(status) + + @patch("lib.module_utils.sjob_manager.SlurmJobManager._job_status") + async def test_monitor_job_no_status(self, mock_job_status): + # Mock _job_status to return None + mock_job_status.return_value = None + + # We need to prevent an infinite loop; we'll let it run only once + with patch( + "lib.module_utils.sjob_manager.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + mock_sleep.side_effect = asyncio.CancelledError + + with self.assertRaises(asyncio.CancelledError): + await self.manager.monitor_job(self.job_id, self.sample) + + def test_check_status_calls_post_process(self): + # Mock the sample's post_process method + self.sample.post_process = MagicMock() + + self.manager.check_status(self.job_id, "COMPLETED", self.sample) + self.sample.post_process.assert_called_once() + + def test_check_status_does_not_call_post_process(self): + # Mock the sample's post_process method + self.sample.post_process = MagicMock() + + self.manager.check_status(self.job_id, "FAILED", self.sample) + self.sample.post_process.assert_not_called() + + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_shell") + async def test_job_status_with_multiple_lines(self, mock_create_subprocess_shell): + # Mock sacct output with multiple lines + process_mock = MagicMock() + process_mock.communicate = AsyncMock( + return_value=(b"COMPLETED\nCOMPLETED", b"") + ) + mock_create_subprocess_shell.return_value = process_mock + + status = await self.manager._job_status(self.job_id) + self.assertEqual(status, "COMPLETED\nCOMPLETED") + + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_shell") + async def test_job_status_empty_output(self, mock_create_subprocess_shell): + # Mock sacct output with empty stdout and stderr + process_mock = MagicMock() + process_mock.communicate = AsyncMock(return_value=(b"", b"")) + mock_create_subprocess_shell.return_value = process_mock + + status = await self.manager._job_status(self.job_id) + self.assertIsNone(status) + + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_shell") + async def test_job_status_decode_error(self, mock_create_subprocess_shell): + # Mock sacct output with bytes that cannot be decoded + process_mock = MagicMock() + process_mock.communicate = AsyncMock(return_value=(b"\xff\xfe", b"")) + mock_create_subprocess_shell.return_value = process_mock + + status = await self.manager._job_status(self.job_id) + self.assertIsNone(status) + + @patch("lib.module_utils.sjob_manager.asyncio.create_subprocess_exec") + async def test_submit_job_decode_error(self, mock_create_subprocess_exec): + # Mock sbatch output with bytes that cannot be decoded + process_mock = MagicMock() + process_mock.communicate = AsyncMock(return_value=(b"\xff\xfe", b"")) + process_mock.returncode = 0 + mock_create_subprocess_exec.return_value = process_mock + + job_id = await self.manager.submit_job(self.script_path) + self.assertIsNone(job_id) + if __name__ == "__main__": unittest.main() From 2ff72bceac142d6597e072a804d2b04c5a2c20e9 Mon Sep 17 00:00:00 2001 From: glrs <5999366+glrs@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:03:07 +0100 Subject: [PATCH 2/7] Update unittests --- tests/test_common.py | 222 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 177 insertions(+), 45 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index 17debbd..d387fb8 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,88 +1,220 @@ +import os import unittest from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, mock_open, patch from lib.core_utils.common import YggdrasilUtilities class TestYggdrasilUtilities(unittest.TestCase): - @patch("lib.utils.common.importlib.import_module") + def setUp(self): + # Backup original values + self.original_module_cache = YggdrasilUtilities.module_cache.copy() + self.original_config_dir = YggdrasilUtilities.CONFIG_DIR + + # Reset module cache + YggdrasilUtilities.module_cache = {} + + # Use a temporary config directory + self.temp_config_dir = Path("/tmp/yggdrasil_test_config") + self.temp_config_dir.mkdir(parents=True, exist_ok=True) + YggdrasilUtilities.CONFIG_DIR = self.temp_config_dir + + def tearDown(self): + # Restore original values + YggdrasilUtilities.module_cache = self.original_module_cache + YggdrasilUtilities.CONFIG_DIR = self.original_config_dir + + # Clean up temporary config directory + for item in self.temp_config_dir.glob("*"): + item.unlink() + self.temp_config_dir.rmdir() + + @patch("importlib.import_module") def test_load_realm_class_success(self, mock_import_module): - # Mock successful class loading + # Mock module and class mock_module = MagicMock() mock_class = MagicMock() + setattr(mock_module, "MockClass", mock_class) mock_import_module.return_value = mock_module - mock_module.MyClass = mock_class - result = YggdrasilUtilities.load_realm_class("my_module.MyClass") + module_path = "some.module.MockClass" + result = YggdrasilUtilities.load_realm_class(module_path) - mock_import_module.assert_called_once_with("my_module") self.assertEqual(result, mock_class) + self.assertIn(module_path, YggdrasilUtilities.module_cache) + mock_import_module.assert_called_with("some.module") + + @patch("importlib.import_module") + def test_load_realm_class_module_not_found(self, mock_import_module): + # Simulate ImportError + mock_import_module.side_effect = ImportError("Module not found") + + module_path = "nonexistent.module.ClassName" + result = YggdrasilUtilities.load_realm_class(module_path) + + self.assertIsNone(result) + mock_import_module.assert_called_with("nonexistent.module") - @patch("lib.utils.common.importlib.import_module") - def test_load_realm_class_failure(self, mock_import_module): - # Mock import error - mock_import_module.side_effect = ImportError() + @patch("importlib.import_module") + def test_load_realm_class_attribute_error(self, mock_import_module): + # Module exists but class does not + mock_module = MagicMock() + mock_import_module.return_value = mock_module - result = YggdrasilUtilities.load_realm_class("non_existent_module.MyClass") + module_path = "some.module.MissingClass" + result = YggdrasilUtilities.load_realm_class(module_path) - mock_import_module.assert_called_once_with("non_existent_module") self.assertIsNone(result) + mock_import_module.assert_called_with("some.module") - @patch("lib.utils.common.importlib.import_module") + @patch("importlib.import_module") def test_load_module_success(self, mock_import_module): - # Mock successful module loading + # Mock module mock_module = MagicMock() mock_import_module.return_value = mock_module - result = YggdrasilUtilities.load_module("my_module") + module_path = "some.module" + result = YggdrasilUtilities.load_module(module_path) - mock_import_module.assert_called_once_with("my_module") self.assertEqual(result, mock_module) + self.assertIn(module_path, YggdrasilUtilities.module_cache) + mock_import_module.assert_called_with("some.module") - @patch("lib.utils.common.importlib.import_module") - def test_load_module_failure(self, mock_import_module): - # Mock import error - mock_import_module.side_effect = ImportError() + @patch("importlib.import_module") + def test_load_module_import_error(self, mock_import_module): + # Simulate ImportError + mock_import_module.side_effect = ImportError("Module not found") - result = YggdrasilUtilities.load_module("non_existent_module") + module_path = "nonexistent.module" + result = YggdrasilUtilities.load_module(module_path) - mock_import_module.assert_called_once_with("non_existent_module") self.assertIsNone(result) + mock_import_module.assert_called_with("nonexistent.module") - @patch("lib.utils.common.Path.exists") - def test_get_path_file_exists(self, mock_exists): - # Your input needed: Adjust the file path according to your project structure - mock_exists.return_value = True - expected_path = Path( - "/home/anastasios/Documents/git/Yggdrasil/yggdrasil_workspace/common/configurations/config.json" - ) # Replace with actual expected path - - result = YggdrasilUtilities.get_path("config.json") + def test_get_path_file_exists(self): + # Create a dummy config file + file_name = "config.yaml" + test_file = self.temp_config_dir / file_name + test_file.touch() - self.assertIsNotNone(result) - self.assertEqual(result, expected_path) + result = YggdrasilUtilities.get_path(file_name) - @patch("lib.utils.common.Path.exists") - def test_get_path_file_not_exists(self, mock_exists): - mock_exists.return_value = False + self.assertEqual(result, test_file) - result = YggdrasilUtilities.get_path("config.json") + def test_get_path_file_not_exists(self): + file_name = "missing_config.yaml" + result = YggdrasilUtilities.get_path(file_name) self.assertIsNone(result) - @patch.dict("lib.utils.common.os.environ", {"MY_VAR": "value"}) def test_env_variable_exists(self): - result = YggdrasilUtilities.env_variable("MY_VAR") - self.assertEqual(result, "value") + with patch.dict(os.environ, {"TEST_ENV_VAR": "test_value"}): + result = YggdrasilUtilities.env_variable("TEST_ENV_VAR") + self.assertEqual(result, "test_value") + + def test_env_variable_not_exists_with_default(self): + result = YggdrasilUtilities.env_variable( + "NONEXISTENT_ENV_VAR", default="default_value" + ) + self.assertEqual(result, "default_value") + + def test_env_variable_not_exists_no_default(self): + result = YggdrasilUtilities.env_variable("NONEXISTENT_ENV_VAR") + self.assertIsNone(result) + + @patch("builtins.open", new_callable=mock_open, read_data="123") + def test_get_last_processed_seq_file_exists(self, mock_file): + seq_file = self.temp_config_dir / ".last_processed_seq" + seq_file.touch() + + with patch.object(YggdrasilUtilities, "get_path", return_value=seq_file): + result = YggdrasilUtilities.get_last_processed_seq() + + self.assertEqual(result, "123") + + def test_get_last_processed_seq_file_not_exists(self): + with patch.object(YggdrasilUtilities, "get_path", return_value=None): + result = YggdrasilUtilities.get_last_processed_seq() + self.assertEqual(result, "0") # Default value as per method + + @patch("builtins.open", new_callable=mock_open) + def test_save_last_processed_seq_success(self, mock_file): + seq_file = self.temp_config_dir / ".last_processed_seq" + + with patch.object(YggdrasilUtilities, "get_path", return_value=seq_file): + YggdrasilUtilities.save_last_processed_seq("456") + + mock_file.assert_called_with(seq_file, "w") + mock_file().write.assert_called_with("456") + + def test_save_last_processed_seq_no_seq_file(self): + with patch.object(YggdrasilUtilities, "get_path", return_value=None): + # Should handle gracefully + YggdrasilUtilities.save_last_processed_seq("789") + + def test_module_cache_persistence(self): + # Mock module + mock_module = MagicMock() + with patch("importlib.import_module", return_value=mock_module) as mock_import: + module_path = "some.module" + + # First call + result1 = YggdrasilUtilities.load_module(module_path) + # Second call should use cache + result2 = YggdrasilUtilities.load_module(module_path) + + self.assertEqual(result1, result2) + mock_import.assert_called_once_with("some.module") + + @patch("builtins.open", new_callable=mock_open, read_data="") + def test_get_last_processed_seq_empty_file(self, mock_file): + seq_file = self.temp_config_dir / ".last_processed_seq" + seq_file.touch() + + with patch.object(YggdrasilUtilities, "get_path", return_value=seq_file): + result = YggdrasilUtilities.get_last_processed_seq() + + self.assertEqual(result, "0") # Assumes default when file is empty + + @patch("builtins.open", new_callable=mock_open, read_data="abc") + def test_get_last_processed_seq_invalid_content(self, mock_file): + seq_file = self.temp_config_dir / ".last_processed_seq" + seq_file.touch() + + with patch.object(YggdrasilUtilities, "get_path", return_value=seq_file): + result = YggdrasilUtilities.get_last_processed_seq() + + self.assertEqual(result, "abc") # Returns content as-is + + @patch("builtins.open", side_effect=Exception("File error")) + def test_get_last_processed_seq_file_error(self, mock_file): + seq_file = self.temp_config_dir / ".last_processed_seq" + + with patch.object(YggdrasilUtilities, "get_path", return_value=seq_file): + result = YggdrasilUtilities.get_last_processed_seq() + self.assertEqual(result, "0") # Should handle exception and return default + + @patch("builtins.open", side_effect=Exception("File error")) + def test_save_last_processed_seq_file_error(self, mock_file): + seq_file = self.temp_config_dir / ".last_processed_seq" + + with patch.object(YggdrasilUtilities, "get_path", return_value=seq_file): + # Should handle exception gracefully + YggdrasilUtilities.save_last_processed_seq("123") - @patch.dict("lib.utils.common.os.environ", {}, clear=True) - def test_env_variable_not_exists(self): - result = YggdrasilUtilities.env_variable("MY_VAR", default="default") - self.assertEqual(result, "default") + def test_get_path_with_relative_file_name(self): + # Use relative path components in file name + file_name = "../outside_config.yaml" + result = YggdrasilUtilities.get_path(file_name) + self.assertIsNone(result) # Should not allow navigating outside config dir - # TODO: Additional test cases or scenarios + def test_get_path_with_absolute_file_name(self): + # Use absolute path + file_name = "/etc/passwd" + result = YggdrasilUtilities.get_path(file_name) + self.assertIsNone(result) # Should not allow absolute paths if __name__ == "__main__": From fd8fd2f5fcc74f89b751fa3124f22a0b0c1d42bd Mon Sep 17 00:00:00 2001 From: glrs <5999366+glrs@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:03:33 +0100 Subject: [PATCH 3/7] Update unittests --- tests/test_config_loader.py | 193 ++++++++++++++++++++++++++++-------- 1 file changed, 149 insertions(+), 44 deletions(-) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index e6ae437..7ce6cf1 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -1,3 +1,4 @@ +import json import types import unittest from pathlib import Path @@ -9,52 +10,156 @@ class TestConfigLoader(unittest.TestCase): def setUp(self): + # Create a ConfigLoader instance for testing self.config_loader = ConfigLoader() + self.mock_config_data = {"key1": "value1", "key2": "value2"} + self.mock_config_json = json.dumps(self.mock_config_data) - @patch("lib.utils.config_loader.Ygg.get_path", return_value=Path("dummy_file_path")) - @patch("builtins.open", new_callable=mock_open, read_data='{"key": "value"}') - @patch("json.load", return_value={"key": "value"}) - def test_load_config(self, mock_json_load, mock_file, mock_get_path): - # Create an instance of the ConfigLoader class - config_loader = ConfigLoader() - # Call the load_config method - config = config_loader.load_config("dummy_file_name") - # Assert the Ygg.get_path function was called - mock_get_path.assert_called_once_with("dummy_file_name") - # Assert the file was opened - mock_file.assert_called_once_with(Path("dummy_file_path"), "r") - # Assert the json.load function was called - mock_json_load.assert_called_once() - # Assert the config was loaded correctly - self.assertEqual(config, types.MappingProxyType({"key": "value"})) - - @patch("builtins.open", new_callable=mock_open, read_data='{"key": "value"}') - @patch("json.load", return_value={"key": "value"}) - def test_load_config_path(self, mock_json_load, mock_file): - # Call the load_config_path method - config = self.config_loader.load_config_path("/path/to/config.json") - # Assert the file was opened - mock_file.assert_called_once_with(Path("/path/to/config.json"), "r") - # Assert the json.load function was called - mock_json_load.assert_called_once() - # Assert the config was loaded correctly - self.assertEqual(config, types.MappingProxyType({"key": "value"})) - - def test_getitem2(self): - # Load some configuration data - self.config_loader._config = types.MappingProxyType({"key": "value"}) - # Call the __getitem__ method - value = self.config_loader["key"] - # Assert the correct value was returned - self.assertEqual(value, "value") - - def test_getitem(self): - # Mock the _config dictionary - self.config_loader._config = types.MappingProxyType({"key": "value"}) - # Call the __getitem__ method - value = self.config_loader["key"] - # Assert the correct value was returned - self.assertEqual(value, "value") + def test_init(self): + # Test that _config is initialized to None + self.assertIsNone(self.config_loader._config) + + def test_load_config_success(self): + # Test loading config from a file name using Ygg.get_path + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path, patch( + "builtins.open", mock_open(read_data=self.mock_config_json) + ): + mock_get_path.return_value = Path("/path/to/config.json") + config = self.config_loader.load_config("config.json") + self.assertEqual(config, types.MappingProxyType(self.mock_config_data)) + self.assertEqual( + self.config_loader._config, + types.MappingProxyType(self.mock_config_data), + ) + + def test_load_config_path_success(self): + # Test loading config from a full path + with patch("builtins.open", mock_open(read_data=self.mock_config_json)): + config = self.config_loader.load_config_path("/path/to/config.json") + self.assertEqual(config, types.MappingProxyType(self.mock_config_data)) + self.assertEqual( + self.config_loader._config, + types.MappingProxyType(self.mock_config_data), + ) + + def test_load_config_file_not_found(self): + # Test behavior when config file is not found + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path: + mock_get_path.return_value = None + config = self.config_loader.load_config("nonexistent.json") + self.assertEqual(config, types.MappingProxyType({})) + self.assertEqual(self.config_loader._config, types.MappingProxyType({})) + + def test_load_config_path_file_not_found(self): + # Test behavior when config file path is invalid + with patch("pathlib.Path.open", side_effect=FileNotFoundError()): + with self.assertRaises(FileNotFoundError): + self.config_loader.load_config_path("/invalid/path/config.json") + + def test_load_config_invalid_json(self): + # Test behavior when config file contains invalid JSON + invalid_json = "{key1: value1" # Missing quotes and closing brace + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path, patch( + "builtins.open", mock_open(read_data=invalid_json) + ): + mock_get_path.return_value = Path("/path/to/config.json") + with self.assertRaises(json.JSONDecodeError): + self.config_loader.load_config("config.json") + + def test_load_config_empty_file(self): + # Test behavior when config file is empty + empty_json = "" + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path, patch( + "builtins.open", mock_open(read_data=empty_json) + ): + mock_get_path.return_value = Path("/path/to/config.json") + with self.assertRaises(json.JSONDecodeError): + self.config_loader.load_config("config.json") + + def test_getitem_existing_key(self): + # Test __getitem__ with an existing key + self.config_loader._config = types.MappingProxyType(self.mock_config_data) + self.assertEqual(self.config_loader["key1"], "value1") + + def test_getitem_nonexistent_key(self): + # Test __getitem__ with a nonexistent key + self.config_loader._config = types.MappingProxyType(self.mock_config_data) + self.assertIsNone(self.config_loader["nonexistent_key"]) + + def test_getitem_no_config_loaded(self): + # Test __getitem__ when no config has been loaded + self.config_loader._config = None + self.assertIsNone(self.config_loader["key1"]) + + def test_config_immutable(self): + # Test that the configuration data is immutable + self.config_loader._config = types.MappingProxyType(self.mock_config_data) + with self.assertRaises(TypeError): + original_dict = self.mock_config_data + with self.assertRaises(TypeError): + original_dict["key1"] = "new_value" + + def test_load_config_type_error(self): + # Test handling of TypeError during json.load + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path, patch( + "builtins.open", mock_open(read_data=self.mock_config_json) + ), patch("json.load", side_effect=TypeError("Type error")): + mock_get_path.return_value = Path("/path/to/config.json") + with self.assertRaises(TypeError): + self.config_loader.load_config("config.json") + + def test_load_config_unexpected_exception(self): + # Test handling of an unexpected exception during file loading + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path, patch( + "builtins.open", side_effect=Exception("Unexpected error") + ): + mock_get_path.return_value = Path("/path/to/config.json") + with self.assertRaises(Exception) as context: + self.config_loader.load_config("config.json") + self.assertEqual(str(context.exception), "Unexpected error") + + def test_load_config_path_unexpected_exception(self): + # Test handling of an unexpected exception during file loading with load_config_path + with patch("builtins.open", side_effect=Exception("Unexpected error")): + with self.assertRaises(Exception) as context: + self.config_loader.load_config_path("/path/to/config.json") + self.assertEqual(str(context.exception), "Unexpected error") + + def test_config_manager_instance(self): + # Test that config_manager is an instance of ConfigLoader + from lib.core_utils.config_loader import config_manager + + self.assertIsInstance(config_manager, ConfigLoader) + + def test_configs_loaded(self): + # Test that configs are loaded when the module is imported + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path, patch( + "builtins.open", mock_open(read_data=self.mock_config_json) + ): + mock_get_path.return_value = Path("/path/to/config.json") + # Reload the module to trigger the code at the module level + import sys + + if "config_loader" in sys.modules: + del sys.modules["config_loader"] + from lib.core_utils import config_loader + + self.assertEqual( + config_loader.configs, types.MappingProxyType(self.mock_config_data) + ) + + def test_load_config_with_directory_traversal(self): + # Test that directory traversal in file_name is handled safely + with patch("lib.core_utils.config_loader.Ygg.get_path") as mock_get_path: + mock_get_path.return_value = Path("/path/to/../../etc/passwd") + with self.assertRaises(Exception): + self.config_loader.load_config("../../../etc/passwd") + + def test_load_config_path_with_invalid_path(self): + # Test that invalid paths in load_config_path are handled + with patch("pathlib.Path.open", side_effect=FileNotFoundError()): + with self.assertRaises(FileNotFoundError): + self.config_loader.load_config_path("/invalid/path/../../etc/passwd") if __name__ == "__main__": From e0191fc79b2e9ce86440246fb5889367f0f16f03 Mon Sep 17 00:00:00 2001 From: glrs <5999366+glrs@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:03:47 +0100 Subject: [PATCH 4/7] Update unittests --- tests/test_logging_utils.py | 454 +++++++++++++++++++++++++++++++++--- 1 file changed, 427 insertions(+), 27 deletions(-) diff --git a/tests/test_logging_utils.py b/tests/test_logging_utils.py index c1c6228..4c96ffd 100644 --- a/tests/test_logging_utils.py +++ b/tests/test_logging_utils.py @@ -1,42 +1,442 @@ +import logging +import os import unittest +from datetime import datetime +from pathlib import Path from unittest.mock import MagicMock, patch -from lib.core_utils.logging_utils import configure_logging +from lib.core_utils.logging_utils import configure_logging, custom_logger -class TestConfigureLogging(unittest.TestCase): - @patch("lib.utils.logging_utils.logging") - @patch("lib.utils.logging_utils.datetime") - @patch("lib.utils.logging_utils.Path") - def test_configure_logging(self, mock_path, mock_datetime, mock_logging): - # Set up the mocks - mock_path.return_value.mkdir.return_value = None - mock_datetime.now.return_value.strftime.return_value = "2022-01-01_00.00.00" - mock_logging.DEBUG = 10 - mock_logging.INFO = 20 - mock_logging.getLogger.return_value = MagicMock() +class TestLoggingUtils(unittest.TestCase): - # Call the function + def setUp(self): + # Backup original logging handlers and level + self.original_handlers = logging.getLogger().handlers.copy() + self.original_level = logging.getLogger().level + + # Clear existing handlers + logging.getLogger().handlers = [] + + # Mock configs + self.mock_configs = {"yggdrasil_log_dir": "/tmp/yggdrasil_logs"} + self.patcher_configs = patch( + "lib.core_utils.logging_utils.configs", self.mock_configs + ) + self.patcher_configs.start() + + # Mock datetime.datetime to control the timestamp + self.patcher_datetime = patch("lib.core_utils.logging_utils.datetime") + self.mock_datetime = self.patcher_datetime.start() + + # Create a mock datetime instance + mock_now = MagicMock() + mock_now.strftime.return_value = "2021-01-01_12.00.00" + # Set datetime.now() to return our mock datetime instance + self.mock_datetime.now.return_value = mock_now + + # Mock Path.mkdir to prevent actual directory creation + self.patcher_mkdir = patch("pathlib.Path.mkdir") + self.mock_mkdir = self.patcher_mkdir.start() + + # Mock logging.basicConfig to track calls + self.patcher_basicConfig = patch("logging.basicConfig") + self.mock_basicConfig = self.patcher_basicConfig.start() + + # Mock logging.FileHandler to prevent it from actually opening a file + self.patcher_filehandler = patch( + "lib.core_utils.logging_utils.logging.FileHandler", MagicMock() + ) + self.mock_filehandler = self.patcher_filehandler.start() + + def tearDown(self): + # Restore original logging handlers and level + logging.getLogger().handlers = self.original_handlers + logging.getLogger().level = self.original_level + + # Stop all patches + self.patcher_configs.stop() + self.patcher_datetime.stop() + self.patcher_mkdir.stop() + self.patcher_basicConfig.stop() + self.patcher_filehandler.stop() + + def test_configure_logging_default(self): + # Test configure_logging with default parameters (debug=False) + configure_logging() + + expected_log_dir = Path(self.mock_configs["yggdrasil_log_dir"]) + expected_log_file = expected_log_dir / "yggdrasil_2021-01-01_12.00.00.log" + expected_log_level = logging.INFO + expected_log_format = "%(asctime)s [%(name)s][%(levelname)s] %(message)s" + + self.mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + handlers = [logging.FileHandler(expected_log_file)] + self.mock_basicConfig.assert_called_once_with( + level=expected_log_level, format=expected_log_format, handlers=handlers + ) + + def test_configure_logging_debug_true(self): + # Test configure_logging with debug=True configure_logging(debug=True) - # Assert the mocks were called correctly - mock_path.assert_called_once_with("yggdrasil_workspace/logs") - mock_path.return_value.mkdir.assert_called_once_with( - parents=True, exist_ok=True + expected_log_dir = Path(self.mock_configs["yggdrasil_log_dir"]) + expected_log_file = expected_log_dir / "yggdrasil_2021-01-01_12.00.00.log" + expected_log_level = logging.DEBUG + expected_log_format = "%(asctime)s [%(name)s][%(levelname)s] %(message)s" + + self.mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + handlers = [logging.FileHandler(expected_log_file), logging.StreamHandler()] + self.mock_basicConfig.assert_called_once_with( + level=expected_log_level, format=expected_log_format, handlers=handlers + ) + + def test_configure_logging_creates_log_directory(self): + # Ensure that configure_logging attempts to create the log directory + configure_logging() + + expected_log_dir = Path(self.mock_configs["yggdrasil_log_dir"]) + self.mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + self.assertEqual(self.mock_mkdir.call_args[0], ()) + self.assertEqual( + self.mock_mkdir.call_args[1], {"parents": True, "exist_ok": True} ) - mock_datetime.now.assert_called_once() - mock_datetime.now.return_value.strftime.assert_called_once_with( - "%Y-%m-%d_%H.%M.%S" + + def test_configure_logging_handles_existing_directory(self): + # Test that no exception is raised if the directory already exists + self.mock_mkdir.side_effect = FileExistsError + + configure_logging() + + # The test passes if no exception is raised + + def test_configure_logging_invalid_log_dir(self): + # Test handling when the log directory is invalid + self.mock_mkdir.side_effect = PermissionError("Permission denied") + + with self.assertRaises(PermissionError): + configure_logging() + + def test_configure_logging_logs_to_correct_file(self): + # Mock logging.FileHandler to prevent file creation + with patch("logging.FileHandler") as mock_file_handler: + configure_logging() + + expected_log_dir = Path(self.mock_configs["yggdrasil_log_dir"]) + expected_log_file = expected_log_dir / "yggdrasil_2021-01-01_12.00.00.log" + + mock_file_handler.assert_called_once_with(expected_log_file) + + def test_custom_logger_returns_logger(self): + # Test that custom_logger returns a Logger instance with the correct name + module_name = "test_module" + logger = custom_logger(module_name) + + self.assertIsInstance(logger, logging.Logger) + self.assertEqual(logger.name, module_name) + + def test_custom_logger_same_logger(self): + # Test that calling custom_logger multiple times with the same name returns the same logger + module_name = "test_module" + logger1 = custom_logger(module_name) + logger2 = custom_logger(module_name) + + self.assertIs(logger1, logger2) + + def test_logging_levels_suppressed(self): + # Test that logging levels for specified noisy libraries are set to WARNING + noisy_libraries = ["matplotlib", "numba", "h5py", "PIL"] + for lib in noisy_libraries: + logger = logging.getLogger(lib) + self.assertEqual(logger.level, logging.WARNING) + + def test_logging_configuration_reset_between_tests(self): + # Ensure that logging configuration does not leak between tests + configure_logging() + initial_handlers = logging.getLogger().handlers.copy() + initial_level = logging.getLogger().level + + # Simulate another logging configuration + configure_logging(debug=True) + new_handlers = logging.getLogger().handlers.copy() + new_level = logging.getLogger().level + + # Handlers and level should be updated + self.assertNotEqual(initial_handlers, new_handlers) + self.assertNotEqual(initial_level, new_level) + + def test_configure_logging_multiple_calls(self): + # Test that multiple calls to configure_logging update the logging configuration + configure_logging() + first_call_handlers = logging.getLogger().handlers.copy() + first_call_level = logging.getLogger().level + + configure_logging(debug=True) + second_call_handlers = logging.getLogger().handlers.copy() + second_call_level = logging.getLogger().level + + self.assertNotEqual(first_call_handlers, second_call_handlers) + self.assertNotEqual(first_call_level, second_call_level) + + def test_configure_logging_no_configs(self): + # Test behavior when configs do not contain 'yggdrasil_log_dir' + self.mock_configs.pop("yggdrasil_log_dir", None) + + with self.assertRaises(KeyError): + configure_logging() + + def test_configure_logging_with_invalid_log_file(self): + # Test handling when the log file cannot be created + with patch( + "logging.FileHandler", side_effect=PermissionError("Permission denied") + ): + with self.assertRaises(PermissionError): + configure_logging() + + def test_configure_logging_with_invalid_stream_handler(self): + # Test handling when StreamHandler cannot be initialized + with patch("logging.StreamHandler", side_effect=Exception("Stream error")): + with self.assertRaises(Exception): + configure_logging(debug=True) + + def test_configure_logging_with_existing_handlers(self): + # Test that existing handlers are replaced + logging.getLogger().handlers = [MagicMock()] + configure_logging() + self.assertEqual(len(logging.getLogger().handlers), 1) + self.mock_basicConfig.assert_called_once() + + def test_configure_logging_handler_types(self): + # Test that handlers are of correct types + with patch("logging.FileHandler") as mock_file_handler, patch( + "logging.StreamHandler" + ) as mock_stream_handler: + configure_logging(debug=True) + + handlers = [ + mock_file_handler.return_value, + mock_stream_handler.return_value, + ] + self.mock_basicConfig.assert_called_once_with( + level=logging.DEBUG, + format="%(asctime)s [%(name)s][%(levelname)s] %(message)s", + handlers=handlers, + ) + + def test_configure_logging_log_format(self): + # Test that the log format is set correctly + configure_logging() + expected_log_format = "%(asctime)s [%(name)s][%(levelname)s] %(message)s" + self.mock_basicConfig.assert_called_once() + self.assertEqual( + self.mock_basicConfig.call_args[1]["format"], expected_log_format ) - mock_logging.getLogger.assert_called_once() - mock_logging.getLogger.return_value.setLevel.assert_called_once_with(10) - mock_logging.getLogger.return_value.addHandler.assert_called() - # Call the function again with debug=False - configure_logging(debug=False) + def test_configure_logging_log_level_info(self): + # Test that the log level is set to INFO when debug=False + configure_logging() + self.mock_basicConfig.assert_called_once() + self.assertEqual(self.mock_basicConfig.call_args[1]["level"], logging.INFO) + + def test_configure_logging_log_level_debug(self): + # Test that the log level is set to DEBUG when debug=True + configure_logging(debug=True) + self.mock_basicConfig.assert_called_once() + self.assertEqual(self.mock_basicConfig.call_args[1]["level"], logging.DEBUG) + + def test_configure_logging_handlers_order(self): + # Test that handlers are in the correct order + with patch("logging.FileHandler") as mock_file_handler, patch( + "logging.StreamHandler" + ) as mock_stream_handler: + configure_logging(debug=True) + + handlers = [ + mock_file_handler.return_value, + mock_stream_handler.return_value, + ] + self.assertEqual(self.mock_basicConfig.call_args[1]["handlers"], handlers) + + def test_configure_logging_timestamp_format(self): + # Test that the timestamp in the log file name is correctly formatted + configure_logging() + + expected_timestamp = "2021-01-01_12.00.00" + expected_log_dir = Path(self.mock_configs["yggdrasil_log_dir"]) + expected_log_file = expected_log_dir / f"yggdrasil_{expected_timestamp}.log" + + with patch("logging.FileHandler") as mock_file_handler: + configure_logging() + mock_file_handler.assert_called_with(expected_log_file) + + def test_configure_logging_custom_timestamp(self): + # Test with a different timestamp + self.mock_datetime.now.return_value = datetime(2022, 2, 2, 14, 30, 0) + self.mock_datetime.now().strftime.return_value = "2022-02-02_14.30.00" + + configure_logging() + + expected_timestamp = "2022-02-02_14.30.00" + expected_log_dir = Path(self.mock_configs["yggdrasil_log_dir"]) + expected_log_file = expected_log_dir / f"yggdrasil_{expected_timestamp}.log" + + with patch("logging.FileHandler") as mock_file_handler: + configure_logging() + mock_file_handler.assert_called_with(expected_log_file) + + def test_configure_logging_invalid_configs_type(self): + # Test handling when configs is of invalid type + with patch("lib.core_utils.logging_utils.configs", None): + with self.assertRaises(TypeError): + configure_logging() + + def test_configure_logging_log_dir_is_file(self): + # Test behavior when the log directory path is actually a file + with patch("pathlib.Path.mkdir", side_effect=NotADirectoryError): + with self.assertRaises(NotADirectoryError): + configure_logging() + + def test_configure_logging_no_handlers(self): + # Test that logging.basicConfig is called with correct handlers + with patch("logging.basicConfig") as mock_basic_config: + configure_logging() + self.assertIn("handlers", mock_basic_config.call_args[1]) + + def test_custom_logger_different_names(self): + # Test that different module names return different loggers + logger1 = custom_logger("module1") + logger2 = custom_logger("module2") + self.assertNotEqual(logger1, logger2) + self.assertNotEqual(logger1.name, logger2.name) - # Assert the logger's level was set to INFO - mock_logging.getLogger.return_value.setLevel.assert_called_with(20) + def test_custom_logger_propagate_false(self): + # Test that the logger's propagate attribute is default (True) + logger = custom_logger("module") + self.assertTrue(logger.propagate) + + def test_custom_logger_level_not_set(self): + # Test that the logger's level is not explicitly set (inherits from root) + logger = custom_logger("module") + self.assertEqual(logger.level, logging.NOTSET) + + def test_configure_logging_without_debug_stream_handler(self): + # Test that StreamHandler is not added when debug=False + with patch("logging.StreamHandler") as mock_stream_handler: + configure_logging() + mock_stream_handler.assert_not_called() + + def test_configure_logging_with_debug_stream_handler(self): + # Test that StreamHandler is added when debug=True + with patch("logging.StreamHandler") as mock_stream_handler: + configure_logging(debug=True) + mock_stream_handler.assert_called_once() + + def test_configure_logging_handlers_are_set_correctly(self): + # Test that handlers are set correctly in the root logger + with patch("logging.FileHandler") as mock_file_handler, patch( + "logging.StreamHandler" + ) as mock_stream_handler: + configure_logging(debug=True) + + root_logger = logging.getLogger() + self.assertEqual(len(root_logger.handlers), 2) + self.assertIsInstance( + root_logger.handlers[0], mock_file_handler.return_value.__class__ + ) + self.assertIsInstance( + root_logger.handlers[1], mock_stream_handler.return_value.__class__ + ) + + def test_configure_logging_respects_existing_loggers(self): + # Test that existing loggers are not affected by configure_logging + existing_logger = logging.getLogger("existing") + existing_logger_level = existing_logger.level + existing_logger_handlers = existing_logger.handlers.copy() + + configure_logging() + + self.assertEqual(existing_logger.level, existing_logger_level) + self.assertEqual(existing_logger.handlers, existing_logger_handlers) + + def test_logging_messages_after_configuration(self): + # Test that logging messages are handled correctly after configuration + with patch("logging.FileHandler") as mock_file_handler: + mock_file_handler.return_value = MagicMock() + configure_logging() + logger = custom_logger("test_module") + logger.info("Test message") + + # Ensure that the message is handled by the file handler + mock_file_handler.return_value.emit.assert_called() + + def test_suppressed_loggers_levels(self): + # Ensure that suppressed loggers have their levels set to WARNING + suppressed_loggers = ["matplotlib", "numba", "h5py", "PIL"] + for logger_name in suppressed_loggers: + logger = logging.getLogger(logger_name) + self.assertEqual(logger.level, logging.WARNING) + + def test_suppressed_loggers_do_not_propagate(self): + # Ensure that suppressed loggers still propagate messages + suppressed_loggers = ["matplotlib", "numba", "h5py", "PIL"] + for logger_name in suppressed_loggers: + logger = logging.getLogger(logger_name) + self.assertTrue(logger.propagate) + + def test_logging_basic_config_called_once(self): + # Ensure that logging.basicConfig is called only once + configure_logging() + self.mock_basicConfig.assert_called_once() + + def test_configure_logging_with_relative_log_dir(self): + # Test handling when 'yggdrasil_log_dir' is a relative path + self.mock_configs["yggdrasil_log_dir"] = "relative/path/to/logs" + configure_logging() + + expected_log_dir = Path("relative/path/to/logs") + self.mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + def test_configure_logging_with_env_var_in_log_dir(self): + # Test handling when 'yggdrasil_log_dir' contains an environment variable + self.mock_configs["yggdrasil_log_dir"] = "${HOME}/logs" + with patch.dict(os.environ, {"HOME": "/home/testuser"}): + configure_logging() + + expected_log_dir = Path("/home/testuser/logs") + self.mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + def test_configure_logging_invalid_log_format(self): + # Test handling when log_format is invalid + with patch("logging.basicConfig") as mock_basic_config: + configure_logging() + mock_basic_config.assert_called_once() + self.assertIn("format", mock_basic_config.call_args[1]) + + def test_configure_logging_with_custom_handlers(self): + # Test that custom handlers can be added if the code is modified in the future + # Since the current code does not support this, we check that handlers are as expected + configure_logging() + root_logger = logging.getLogger() + self.assertEqual(len(root_logger.handlers), 1) + self.assertIsInstance(root_logger.handlers[0], logging.FileHandler) + + def test_configure_logging_with_no_handlers(self): + # Test that an error is raised if handlers list is empty + with patch("logging.basicConfig") as mock_basic_config: + with patch( + "lib.core_utils.logging_utils.logging.FileHandler", + side_effect=Exception("Handler error"), + ): + with self.assertRaises(Exception): + configure_logging() + + def test_configure_logging_multiple_times(self): + # Test that multiple calls to configure_logging do not cause errors + configure_logging() + configure_logging(debug=True) + self.assertTrue(True) # Test passes if no exception is raised if __name__ == "__main__": From cefa56e636c11e8810281f9c1b544bccd7ed3f2c Mon Sep 17 00:00:00 2001 From: glrs <5999366+glrs@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:04:03 +0100 Subject: [PATCH 5/7] Update unittests --- tests/test_slurm_utils.py | 281 ++++++++++++++++++++++++++++++++++---- 1 file changed, 258 insertions(+), 23 deletions(-) diff --git a/tests/test_slurm_utils.py b/tests/test_slurm_utils.py index 6aa9617..bd8ad3b 100644 --- a/tests/test_slurm_utils.py +++ b/tests/test_slurm_utils.py @@ -1,32 +1,267 @@ import unittest -from unittest.mock import mock_open, patch +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch from lib.module_utils.slurm_utils import generate_slurm_script class TestGenerateSlurmScript(unittest.TestCase): - def test_generate_slurm_script(self): - # Define the input arguments - args_dict = { - "job_name": "test_batch", - "yaml_filepath": "/home/user/path/to.yaml", - } - template_fpath = "slurm_template.sh" - output_fpath = "slurm_script.sh" - template_content = "{job_name}\n{yaml_filepath}\n" - expected_script_content = "test_batch\n/home/user/path/to.yaml\n" - - # Patch the 'open' function and mock its behavior - with patch("builtins.open", mock_open(read_data=template_content)) as mock_file: - # Call the function under test - generate_slurm_script(args_dict, template_fpath, output_fpath) - - # Assert that the 'open' function was called with the correct arguments - mock_file.assert_any_call(template_fpath, "r") - mock_file.assert_any_call(output_fpath, "w") - - # Assert that the 'write' method of the file object was called with the expected content - mock_file().write.assert_called_once_with(expected_script_content) + + def setUp(self): + self.args_dict = {"job_name": "test_job", "time": "01:00:00"} + self.template_content = ( + "#!/bin/bash\n#SBATCH --job-name={job_name}\n#SBATCH --time={time}\n" + ) + self.expected_script = ( + "#!/bin/bash\n#SBATCH --job-name=test_job\n#SBATCH --time=01:00:00\n" + ) + self.template_fpath = "template.slurm" + self.output_fpath = "output.slurm" + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open, read_data="") + def test_generate_slurm_script_file_not_found(self, mock_file, mock_path): + # Simulate FileNotFoundError when opening the template file + mock_template_path = MagicMock() + mock_template_path.open.side_effect = FileNotFoundError( + "Template file not found" + ) + mock_path.return_value = mock_template_path + + result = generate_slurm_script( + self.args_dict, self.template_fpath, self.output_fpath + ) + self.assertFalse(result) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open, read_data="") + def test_generate_slurm_script_missing_placeholder(self, mock_file, mock_path): + # Simulate KeyError due to missing placeholder in args_dict + incomplete_args_dict = {"job_name": "test_job"} # Missing 'time' key + mock_template_path = MagicMock() + mock_template_path.open.return_value.__enter__.return_value.read.return_value = ( + self.template_content + ) + mock_path.return_value = mock_template_path + + result = generate_slurm_script( + incomplete_args_dict, self.template_fpath, self.output_fpath + ) + self.assertFalse(result) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_success(self, mock_file, mock_path): + # Mock reading the template file and writing the output file + mock_template_file = mock_open(read_data=self.template_content).return_value + mock_output_file = mock_open().return_value + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.return_value = mock_output_file + + # Mock Path objects + def side_effect(arg): + if arg == self.template_fpath: + return mock_template_path + elif arg == self.output_fpath: + return mock_output_path + else: + return Path(arg) + + mock_path.side_effect = side_effect + + result = generate_slurm_script( + self.args_dict, self.template_fpath, self.output_fpath + ) + self.assertTrue(result) + mock_template_file.read.assert_called_once() + mock_output_file.write.assert_called_once_with(self.expected_script) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_general_exception(self, mock_file, mock_path): + # Simulate a general exception during file writing + mock_template_file = mock_open(read_data=self.template_content).return_value + mock_output_file = mock_open().return_value + mock_output_file.write.side_effect = Exception("Write error") + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.return_value = mock_output_file + + # Mock Path objects + def side_effect(arg): + if arg == self.template_fpath: + return mock_template_path + elif arg == self.output_fpath: + return mock_output_path + else: + return Path(arg) + + mock_path.side_effect = side_effect + + result = generate_slurm_script( + self.args_dict, self.template_fpath, self.output_fpath + ) + self.assertFalse(result) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_empty_template(self, mock_file, mock_path): + # Test with an empty template + empty_template_content = "" + mock_template_file = mock_open(read_data=empty_template_content).return_value + mock_output_file = mock_open().return_value + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.return_value = mock_output_file + + # Mock Path objects + mock_path.side_effect = lambda arg: ( + mock_template_path if arg == self.template_fpath else mock_output_path + ) + + result = generate_slurm_script({}, self.template_fpath, self.output_fpath) + self.assertTrue(result) + mock_output_file.write.assert_called_once_with("") + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_empty_args_dict(self, mock_file, mock_path): + # Test with empty args_dict but placeholders in template + mock_template_file = mock_open(read_data=self.template_content).return_value + mock_output_file = mock_open().return_value + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.return_value = mock_output_file + + mock_path.side_effect = lambda arg: ( + mock_template_path if arg == self.template_fpath else mock_output_path + ) + + result = generate_slurm_script({}, self.template_fpath, self.output_fpath) + self.assertFalse(result) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_output_file_unwritable(self, mock_file, mock_path): + # Simulate exception when opening output file for writing + mock_template_file = mock_open(read_data=self.template_content).return_value + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.side_effect = PermissionError( + "Cannot write to output file" + ) + + # Mock Path objects + mock_path.side_effect = lambda arg: ( + mock_template_path if arg == self.template_fpath else mock_output_path + ) + + result = generate_slurm_script( + self.args_dict, self.template_fpath, self.output_fpath + ) + self.assertFalse(result) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_non_string_args(self, mock_file, mock_path): + # Test with non-string values in args_dict + args_dict = {"job_name": "test_job", "nodes": 4, "time": "01:00:00"} + template_content = "#!/bin/bash\n#SBATCH --job-name={job_name}\n#SBATCH --nodes={nodes}\n#SBATCH --time={time}\n" + expected_script = "#!/bin/bash\n#SBATCH --job-name=test_job\n#SBATCH --nodes=4\n#SBATCH --time=01:00:00\n" + + mock_template_file = mock_open(read_data=template_content).return_value + mock_output_file = mock_open().return_value + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.return_value = mock_output_file + + # Mock Path objects + mock_path.side_effect = lambda arg: ( + mock_template_path if arg == self.template_fpath else mock_output_path + ) + + result = generate_slurm_script( + args_dict, self.template_fpath, self.output_fpath + ) + self.assertTrue(result) + mock_output_file.write.assert_called_once_with(expected_script) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_template_syntax_error(self, mock_file, mock_path): + # Simulate ValueError due to invalid template syntax + invalid_template_content = "#!/bin/bash\n#SBATCH --job-name={job_name\n" + + mock_template_file = mock_open(read_data=invalid_template_content).return_value + mock_output_file = mock_open().return_value + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.return_value = mock_output_file + + mock_path.side_effect = lambda arg: ( + mock_template_path if arg == self.template_fpath else mock_output_path + ) + + result = generate_slurm_script( + self.args_dict, self.template_fpath, self.output_fpath + ) + self.assertFalse(result) + + def test_generate_slurm_script_invalid_template_path_type(self): + # Test with invalid type for template_fpath + with self.assertRaises(TypeError): + generate_slurm_script(self.args_dict, None, self.output_fpath) + + def test_generate_slurm_script_invalid_output_path_type(self): + # Test with invalid type for output_fpath + with self.assertRaises(TypeError): + generate_slurm_script(self.args_dict, self.template_fpath, None) + + @patch("lib.module_utils.slurm_utils.Path") + @patch("builtins.open", new_callable=mock_open) + def test_generate_slurm_script_no_placeholders(self, mock_file, mock_path): + # Test template with no placeholders + template_content = "#!/bin/bash\n#SBATCH --partition=general\n" + expected_script = template_content + + mock_template_file = mock_open(read_data=template_content).return_value + mock_output_file = mock_open().return_value + + mock_template_path = MagicMock(spec=Path) + mock_template_path.open.return_value = mock_template_file + + mock_output_path = MagicMock(spec=Path) + mock_output_path.open.return_value = mock_output_file + + mock_path.side_effect = lambda arg: ( + mock_template_path if arg == self.template_fpath else mock_output_path + ) + + result = generate_slurm_script({}, self.template_fpath, self.output_fpath) + self.assertTrue(result) + mock_output_file.write.assert_called_once_with(expected_script) if __name__ == "__main__": From 46f1d62585317748ca5d0f656c917e467d4d34e5 Mon Sep 17 00:00:00 2001 From: glrs <5999366+glrs@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:04:18 +0100 Subject: [PATCH 6/7] Add unittests --- tests/test_ngi_report_generator.py | 311 +++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 tests/test_ngi_report_generator.py diff --git a/tests/test_ngi_report_generator.py b/tests/test_ngi_report_generator.py new file mode 100644 index 0000000..74814b2 --- /dev/null +++ b/tests/test_ngi_report_generator.py @@ -0,0 +1,311 @@ +import subprocess +import unittest +from unittest.mock import MagicMock, patch + +from lib.module_utils.ngi_report_generator import generate_ngi_report + + +class TestGenerateNgiReport(unittest.TestCase): + + def setUp(self): + self.project_path = "/path/to/project" + self.project_id = "P12345" + self.user_name = "test_user" + self.sample_list = ["sample1", "sample2", "sample3"] + self.samples_str = "sample1 sample2 sample3" + self.activate_env_cmd = "source activate ngi_env" + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_success(self, mock_subprocess_run, mock_configs): + # Setup configs + mock_configs.get.return_value = self.activate_env_cmd + + # Setup subprocess.run to return success + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Report generated", stderr="" + ) + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, self.sample_list + ) + + self.assertTrue(result) + # Verify that subprocess.run was called with the correct command + expected_report_cmd = ( + f"ngi_reports project_summary -d {self.project_path} -p {self.project_id} " + f"-s '{self.user_name}' -y --no_txt --samples {self.samples_str}" + ) + expected_full_cmd = f"{self.activate_env_cmd} && {expected_report_cmd}" + mock_subprocess_run.assert_called_once_with( + expected_full_cmd, + shell=True, + text=True, + capture_output=True, + input="y\n", + ) + + @patch("lib.module_utils.ngi_report_generator.configs") + def test_generate_ngi_report_missing_activate_env_cmd(self, mock_configs): + # Configs return None for activate_ngi_cmd + mock_configs.get.return_value = None + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, self.sample_list + ) + + self.assertFalse(result) + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_nonzero_returncode( + self, mock_subprocess_run, mock_configs + ): + mock_configs.get.return_value = self.activate_env_cmd + + # Simulate subprocess.run returning non-zero exit code + mock_subprocess_run.return_value = MagicMock( + returncode=1, stdout="", stderr="Error generating report" + ) + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, self.sample_list + ) + + self.assertFalse(result) + # Optionally, check that the error message was logged + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_subprocess_error( + self, mock_subprocess_run, mock_configs + ): + mock_configs.get.return_value = self.activate_env_cmd + + # Simulate subprocess.run raising SubprocessError + mock_subprocess_run.side_effect = subprocess.SubprocessError( + "Subprocess failed" + ) + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, self.sample_list + ) + + self.assertFalse(result) + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_exception(self, mock_subprocess_run, mock_configs): + mock_configs.get.return_value = self.activate_env_cmd + + # Simulate subprocess.run raising a general Exception + mock_subprocess_run.side_effect = Exception("Unexpected error") + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, self.sample_list + ) + + self.assertFalse(result) + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_empty_sample_list( + self, mock_subprocess_run, mock_configs + ): + mock_configs.get.return_value = self.activate_env_cmd + + # Setup subprocess.run to return success + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Report generated", stderr="" + ) + + empty_sample_list = [] + samples_str = "" + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, empty_sample_list + ) + + self.assertTrue(result) + # Verify that subprocess.run was called with the correct command + expected_report_cmd = ( + f"ngi_reports project_summary -d {self.project_path} -p {self.project_id} " + f"-s '{self.user_name}' -y --no_txt --samples {samples_str}" + ) + expected_full_cmd = f"{self.activate_env_cmd} && {expected_report_cmd}" + mock_subprocess_run.assert_called_once_with( + expected_full_cmd, + shell=True, + text=True, + capture_output=True, + input="y\n", + ) + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_special_characters( + self, mock_subprocess_run, mock_configs + ): + mock_configs.get.return_value = self.activate_env_cmd + + # Setup subprocess.run to return success + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Report generated", stderr="" + ) + + # Use special characters in inputs + special_project_path = "/path/with special/chars & spaces" + special_user_name = "user & name" + special_sample_list = ["sample1", "sample two", "sample&three"] + samples_str = "sample1 sample two sample&three" + + result = generate_ngi_report( + special_project_path, + self.project_id, + special_user_name, + special_sample_list, + ) + + self.assertTrue(result) + # Verify that subprocess.run was called with the correct command + expected_report_cmd = ( + f"ngi_reports project_summary -d {special_project_path} -p {self.project_id} " + f"-s '{special_user_name}' -y --no_txt --samples {samples_str}" + ) + expected_full_cmd = f"{self.activate_env_cmd} && {expected_report_cmd}" + mock_subprocess_run.assert_called_once_with( + expected_full_cmd, + shell=True, + text=True, + capture_output=True, + input="y\n", + ) + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_long_sample_list( + self, mock_subprocess_run, mock_configs + ): + mock_configs.get.return_value = self.activate_env_cmd + + # Create a long list of samples + long_sample_list = [f"sample{i}" for i in range(1000)] + samples_str = " ".join(long_sample_list) + + # Setup subprocess.run to return success + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Report generated", stderr="" + ) + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, long_sample_list + ) + + self.assertTrue(result) + # Verify that subprocess.run was called + expected_report_cmd = ( + f"ngi_reports project_summary -d {self.project_path} -p {self.project_id} " + f"-s '{self.user_name}' -y --no_txt --samples {samples_str}" + ) + expected_full_cmd = f"{self.activate_env_cmd} && {expected_report_cmd}" + mock_subprocess_run.assert_called_once_with( + expected_full_cmd, + shell=True, + text=True, + capture_output=True, + input="y\n", + ) + + @patch("lib.module_utils.ngi_report_generator.configs") + def test_generate_ngi_report_configs_error(self, mock_configs): + # Simulate configs.get raising an exception + mock_configs.get.side_effect = Exception("Configs error") + + result = generate_ngi_report( + self.project_path, self.project_id, self.user_name, self.sample_list + ) + + self.assertFalse(result) + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_unicode_characters( + self, mock_subprocess_run, mock_configs + ): + mock_configs.get.return_value = self.activate_env_cmd + + # Use Unicode characters in inputs + unicode_project_path = "/path/to/项目" + unicode_user_name = "用户" + unicode_sample_list = ["样品一", "样品二"] + + samples_str = " ".join(unicode_sample_list) + + # Setup subprocess.run to return success + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="报告已生成", stderr="" + ) + + result = generate_ngi_report( + unicode_project_path, + self.project_id, + unicode_user_name, + unicode_sample_list, + ) + + self.assertTrue(result) + # Verify that subprocess.run was called with the correct command + expected_report_cmd = ( + f"ngi_reports project_summary -d {unicode_project_path} -p {self.project_id} " + f"-s '{unicode_user_name}' -y --no_txt --samples {samples_str}" + ) + expected_full_cmd = f"{self.activate_env_cmd} && {expected_report_cmd}" + mock_subprocess_run.assert_called_once_with( + expected_full_cmd, + shell=True, + text=True, + capture_output=True, + input="y\n", + ) + + @patch("lib.module_utils.ngi_report_generator.configs") + @patch("lib.module_utils.ngi_report_generator.subprocess.run") + def test_generate_ngi_report_input_injection( + self, mock_subprocess_run, mock_configs + ): + mock_configs.get.return_value = self.activate_env_cmd + + # Attempt to inject additional commands via inputs + malicious_user_name = "user_name'; rm -rf /; echo '" + samples_str = "sample1 sample2" + + # Setup subprocess.run to return success + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Report generated", stderr="" + ) + + result = generate_ngi_report( + self.project_path, + self.project_id, + malicious_user_name, + self.sample_list[:2], + ) + + self.assertTrue(result) + # Verify that subprocess.run was called with the correct (escaped) command + expected_report_cmd = ( + f"ngi_reports project_summary -d {self.project_path} -p {self.project_id} " + f"-s '{malicious_user_name}' -y --no_txt --samples {samples_str}" + ) + expected_full_cmd = f"{self.activate_env_cmd} && {expected_report_cmd}" + mock_subprocess_run.assert_called_once_with( + expected_full_cmd, + shell=True, + text=True, + capture_output=True, + input="y\n", + ) + + +if __name__ == "__main__": + unittest.main() From 6616686e368090b1b69540c12960de457d17c0ae Mon Sep 17 00:00:00 2001 From: glrs <5999366+glrs@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:04:28 +0100 Subject: [PATCH 7/7] Add unittests --- tests/test_report_transfer.py | 386 ++++++++++++++++++++++++++++++++++ 1 file changed, 386 insertions(+) create mode 100644 tests/test_report_transfer.py diff --git a/tests/test_report_transfer.py b/tests/test_report_transfer.py new file mode 100644 index 0000000..ce4f621 --- /dev/null +++ b/tests/test_report_transfer.py @@ -0,0 +1,386 @@ +import subprocess +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from lib.module_utils.report_transfer import transfer_report + + +class TestTransferReport(unittest.TestCase): + + def setUp(self): + self.report_path = Path("/path/to/report") + self.project_id = "project123" + self.sample_id = "sample456" + self.remote_dir_base = "/remote/destination" + self.server = "example.com" + self.user = "user" + self.ssh_key = "/path/to/ssh_key" + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.subprocess.run") + def test_transfer_report_success(self, mock_subprocess_run, mock_configs): + # Set up configs + mock_configs.__getitem__.return_value = { + "server": self.server, + "user": self.user, + "destination": self.remote_dir_base, + "ssh_key": self.ssh_key, + } + + # Set up subprocess.run to succeed + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Transfer complete", stderr="" + ) + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is True + self.assertTrue(result) + + # Assert subprocess.run was called with correct arguments + expected_remote_dir = ( + f"{self.remote_dir_base}/{self.project_id}/{self.sample_id}" + ) + expected_remote_path = f"{self.user}@{self.server}:{expected_remote_dir}/" + expected_rsync_command = [ + "rsync", + "-avz", + "--rsync-path", + f"mkdir -p '{expected_remote_dir}' && rsync", + "-e", + f"ssh -i {self.ssh_key}", + str(self.report_path), + expected_remote_path, + ] + mock_subprocess_run.assert_called_once_with( + expected_rsync_command, + check=True, + text=True, + capture_output=True, + ) + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.logging") + def test_transfer_report_missing_config_key(self, mock_logging, mock_configs): + # Set up configs to raise KeyError for missing 'server' key + mock_configs.__getitem__.side_effect = KeyError("server") + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is False + self.assertFalse(result) + + # Assert that logging.error was called with the missing key + mock_logging.error.assert_called_with( + "Missing configuration for report transfer: 'server'" + ) + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.subprocess.run") + def test_transfer_report_subprocess_calledprocesserror( + self, mock_subprocess_run, mock_configs + ): + # Set up configs + mock_configs.__getitem__.return_value = { + "server": self.server, + "user": self.user, + "destination": self.remote_dir_base, + "ssh_key": self.ssh_key, + } + + # Set up subprocess.run to raise CalledProcessError + mock_subprocess_run.side_effect = subprocess.CalledProcessError( + returncode=1, cmd="rsync", stderr="Error in rsync" + ) + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is False + self.assertFalse(result) + + # Assert that subprocess.run was called + mock_subprocess_run.assert_called_once() + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.subprocess.run") + def test_transfer_report_general_exception(self, mock_subprocess_run, mock_configs): + # Set up configs + mock_configs.__getitem__.return_value = { + "server": self.server, + "user": self.user, + "destination": self.remote_dir_base, + "ssh_key": self.ssh_key, + } + + # Set up subprocess.run to raise a general Exception + mock_subprocess_run.side_effect = Exception("Unexpected error") + + # Mock logging + with patch("lib.module_utils.report_transfer.logging") as mock_logging: + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is False + self.assertFalse(result) + + # Assert that logging.error was called with the exception message + mock_logging.error.assert_any_call( + "Unexpected error during report transfer: Unexpected error" + ) + mock_logging.error.assert_any_call("RSYNC output: ") + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.subprocess.run") + def test_transfer_report_no_ssh_key(self, mock_subprocess_run, mock_configs): + # Set up configs without ssh_key + mock_configs.__getitem__.return_value = { + "server": self.server, + "user": self.user, + "destination": self.remote_dir_base, + # ssh_key is optional + } + + # Set up subprocess.run to succeed + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Transfer complete", stderr="" + ) + + # Call the function without sample_id + result = transfer_report(self.report_path, self.project_id) + + # Assert the result is True + self.assertTrue(result) + + # Assert subprocess.run was called with correct arguments + expected_remote_dir = f"{self.remote_dir_base}/{self.project_id}" + expected_remote_path = f"{self.user}@{self.server}:{expected_remote_dir}/" + expected_rsync_command = [ + "rsync", + "-avz", + "--rsync-path", + f"mkdir -p '{expected_remote_dir}' && rsync", + "-e", + "ssh", + str(self.report_path), + expected_remote_path, + ] + mock_subprocess_run.assert_called_once_with( + expected_rsync_command, + check=True, + text=True, + capture_output=True, + ) + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.subprocess.run") + def test_transfer_report_without_sample_id(self, mock_subprocess_run, mock_configs): + # Set up configs + mock_configs.__getitem__.return_value = { + "server": self.server, + "user": self.user, + "destination": self.remote_dir_base, + "ssh_key": self.ssh_key, + } + + # Set up subprocess.run to succeed + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Transfer complete", stderr="" + ) + + # Call the function without sample_id + result = transfer_report(self.report_path, self.project_id) + + # Assert the result is True + self.assertTrue(result) + + # Assert subprocess.run was called with correct arguments + expected_remote_dir = f"{self.remote_dir_base}/{self.project_id}" + expected_remote_path = f"{self.user}@{self.server}:{expected_remote_dir}/" + expected_rsync_command = [ + "rsync", + "-avz", + "--rsync-path", + f"mkdir -p '{expected_remote_dir}' && rsync", + "-e", + f"ssh -i {self.ssh_key}", + str(self.report_path), + expected_remote_path, + ] + mock_subprocess_run.assert_called_once_with( + expected_rsync_command, + check=True, + text=True, + capture_output=True, + ) + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.logging") + def test_transfer_report_missing_destination(self, mock_logging, mock_configs): + # Set up configs missing 'destination' + mock_configs.__getitem__.return_value = { + "server": self.server, + "user": self.user, + "ssh_key": self.ssh_key, + # 'destination' key is missing + } + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is False + self.assertFalse(result) + + # Assert that logging.error was called with the missing key + mock_logging.error.assert_called_with( + "Missing configuration for report transfer: 'destination'" + ) + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.logging") + def test_transfer_report_nonexistent_report_path(self, mock_logging, mock_configs): + # Set up configs + mock_configs.__getitem__.return_value = { + "server": self.server, + "user": self.user, + "destination": self.remote_dir_base, + "ssh_key": self.ssh_key, + } + + # Assume report_path does not exist; since the function does not check this, it proceeds + # Mock subprocess.run to simulate rsync failure due to nonexistent report_path + with patch( + "lib.module_utils.report_transfer.subprocess.run" + ) as mock_subprocess_run: + mock_subprocess_run.side_effect = subprocess.CalledProcessError( + returncode=1, cmd="rsync", stderr="No such file or directory" + ) + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is False + self.assertFalse(result) + + # Assert that logging.error was called with rsync error + mock_logging.error.assert_called_with( + "Failed to transfer report:\nNo such file or directory" + ) + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.subprocess.run") + def test_transfer_report_unicode_characters( + self, mock_subprocess_run, mock_configs + ): + # Set up configs with Unicode characters + unicode_server = "例子.com" + unicode_user = "用户" + unicode_destination = "/远程/目的地" + + mock_configs.__getitem__.return_value = { + "server": unicode_server, + "user": unicode_user, + "destination": unicode_destination, + "ssh_key": self.ssh_key, + } + + # Set up subprocess.run to succeed + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="传输完成", stderr="" + ) + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is True + self.assertTrue(result) + + # Assert subprocess.run was called with correct arguments containing Unicode characters + expected_remote_dir = ( + f"{unicode_destination}/{self.project_id}/{self.sample_id}" + ) + expected_remote_path = f"{unicode_user}@{unicode_server}:{expected_remote_dir}/" + expected_rsync_command = [ + "rsync", + "-avz", + "--rsync-path", + f"mkdir -p '{expected_remote_dir}' && rsync", + "-e", + f"ssh -i {self.ssh_key}", + str(self.report_path), + expected_remote_path, + ] + mock_subprocess_run.assert_called_once_with( + expected_rsync_command, + check=True, + text=True, + capture_output=True, + ) + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.logging") + def test_transfer_report_invalid_config_type(self, mock_logging, mock_configs): + # Set up configs['report_transfer'] to be None + mock_configs.__getitem__.return_value = None + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is False + self.assertFalse(result) + + # Assert that logging.error was called + mock_logging.error.assert_called() + + @patch("lib.module_utils.report_transfer.configs") + @patch("lib.module_utils.report_transfer.subprocess.run") + def test_transfer_report_non_string_config_values( + self, mock_subprocess_run, mock_configs + ): + # Set up configs with non-string value for 'server' + mock_configs.__getitem__.return_value = { + "server": 123, # Non-string value + "user": self.user, + "destination": self.remote_dir_base, + "ssh_key": self.ssh_key, + } + + # Set up subprocess.run to succeed + mock_subprocess_run.return_value = MagicMock( + returncode=0, stdout="Transfer complete", stderr="" + ) + + # Call the function + result = transfer_report(self.report_path, self.project_id, self.sample_id) + + # Assert the result is True + self.assertTrue(result) + + # Assert subprocess.run was called with '123' converted to string + expected_remote_dir = ( + f"{self.remote_dir_base}/{self.project_id}/{self.sample_id}" + ) + expected_remote_path = f"{self.user}@123:{expected_remote_dir}/" + expected_rsync_command = [ + "rsync", + "-avz", + "--rsync-path", + f"mkdir -p '{expected_remote_dir}' && rsync", + "-e", + f"ssh -i {self.ssh_key}", + str(self.report_path), + expected_remote_path, + ] + mock_subprocess_run.assert_called_once_with( + expected_rsync_command, + check=True, + text=True, + capture_output=True, + ) + + +if __name__ == "__main__": + unittest.main()