From e059d36aa5d2447a0dfad6bfd4bceb60a75e094f Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Mon, 2 Dec 2024 16:30:25 +0100 Subject: [PATCH] Improve test coverage (#4) * add test for AtomsData * Drop PickledFunction data * add test for create_env --- pyproject.toml | 2 +- src/aiida_pythonjob/__init__.py | 2 - src/aiida_pythonjob/calculations/pythonjob.py | 42 +---- src/aiida_pythonjob/data/__init__.py | 3 +- src/aiida_pythonjob/data/pickled_function.py | 145 ------------------ src/aiida_pythonjob/launch.py | 57 +++---- src/aiida_pythonjob/parsers/pythonjob.py | 2 +- src/aiida_pythonjob/utils.py | 94 +++++++++++- tests/test_create_env.py | 104 +++++++++++++ tests/test_data.py | 14 +- tests/test_utils.py | 13 ++ 11 files changed, 258 insertions(+), 220 deletions(-) delete mode 100644 src/aiida_pythonjob/data/pickled_function.py create mode 100644 tests/test_create_env.py create mode 100644 tests/test_utils.py diff --git a/pyproject.toml b/pyproject.toml index 28f218e..06c49c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ keywords = ["aiida", "plugin"] requires-python = ">=3.9" dependencies = [ "aiida-core>=2.3,<3", + "ase", "cloudpickle", "voluptuous" ] @@ -46,7 +47,6 @@ Source = "https://github.com/aiidateam/aiida-pythonjob" [project.entry-points."aiida.data"] "pythonjob.pickled_data" = "aiida_pythonjob.data.pickled_data:PickledData" -"pythonjob.pickled_function" = "aiida_pythonjob.data.pickled_function:PickledFunction" "pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData" "pythonjob.builtins.int" = "aiida.orm.nodes.data.int:Int" "pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float" diff --git a/src/aiida_pythonjob/__init__.py b/src/aiida_pythonjob/__init__.py index 6787136..12bfb33 100644 --- a/src/aiida_pythonjob/__init__.py +++ b/src/aiida_pythonjob/__init__.py @@ -3,14 +3,12 @@ __version__ = "0.1.3" from .calculations import PythonJob -from .data import PickledData, PickledFunction from .launch import prepare_pythonjob_inputs from .parsers import PythonJobParser __all__ = ( "PythonJob", "PickledData", - "PickledFunction", "prepare_pythonjob_inputs", "PythonJobParser", ) diff --git a/src/aiida_pythonjob/calculations/pythonjob.py b/src/aiida_pythonjob/calculations/pythonjob.py index d717c93..8fd12b5 100644 --- a/src/aiida_pythonjob/calculations/pythonjob.py +++ b/src/aiida_pythonjob/calculations/pythonjob.py @@ -11,6 +11,7 @@ from aiida.engine import CalcJob, CalcJobProcessSpec from aiida.orm import ( Data, + Dict, FolderData, List, RemoteData, @@ -19,8 +20,6 @@ to_aiida_type, ) -from aiida_pythonjob.data.pickled_function import PickledFunction, to_pickled_function - __all__ = ("PythonJob",) @@ -42,31 +41,11 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] :param spec: the calculation job process spec to define. """ super().define(spec) - spec.input( - "function", - valid_type=PickledFunction, - serializer=to_pickled_function, - required=False, - ) - spec.input( - "function_source_code", - valid_type=Str, - serializer=to_aiida_type, - required=False, - ) - spec.input("function_name", valid_type=Str, serializer=to_aiida_type, required=False) + spec.input("function_data", valid_type=Dict, serializer=to_aiida_type, required=False) spec.input("process_label", valid_type=Str, serializer=to_aiida_type, required=False) spec.input_namespace( "function_inputs", valid_type=Data, required=False ) # , serializer=serialize_to_aiida_nodes) - spec.input( - "function_outputs", - valid_type=List, - default=lambda: List(), - required=False, - serializer=to_aiida_type, - help="The information of the output ports", - ) spec.input( "parent_folder", valid_type=(RemoteData, FolderData, SinglefileData), @@ -155,21 +134,6 @@ def on_create(self) -> None: super().on_create() self.node.label = self._build_process_label() - def get_function_data(self) -> dict[str, t.Any]: - """Get the function data. - - :returns: The function data. - """ - if "function" in self.inputs: - metadata = self.inputs.function.metadata - metadata["source_code"] = metadata["import_statements"] + "\n" + metadata["source_code_without_decorator"] - return metadata - else: - return { - "source_code": self.inputs.function_source_code.value, - "name": self.inputs.function_name.value, - } - def prepare_for_submission(self, folder: Folder) -> CalcInfo: """Prepare the calculation for submission. @@ -192,7 +156,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: parent_folder_name = self.inputs.parent_folder_name.value else: parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME - function_data = self.get_function_data() + function_data = self.inputs.function_data.get_dict() # create python script to run the function script = f""" import pickle diff --git a/src/aiida_pythonjob/data/__init__.py b/src/aiida_pythonjob/data/__init__.py index 8d543fc..8b81c44 100644 --- a/src/aiida_pythonjob/data/__init__.py +++ b/src/aiida_pythonjob/data/__init__.py @@ -1,5 +1,4 @@ from .pickled_data import PickledData -from .pickled_function import PickledFunction from .serializer import general_serializer, serialize_to_aiida_nodes -__all__ = ("PickledData", "PickledFunction", "serialize_to_aiida_nodes", "general_serializer") +__all__ = ("PickledData", "serialize_to_aiida_nodes", "general_serializer") diff --git a/src/aiida_pythonjob/data/pickled_function.py b/src/aiida_pythonjob/data/pickled_function.py deleted file mode 100644 index 4cf7d89..0000000 --- a/src/aiida_pythonjob/data/pickled_function.py +++ /dev/null @@ -1,145 +0,0 @@ -import inspect -import textwrap -from typing import Any, Callable, Dict, _SpecialForm, get_type_hints - -from .pickled_data import PickledData - - -class PickledFunction(PickledData): - """Data class to represent a pickled Python function.""" - - def __init__(self, value=None, **kwargs): - """Initialize a PickledFunction node instance. - - :param value: a Python function - """ - super().__init__(**kwargs) - if not callable(value): - raise ValueError("value must be a callable Python function") - self.set_value(value) - self.set_attribute(value) - - def __str__(self): - return f"PickledFunction<{self.base.attributes.get('function_name')}> pk={self.pk}" - - @property - def metadata(self): - """Return a dictionary of metadata.""" - return { - "name": self.base.attributes.get("name"), - "import_statements": self.base.attributes.get("import_statements"), - "source_code": self.base.attributes.get("source_code"), - "source_code_without_decorator": self.base.attributes.get("source_code_without_decorator"), - "type": "function", - "is_pickle": True, - } - - @classmethod - def build_callable(cls, func): - """Return the executor for this node.""" - import cloudpickle as pickle - - executor = { - "executor": pickle.dumps(func), - "type": "function", - "is_pickle": True, - } - executor.update(cls.inspect_function(func)) - return executor - - def set_attribute(self, value): - """Set the contents of this node by pickling the provided function. - - :param value: The Python function to pickle and store. - """ - # Serialize the function and extract metadata - serialized_data = self.inspect_function(value) - - # Store relevant metadata - self.base.attributes.set("name", serialized_data["name"]) - self.base.attributes.set("import_statements", serialized_data["import_statements"]) - self.base.attributes.set("source_code", serialized_data["source_code"]) - self.base.attributes.set( - "source_code_without_decorator", - serialized_data["source_code_without_decorator"], - ) - - @classmethod - def inspect_function(cls, func: Callable) -> Dict[str, Any]: - """Serialize a function for storage or transmission.""" - try: - # we need save the source code explicitly, because in the case of jupyter notebook, - # the source code is not saved in the pickle file - source_code = inspect.getsource(func) - # Split the source into lines for processing - source_code_lines = source_code.split("\n") - function_source_code = "\n".join(source_code_lines) - # Find the first line of the actual function definition - for i, line in enumerate(source_code_lines): - if line.strip().startswith("def "): - break - function_source_code_without_decorator = "\n".join(source_code_lines[i:]) - function_source_code_without_decorator = textwrap.dedent(function_source_code_without_decorator) - # we also need to include the necessary imports for the types used in the type hints. - try: - required_imports = cls.get_required_imports(func) - except Exception as e: - required_imports = {} - print(f"Failed to get required imports for function {func.__name__}: {e}") - # Generate import statements - import_statements = "\n".join( - f"from {module} import {', '.join(types)}" for module, types in required_imports.items() - ) - except Exception as e: - print(f"Failed to inspect function {func.__name__}: {e}") - function_source_code = "" - function_source_code_without_decorator = "" - import_statements = "" - return { - "name": func.__name__, - "source_code": function_source_code, - "source_code_without_decorator": function_source_code_without_decorator, - "import_statements": import_statements, - } - - @classmethod - def get_required_imports(cls, func: Callable) -> Dict[str, set]: - """Retrieve type hints and the corresponding modules.""" - type_hints = get_type_hints(func) - imports = {} - - def add_imports(type_hint): - if isinstance(type_hint, _SpecialForm): # Handle special forms like Any, Union, Optional - module_name = "typing" - type_name = type_hint._name or str(type_hint) - elif hasattr(type_hint, "__origin__"): # This checks for higher-order types like List, Dict - module_name = type_hint.__module__ - type_name = getattr(type_hint, "_name", None) or getattr(type_hint.__origin__, "__name__", None) - for arg in getattr(type_hint, "__args__", []): - if arg is type(None): - continue - add_imports(arg) # Recursively add imports for each argument - elif hasattr(type_hint, "__module__"): - module_name = type_hint.__module__ - type_name = type_hint.__name__ - else: - return # If no module or origin, we can't import it, e.g., for literals - - if type_name is not None: - if module_name not in imports: - imports[module_name] = set() - imports[module_name].add(type_name) - - for _, type_hint in type_hints.items(): - add_imports(type_hint) - - return imports - - -def to_pickled_function(value): - """Convert a Python function to a `PickledFunction` instance.""" - return PickledFunction(value) - - -class PickledLocalFunction(PickledFunction): - """PickledFunction subclass for local functions.""" diff --git a/src/aiida_pythonjob/launch.py b/src/aiida_pythonjob/launch.py index c8e9f5c..a08b568 100644 --- a/src/aiida_pythonjob/launch.py +++ b/src/aiida_pythonjob/launch.py @@ -1,39 +1,38 @@ +from __future__ import annotations + import inspect import os -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union -from aiida.orm import AbstractCode, Computer, FolderData, List, SinglefileData, Str +from aiida import orm -from .data.pickled_function import PickledFunction from .data.serializer import serialize_to_aiida_nodes -from .utils import get_or_create_code +from .utils import build_function_data, get_or_create_code def prepare_pythonjob_inputs( function: Optional[Callable[..., Any]] = None, function_inputs: Optional[Dict[str, Any]] = None, - function_outputs: Optional[Dict[str, Any]] = None, - code: Optional[AbstractCode] = None, + function_outputs: Optional[List[str | dict]] = None, + code: Optional[orm.AbstractCode] = None, command_info: Optional[Dict[str, str]] = None, - computer: Union[str, Computer] = "localhost", + computer: Union[str, orm.Computer] = "localhost", metadata: Optional[Dict[str, Any]] = None, upload_files: Dict[str, str] = {}, process_label: Optional[str] = None, - pickled_function: Optional[PickledFunction] = None, + function_data: dict | None = None, **kwargs: Any, ) -> Dict[str, Any]: pass """Prepare the inputs for PythonJob""" - if function is None and pickled_function is None: - raise ValueError("Either function or pickled_function must be provided") - if function is not None and pickled_function is not None: - raise ValueError("Only one of function or pickled_function should be provided") - # if function is a function, convert it to a PickledFunction + if function is None and function_data is None: + raise ValueError("Either function or function_data must be provided") + if function is not None and function_data is not None: + raise ValueError("Only one of function or function_data should be provided") + # if function is a function, inspect it and get the source code if function is not None and inspect.isfunction(function): - executor = PickledFunction.build_callable(function) - if pickled_function is not None: - executor = pickled_function + function_data = build_function_data(function) new_upload_files = {} # change the string in the upload files to SingleFileData, or FolderData for key, source in upload_files.items(): @@ -42,10 +41,10 @@ def prepare_pythonjob_inputs( new_key = key.replace(".", "_dot_") if isinstance(source, str): if os.path.isfile(source): - new_upload_files[new_key] = SinglefileData(file=source) + new_upload_files[new_key] = orm.SinglefileData(file=source) elif os.path.isdir(source): - new_upload_files[new_key] = FolderData(tree=source) - elif isinstance(source, (SinglefileData, FolderData)): + new_upload_files[new_key] = orm.FolderData(tree=source) + elif isinstance(source, (orm.SinglefileData, orm.FolderData)): new_upload_files[new_key] = source else: raise ValueError(f"Invalid upload file type: {type(source)}, {source}") @@ -54,11 +53,13 @@ def prepare_pythonjob_inputs( command_info = command_info or {} code = get_or_create_code(computer=computer, **command_info) # get the source code of the function - function_name = executor["name"] - if executor.get("is_pickle", False): - function_source_code = executor["import_statements"] + "\n" + executor["source_code_without_decorator"] + function_name = function_data["name"] + if function_data.get("is_pickle", False): + function_source_code = ( + function_data["import_statements"] + "\n" + function_data["source_code_without_decorator"] + ) else: - function_source_code = f"from {executor['module']} import {function_name}" + function_source_code = f"from {function_data['module']} import {function_name}" # serialize the kwargs into AiiDA Data function_inputs = function_inputs or {} @@ -66,12 +67,16 @@ def prepare_pythonjob_inputs( # transfer the args to kwargs inputs = { "process_label": process_label or "PythonJob<{}>".format(function_name), - "function_source_code": Str(function_source_code), - "function_name": Str(function_name), + "function_data": orm.Dict( + { + "source_code": function_source_code, + "name": function_name, + "outputs": function_outputs or [], + } + ), "code": code, "function_inputs": function_inputs, "upload_files": new_upload_files, - "function_outputs": List(function_outputs), "metadata": metadata or {}, **kwargs, } diff --git a/src/aiida_pythonjob/parsers/pythonjob.py b/src/aiida_pythonjob/parsers/pythonjob.py index d10ef9c..2fb659f 100644 --- a/src/aiida_pythonjob/parsers/pythonjob.py +++ b/src/aiida_pythonjob/parsers/pythonjob.py @@ -22,7 +22,7 @@ def parse(self, **kwargs): """ import pickle - function_outputs = self.node.inputs.function_outputs.get_list() + function_outputs = self.node.inputs.function_data.get_dict()["outputs"] if len(function_outputs) == 0: function_outputs = [{"name": "result"}] self.output_list = function_outputs diff --git a/src/aiida_pythonjob/utils.py b/src/aiida_pythonjob/utils.py index 6155df0..495d5ba 100644 --- a/src/aiida_pythonjob/utils.py +++ b/src/aiida_pythonjob/utils.py @@ -1,9 +1,99 @@ -from typing import Dict, List, Optional, Tuple, Union +import inspect +import textwrap +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, _SpecialForm, get_type_hints from aiida.common.exceptions import NotExistent from aiida.orm import Computer, InstalledCode, User, load_code, load_computer +def get_required_imports(func: Callable) -> Dict[str, set]: + """Retrieve type hints and the corresponding modules.""" + type_hints = get_type_hints(func) + imports = {} + + def add_imports(type_hint): + if isinstance(type_hint, _SpecialForm): # Handle special forms like Any, Union, Optional + module_name = "typing" + type_name = type_hint._name or str(type_hint) + elif hasattr(type_hint, "__origin__"): # This checks for higher-order types like List, Dict + module_name = type_hint.__module__ + type_name = getattr(type_hint, "_name", None) or getattr(type_hint.__origin__, "__name__", None) + for arg in getattr(type_hint, "__args__", []): + if arg is type(None): + continue + add_imports(arg) # Recursively add imports for each argument + elif hasattr(type_hint, "__module__"): + module_name = type_hint.__module__ + type_name = type_hint.__name__ + else: + return # If no module or origin, we can't import it, e.g., for literals + if type_name is not None: + if module_name not in imports: + imports[module_name] = set() + imports[module_name].add(type_name) + + for _, type_hint in type_hints.items(): + add_imports(type_hint) + return imports + + +def inspect_function(func: Callable) -> Dict[str, Any]: + """Serialize a function for storage or transmission.""" + # we need save the source code explicitly, because in the case of jupyter notebook, + # the source code is not saved in the pickle file + try: + source_code = inspect.getsource(func) + except OSError: + raise ValueError("Failed to get the source code of the function.") + + # Split the source into lines for processing + source_code_lines = source_code.split("\n") + function_source_code = "\n".join(source_code_lines) + # Find the first line of the actual function definition + for i, line in enumerate(source_code_lines): + if line.strip().startswith("def "): + break + function_source_code_without_decorator = "\n".join(source_code_lines[i:]) + function_source_code_without_decorator = textwrap.dedent(function_source_code_without_decorator) + # we also need to include the necessary imports for the types used in the type hints. + try: + required_imports = get_required_imports(func) + except Exception as exception: + raise ValueError(f"Failed to get the required imports for the function: {exception}") + # Generate import statements + import_statements = "\n".join( + f"from {module} import {', '.join(types)}" for module, types in required_imports.items() + ) + return { + "name": func.__name__, + "source_code": function_source_code, + "source_code_without_decorator": function_source_code_without_decorator, + "import_statements": import_statements, + "is_pickle": True, + } + + +def build_function_data(func): + """Return the executor for this node.""" + import types + + if isinstance(func, (types.FunctionType, types.BuiltinFunctionType, type)): + # Check if callable is nested (contains dots in __qualname__ after the first segment) + if func.__module__ == "__main__" or "." in func.__qualname__.split(".", 1)[-1]: + # Local or nested callable, so pickle the callable + executor = inspect_function(func) + else: + # Global callable (function/class), store its module and name for reference + executor = { + "module": func.__module__, + "name": func.__name__, + "is_pickle": False, + } + else: + raise TypeError("Provided object is not a callable function or class.") + return executor + + def get_or_create_code( label: str = "python3", computer: Optional[Union[str, "Computer"]] = "localhost", @@ -142,7 +232,7 @@ def create_conda_env( if retval != 0: return ( False, - f"The command `echo -n` returned a non-zero return code ({retval})", + f"The command returned a non-zero return code ({retval})", ) template = """ diff --git a/tests/test_create_env.py b/tests/test_create_env.py new file mode 100644 index 0000000..7721459 --- /dev/null +++ b/tests/test_create_env.py @@ -0,0 +1,104 @@ +from unittest.mock import MagicMock, patch + + +def test_create_conda_env(): + computer_name = "test_computer" + env_name = "test_env" + pip_packages = ["numpy", "pandas"] + modules = ["qe"] + variables = {"TEST_VAR": "test_value"} + conda_deps = ["scipy"] + python_version = "3.8" + shell = "posix" + + # Mock the computer and related objects + mock_computer = MagicMock() + mock_computer.label = computer_name + mock_user = MagicMock() + mock_user.email = "test_user@test.com" + mock_authinfo = MagicMock() + mock_transport = MagicMock() + mock_scheduler = MagicMock() + + mock_authinfo.get_transport.return_value = mock_transport + mock_computer.get_authinfo.return_value = mock_authinfo + mock_authinfo.computer.get_scheduler.return_value = mock_scheduler + + # Mock successful transport behavior + mock_transport.exec_command_wait.return_value = ( + 0, # retval + "Environment setup is complete.\n", # stdout + "", # stderr + ) + + # Patch `load_computer` and `User.collection.get_default` to return mocked objects + with ( + patch("aiida_pythonjob.utils.load_computer", return_value=mock_computer), + patch("aiida_pythonjob.utils.User.collection.get_default", return_value=mock_user), + ): + from aiida_pythonjob.utils import create_conda_env + + success, message = create_conda_env( + computer=computer_name, + name=env_name, + pip=pip_packages, + conda={"dependencies": conda_deps, "channels": ["conda-forge"]}, + modules=modules, + variables=variables, + python_version=python_version, + shell=shell, + ) + + # Assertions for successful case + assert success is True + assert message == "Environment setup is complete." + + # Validate that exec_command_wait was called with the generated script + mock_transport.exec_command_wait.assert_called_once() + called_script = mock_transport.exec_command_wait.call_args[0][0] + assert f"conda create -y -n {env_name} python={python_version}" in called_script + assert "pip install numpy pandas" in called_script + assert "conda config --prepend channels" in called_script + assert "module load qe" in called_script + assert "export TEST_VAR='test_value'" in called_script + + +def test_create_conda_env_error_handling(): + computer_name = "test_computer" + env_name = "test_env" + + # Mock the computer and related objects + mock_computer = MagicMock() + mock_computer.label = computer_name + mock_user = MagicMock() + mock_user.email = "test_user@test.com" + mock_authinfo = MagicMock() + mock_transport = MagicMock() + mock_scheduler = MagicMock() + + # Mock error in transport + mock_transport.exec_command_wait.return_value = ( + 1, # retval + "", # stdout + "Error creating environment", # stderr + ) + + mock_authinfo.get_transport.return_value = mock_transport + mock_authinfo.computer.get_scheduler.return_value = mock_scheduler + mock_computer.get_authinfo.return_value = mock_authinfo + + # Patch `load_computer` and `User.collection.get_default` to return mocked objects + with ( + patch("aiida_pythonjob.utils.load_computer", return_value=mock_computer), + patch("aiida_pythonjob.utils.User.collection.get_default", return_value=mock_user), + ): + from aiida_pythonjob.utils import create_conda_env + + success, message = create_conda_env( + computer=computer_name, + name=env_name, + ) + + # Assertions for failure case + assert success is False + assert "The command returned a non-zero return code" in message diff --git a/tests/test_data.py b/tests/test_data.py index cf3d981..82c8f9c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,5 +1,5 @@ import aiida -from aiida_pythonjob import PickledFunction +from aiida_pythonjob.utils import get_required_imports def test_typing(): @@ -16,7 +16,7 @@ def generate_structures( ) -> list[array]: pass - modules = PickledFunction.get_required_imports(generate_structures) + modules = get_required_imports(generate_structures) assert modules == { "typing": {"List"}, "builtins": {"list", "float"}, @@ -34,3 +34,13 @@ def test_python_job(): assert isinstance(new_inputs["a"], aiida.orm.Int) assert isinstance(new_inputs["b"], aiida.orm.Float) assert isinstance(new_inputs["c"], PickledData) + + +def test_atoms_data(): + from aiida_pythonjob.data.atoms import AtomsData + from ase.build import bulk + + atoms = bulk("Si") + + atoms_data = AtomsData(atoms) + assert atoms_data.value == atoms diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..508d741 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,13 @@ +from aiida_pythonjob.utils import build_function_data + + +def test_build_function_data(): + from math import sqrt + + function_data = build_function_data(sqrt) + assert function_data == {"module": "math", "name": "sqrt", "is_pickle": False} + # + try: + function_data = build_function_data(1) + except Exception as e: + assert str(e) == "Provided object is not a callable function or class."