Skip to content

Commit

Permalink
Improve test coverage (#4)
Browse files Browse the repository at this point in the history
* add test for AtomsData
* Drop PickledFunction data
* add test for create_env
  • Loading branch information
superstar54 authored Dec 2, 2024
1 parent 9f35808 commit e059d36
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 220 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ keywords = ["aiida", "plugin"]
requires-python = ">=3.9"
dependencies = [
"aiida-core>=2.3,<3",
"ase",
"cloudpickle",
"voluptuous"
]
Expand All @@ -46,7 +47,6 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"

[project.entry-points."aiida.data"]
"pythonjob.pickled_data" = "aiida_pythonjob.data.pickled_data:PickledData"
"pythonjob.pickled_function" = "aiida_pythonjob.data.pickled_function:PickledFunction"
"pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData"
"pythonjob.builtins.int" = "aiida.orm.nodes.data.int:Int"
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
Expand Down
2 changes: 0 additions & 2 deletions src/aiida_pythonjob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
__version__ = "0.1.3"

from .calculations import PythonJob
from .data import PickledData, PickledFunction
from .launch import prepare_pythonjob_inputs
from .parsers import PythonJobParser

__all__ = (
"PythonJob",
"PickledData",
"PickledFunction",
"prepare_pythonjob_inputs",
"PythonJobParser",
)
42 changes: 3 additions & 39 deletions src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from aiida.engine import CalcJob, CalcJobProcessSpec
from aiida.orm import (
Data,
Dict,
FolderData,
List,
RemoteData,
Expand All @@ -19,8 +20,6 @@
to_aiida_type,
)

from aiida_pythonjob.data.pickled_function import PickledFunction, to_pickled_function

__all__ = ("PythonJob",)


Expand All @@ -42,31 +41,11 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
:param spec: the calculation job process spec to define.
"""
super().define(spec)
spec.input(
"function",
valid_type=PickledFunction,
serializer=to_pickled_function,
required=False,
)
spec.input(
"function_source_code",
valid_type=Str,
serializer=to_aiida_type,
required=False,
)
spec.input("function_name", valid_type=Str, serializer=to_aiida_type, required=False)
spec.input("function_data", valid_type=Dict, serializer=to_aiida_type, required=False)
spec.input("process_label", valid_type=Str, serializer=to_aiida_type, required=False)
spec.input_namespace(
"function_inputs", valid_type=Data, required=False
) # , serializer=serialize_to_aiida_nodes)
spec.input(
"function_outputs",
valid_type=List,
default=lambda: List(),
required=False,
serializer=to_aiida_type,
help="The information of the output ports",
)
spec.input(
"parent_folder",
valid_type=(RemoteData, FolderData, SinglefileData),
Expand Down Expand Up @@ -155,21 +134,6 @@ def on_create(self) -> None:
super().on_create()
self.node.label = self._build_process_label()

def get_function_data(self) -> dict[str, t.Any]:
"""Get the function data.
:returns: The function data.
"""
if "function" in self.inputs:
metadata = self.inputs.function.metadata
metadata["source_code"] = metadata["import_statements"] + "\n" + metadata["source_code_without_decorator"]
return metadata
else:
return {
"source_code": self.inputs.function_source_code.value,
"name": self.inputs.function_name.value,
}

def prepare_for_submission(self, folder: Folder) -> CalcInfo:
"""Prepare the calculation for submission.
Expand All @@ -192,7 +156,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
parent_folder_name = self.inputs.parent_folder_name.value
else:
parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME
function_data = self.get_function_data()
function_data = self.inputs.function_data.get_dict()
# create python script to run the function
script = f"""
import pickle
Expand Down
3 changes: 1 addition & 2 deletions src/aiida_pythonjob/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .pickled_data import PickledData
from .pickled_function import PickledFunction
from .serializer import general_serializer, serialize_to_aiida_nodes

__all__ = ("PickledData", "PickledFunction", "serialize_to_aiida_nodes", "general_serializer")
__all__ = ("PickledData", "serialize_to_aiida_nodes", "general_serializer")
145 changes: 0 additions & 145 deletions src/aiida_pythonjob/data/pickled_function.py

This file was deleted.

57 changes: 31 additions & 26 deletions src/aiida_pythonjob/launch.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
from __future__ import annotations

import inspect
import os
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

from aiida.orm import AbstractCode, Computer, FolderData, List, SinglefileData, Str
from aiida import orm

from .data.pickled_function import PickledFunction
from .data.serializer import serialize_to_aiida_nodes
from .utils import get_or_create_code
from .utils import build_function_data, get_or_create_code


def prepare_pythonjob_inputs(
function: Optional[Callable[..., Any]] = None,
function_inputs: Optional[Dict[str, Any]] = None,
function_outputs: Optional[Dict[str, Any]] = None,
code: Optional[AbstractCode] = None,
function_outputs: Optional[List[str | dict]] = None,
code: Optional[orm.AbstractCode] = None,
command_info: Optional[Dict[str, str]] = None,
computer: Union[str, Computer] = "localhost",
computer: Union[str, orm.Computer] = "localhost",
metadata: Optional[Dict[str, Any]] = None,
upload_files: Dict[str, str] = {},
process_label: Optional[str] = None,
pickled_function: Optional[PickledFunction] = None,
function_data: dict | None = None,
**kwargs: Any,
) -> Dict[str, Any]:
pass
"""Prepare the inputs for PythonJob"""

if function is None and pickled_function is None:
raise ValueError("Either function or pickled_function must be provided")
if function is not None and pickled_function is not None:
raise ValueError("Only one of function or pickled_function should be provided")
# if function is a function, convert it to a PickledFunction
if function is None and function_data is None:
raise ValueError("Either function or function_data must be provided")
if function is not None and function_data is not None:
raise ValueError("Only one of function or function_data should be provided")
# if function is a function, inspect it and get the source code
if function is not None and inspect.isfunction(function):
executor = PickledFunction.build_callable(function)
if pickled_function is not None:
executor = pickled_function
function_data = build_function_data(function)
new_upload_files = {}
# change the string in the upload files to SingleFileData, or FolderData
for key, source in upload_files.items():
Expand All @@ -42,10 +41,10 @@ def prepare_pythonjob_inputs(
new_key = key.replace(".", "_dot_")
if isinstance(source, str):
if os.path.isfile(source):
new_upload_files[new_key] = SinglefileData(file=source)
new_upload_files[new_key] = orm.SinglefileData(file=source)
elif os.path.isdir(source):
new_upload_files[new_key] = FolderData(tree=source)
elif isinstance(source, (SinglefileData, FolderData)):
new_upload_files[new_key] = orm.FolderData(tree=source)
elif isinstance(source, (orm.SinglefileData, orm.FolderData)):
new_upload_files[new_key] = source
else:
raise ValueError(f"Invalid upload file type: {type(source)}, {source}")
Expand All @@ -54,24 +53,30 @@ def prepare_pythonjob_inputs(
command_info = command_info or {}
code = get_or_create_code(computer=computer, **command_info)
# get the source code of the function
function_name = executor["name"]
if executor.get("is_pickle", False):
function_source_code = executor["import_statements"] + "\n" + executor["source_code_without_decorator"]
function_name = function_data["name"]
if function_data.get("is_pickle", False):
function_source_code = (
function_data["import_statements"] + "\n" + function_data["source_code_without_decorator"]
)
else:
function_source_code = f"from {executor['module']} import {function_name}"
function_source_code = f"from {function_data['module']} import {function_name}"

# serialize the kwargs into AiiDA Data
function_inputs = function_inputs or {}
function_inputs = serialize_to_aiida_nodes(function_inputs)
# transfer the args to kwargs
inputs = {
"process_label": process_label or "PythonJob<{}>".format(function_name),
"function_source_code": Str(function_source_code),
"function_name": Str(function_name),
"function_data": orm.Dict(
{
"source_code": function_source_code,
"name": function_name,
"outputs": function_outputs or [],
}
),
"code": code,
"function_inputs": function_inputs,
"upload_files": new_upload_files,
"function_outputs": List(function_outputs),
"metadata": metadata or {},
**kwargs,
}
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_pythonjob/parsers/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def parse(self, **kwargs):
"""
import pickle

function_outputs = self.node.inputs.function_outputs.get_list()
function_outputs = self.node.inputs.function_data.get_dict()["outputs"]
if len(function_outputs) == 0:
function_outputs = [{"name": "result"}]
self.output_list = function_outputs
Expand Down
Loading

0 comments on commit e059d36

Please sign in to comment.