From ed502c6854f5a8f67475a02bcf401f922743789d Mon Sep 17 00:00:00 2001 From: superstar54 Date: Tue, 3 Dec 2024 08:27:49 +0100 Subject: [PATCH] test entry point --- src/aiida_pythonjob/data/serializer.py | 6 ++- tests/test_entry_points.py | 58 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/test_entry_points.py diff --git a/src/aiida_pythonjob/data/serializer.py b/src/aiida_pythonjob/data/serializer.py index cb56cf3..291f368 100644 --- a/src/aiida_pythonjob/data/serializer.py +++ b/src/aiida_pythonjob/data/serializer.py @@ -15,6 +15,7 @@ def get_serializer_from_entry_points() -> dict: # ts = time.time() configs = load_config() + print("configs: ", configs) serializers = configs.get("serializers", {}) excludes = serializers.get("excludes", []) # Retrieve the entry points for 'aiida.data' and store them in a dictionary @@ -35,10 +36,13 @@ def get_serializer_from_entry_points() -> dict: # print("Time to load entry points: ", time.time() - ts) # check if there are duplicates + print("serializers: ", serializers) for key, value in eps.items(): + print("key: ", key) if len(value) > 1: if key in serializers: - [ep for ep in value if ep.name == serializers[key]] + print("key: ", serializers[key]) + print("name: ", ep.name) eps[key] = [ep for ep in value if ep.name == serializers[key]] if not eps[key]: raise ValueError(f"Entry point {serializers[key]} not found for {key}") diff --git a/tests/test_entry_points.py b/tests/test_entry_points.py new file mode 100644 index 0000000..2c5ab9e --- /dev/null +++ b/tests/test_entry_points.py @@ -0,0 +1,58 @@ +from importlib.metadata import EntryPoint, EntryPoints +from unittest.mock import patch + +import pytest + + +# Helper function to mock EntryPoint creation +def create_entry_point(name, value, group): + return EntryPoint(name=name, value=value, group=group) + + +def create_mock_entry_points(py_version, entry_point_list): + if py_version >= (3, 10): + # Mock the EntryPoints object for Python 3.10+ + return EntryPoints(entry_point_list) + else: + # Return a dictionary for older Python versions + return {"aiida.data": entry_point_list} + + +@patch("aiida_pythonjob.data.serializer.load_config") +@patch("importlib.metadata.entry_points") +@patch("sys.version_info", new=(3, 10)) +def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config): + # Mock the configuration + mock_load_config.return_value = { + "serializers": { + "excludes": ["excluded_entry"], + } + } + # Mock entry points + mock_ep_1 = create_entry_point("xyz.abc.Abc", "xyz.abc:AbcData", "aiida.data") + mock_ep_2 = create_entry_point("xyz.abc.Bcd", "xyz.abc:BcdData", "aiida.data") + mock_ep_3 = create_entry_point("xyz.abc.Cde", "xyz.abc:CdeData", "aiida.data") + mock_ep_4 = create_entry_point("another_xyz.abc.Cde", "another_xyz.abc:CdeData", "aiida.data") + + mock_entry_points.return_value = create_mock_entry_points((3, 10), [mock_ep_1, mock_ep_2, mock_ep_3, mock_ep_4]) + + # Import the function and run + from aiida_pythonjob.data.serializer import get_serializer_from_entry_points + + with pytest.raises(ValueError, match="Duplicate entry points for abc.Cde"): + get_serializer_from_entry_points() + # Mock the configuration + mock_load_config.return_value = { + "serializers": { + "excludes": ["excluded_entry"], + "abc.Cde": "another_xyz.abc.Cde", + } + } + result = get_serializer_from_entry_points() + # Assert results + expected = { + "abc.Abc": [mock_ep_1], + "abc.Bcd": [mock_ep_2], + "abc.Cde": [mock_ep_4], + } + assert result == expected