Skip to content

Commit

Permalink
test entry point
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 3, 2024
1 parent f310ce1 commit ed502c6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/aiida_pythonjob/data/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def get_serializer_from_entry_points() -> dict:

# ts = time.time()
configs = load_config()
print("configs: ", configs)
serializers = configs.get("serializers", {})
excludes = serializers.get("excludes", [])
# Retrieve the entry points for 'aiida.data' and store them in a dictionary
Expand All @@ -35,10 +36,13 @@ def get_serializer_from_entry_points() -> dict:

# print("Time to load entry points: ", time.time() - ts)
# check if there are duplicates
print("serializers: ", serializers)
for key, value in eps.items():
print("key: ", key)
if len(value) > 1:
if key in serializers:
[ep for ep in value if ep.name == serializers[key]]
print("key: ", serializers[key])
print("name: ", ep.name)
eps[key] = [ep for ep in value if ep.name == serializers[key]]
if not eps[key]:
raise ValueError(f"Entry point {serializers[key]} not found for {key}")
Expand Down
58 changes: 58 additions & 0 deletions tests/test_entry_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from importlib.metadata import EntryPoint, EntryPoints
from unittest.mock import patch

import pytest


# Helper function to mock EntryPoint creation
def create_entry_point(name, value, group):
return EntryPoint(name=name, value=value, group=group)


def create_mock_entry_points(py_version, entry_point_list):
if py_version >= (3, 10):
# Mock the EntryPoints object for Python 3.10+
return EntryPoints(entry_point_list)
else:
# Return a dictionary for older Python versions
return {"aiida.data": entry_point_list}


@patch("aiida_pythonjob.data.serializer.load_config")
@patch("importlib.metadata.entry_points")
@patch("sys.version_info", new=(3, 10))
def test_get_serializer_from_entry_points(mock_entry_points, mock_load_config):
# Mock the configuration
mock_load_config.return_value = {
"serializers": {
"excludes": ["excluded_entry"],
}
}
# Mock entry points
mock_ep_1 = create_entry_point("xyz.abc.Abc", "xyz.abc:AbcData", "aiida.data")
mock_ep_2 = create_entry_point("xyz.abc.Bcd", "xyz.abc:BcdData", "aiida.data")
mock_ep_3 = create_entry_point("xyz.abc.Cde", "xyz.abc:CdeData", "aiida.data")
mock_ep_4 = create_entry_point("another_xyz.abc.Cde", "another_xyz.abc:CdeData", "aiida.data")

mock_entry_points.return_value = create_mock_entry_points((3, 10), [mock_ep_1, mock_ep_2, mock_ep_3, mock_ep_4])

# Import the function and run
from aiida_pythonjob.data.serializer import get_serializer_from_entry_points

with pytest.raises(ValueError, match="Duplicate entry points for abc.Cde"):
get_serializer_from_entry_points()
# Mock the configuration
mock_load_config.return_value = {
"serializers": {
"excludes": ["excluded_entry"],
"abc.Cde": "another_xyz.abc.Cde",
}
}
result = get_serializer_from_entry_points()
# Assert results
expected = {
"abc.Abc": [mock_ep_1],
"abc.Bcd": [mock_ep_2],
"abc.Cde": [mock_ep_4],
}
assert result == expected

0 comments on commit ed502c6

Please sign in to comment.