Skip to content

Commit

Permalink
PythonJob: automatically serialize the inputs into AiiDA data (#85)
Browse files Browse the repository at this point in the history
First search the AiiDA data in the entry point based on the module name and class name, for example, `ase.atoms.Atoms`, if find a entry point, use it to serialize the value, if not found, use `GeneralData` to seralize the value.

Add more data entry point: int, float, str, bool, list and dict
  • Loading branch information
superstar54 committed May 23, 2024
1 parent d00e989 commit feced0c
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 259 deletions.
2 changes: 1 addition & 1 deletion aiida_workgraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
from .decorator import node, build_node


__version__ = "0.2.5"
__version__ = "0.2.6"

__all__ = ["WorkGraph", "Node", "node", "build_node"]
14 changes: 8 additions & 6 deletions aiida_workgraph/calculations/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
)


from .general_data import GeneralData

__all__ = ("PythonJob",)


Expand Down Expand Up @@ -55,7 +53,9 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
spec.input(
"function_name", valid_type=Str, serializer=to_aiida_type, required=False
)
spec.input_namespace("kwargs", valid_type=Data, required=False)
spec.input_namespace(
"kwargs", valid_type=Data, required=False
) # , serializer=general_serializer)
spec.input(
"output_name_list",
valid_type=List,
Expand Down Expand Up @@ -188,12 +188,14 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
# create pickle file for the inputs
input_values = {}
for key, value in inputs.items():
if isinstance(value, GeneralData):
if isinstance(value, Data) and hasattr(value, "value"):
# get the value of the pickled data
input_values[key] = value.value
else:
raise ValueError(f"Unsupported data type: {type(value)}")
# save the value as a pickle file, the path is absolute
raise ValueError(
f"Input data {value} is not supported. Only AiiDA data Node with a value attribute is allowed. "
)
# save the value as a pickle file, the path is absolute
filename = "inputs.pickle"
with folder.open(filename, "wb") as handle:
pickle.dump(input_values, handle)
Expand Down
30 changes: 24 additions & 6 deletions aiida_workgraph/calculations/python_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Parser for an `PythonJob` job."""
from aiida.parsers.parser import Parser
from .general_data import GeneralData
from aiida_workgraph.orm import general_serializer


class PythonParser(Parser):
Expand All @@ -14,18 +14,36 @@ def parse(self, **kwargs):
with self.retrieved.base.repository.open("results.pickle", "rb") as handle:
results = pickle.load(handle)
output_name_list = self.node.inputs.output_name_list.get_list()
# output_name_list exclude ['_wait', '_outputs', 'remote_folder', 'remote_stash', 'retrieved']
output_name_list = [
name
for name in output_name_list
if name
not in [
"_wait",
"_outputs",
"remote_folder",
"remote_stash",
"retrieved",
]
]
outputs = {}
if isinstance(results, tuple):
if len(output_name_list) != len(results):
raise ValueError(
"The number of results does not match the number of output_name_list."
)
for i in range(len(output_name_list)):
self.out(output_name_list[i].name, GeneralData(results[i]))
elif isinstance(results, dict):
for key, value in results.items():
self.out(key, GeneralData(value))
outputs[output_name_list[i].name] = results[i]
outputs = general_serializer(outputs)
elif isinstance(results, dict) and len(results) == len(
output_name_list
):
outputs = general_serializer(results)
else:
self.out("result", GeneralData(results))
outputs = general_serializer({"result": results})
for key, value in outputs.items():
self.out(key, value)
except OSError:
return self.exit_codes.ERROR_READING_OUTPUT_FILE
except ValueError:
Expand Down
16 changes: 3 additions & 13 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
self.to_context(**{name: process})
elif node["metadata"]["node_type"].upper() in ["PYTHONJOB"]:
from aiida_workgraph.calculations.python import PythonJob
from aiida_workgraph.calculations.general_data import GeneralData
from aiida_workgraph.orm.serializer import general_serializer
from aiida_workgraph.utils import get_or_create_code

print("node type: Python.")
Expand All @@ -849,20 +849,10 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
# get the source code of the function
function_name = executor.__name__
function_source_code = node["executor"]["function_source_code"]
inputs = {}
# save all kwargs to inputs port
for key, value in kwargs.items():
if isinstance(value, orm.Node):
if not hasattr(value, "value"):
raise ValueError(
"Only AiiDA data Node with a value attribute is allowed."
)
inputs[key] = value
else:
inputs[key] = GeneralData(value)
# outputs
output_name_list = [output["name"] for output in node["outputs"]]

# serialize the kwargs into AiiDA Data
inputs = general_serializer(kwargs)
# transfer the args to kwargs
process = self.submit(
PythonJob,
Expand Down
7 changes: 7 additions & 0 deletions aiida_workgraph/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .general_data import GeneralData
from .serializer import general_serializer

__all__ = (
"GeneralData",
"general_serializer",
)
128 changes: 128 additions & 0 deletions aiida_workgraph/orm/atoms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""`Data` sub class to represent a list."""

from aiida.orm import Data
from ase import Atoms

__all__ = ("AtomsData",)


class AtomsData(Data):
"""`Data to represent a ASE Atoms."""

_cached_atoms = None

def __init__(self, value=None, **kwargs):
"""Initialise a ``List`` node instance.
:param value: list to initialise the ``List`` node from
"""
data = value or kwargs.pop("atoms", Atoms())
super().__init__(**kwargs)
self.set_atoms(data)

@property
def value(self):
return self.get_atoms()

def initialize(self):
super().initialize()
self._cached_atoms = None

def __getitem__(self, item):
return self.get_atoms()[item]

def __setitem__(self, key, value):
data = self.get_atoms()
data[key] = value
if not self._using_atoms_reference():
self.set_atoms(data)

def __delitem__(self, key):
data = self.get_atoms()
del data[key]
if not self._using_atoms_reference():
self.set_atoms(data)

def __len__(self):
return len(self.get_atoms())

def __str__(self):
return f"{super().__str__()} : {self.get_atoms()}"

def __eq__(self, other):
if isinstance(other, Atoms):
return self.get_atoms() == other.get_atoms()
return self.get_atoms() == other

def append(self, value):
data = self.get_atoms()
data.append(value)
if not self._using_atoms_reference():
self.set_atoms(data)

def extend(self, value): # pylint: disable=arguments-renamed
data = self.get_atoms()
data.extend(value)
if not self._using_atoms_reference():
self.set_atoms(data)

def get_atoms(self):
"""Return the contents of this node.
:return: a Atoms
"""
import pickle

def get_atoms_from_file(self):
filename = "atoms.pkl"
# Open a handle in binary read mode as the arrays are written as binary files as well
with self.base.repository.open(filename, mode="rb") as f:
return pickle.loads(f.read()) # pylint: disable=unexpected-keyword-arg

# Return with proper caching if the node is stored, otherwise always re-read from disk
if not self.is_stored:
return get_atoms_from_file(self)

if self._cached_atoms is None:
self._cached_atoms = get_atoms_from_file(self)

return self._cached_atoms

def set_atoms(self, atoms):
"""Set the contents of this node.
:param atoms: the atoms to set
"""
import pickle

if not isinstance(atoms, Atoms):
raise TypeError("Must supply Atoms type")
self.base.repository.put_object_from_bytes(pickle.dumps(atoms), "atoms.pkl")
formula = atoms.get_chemical_formula()
# Store the array name and shape for querying purposes
self.base.attributes.set("formula", formula)

def _using_atoms_reference(self):
"""
This function tells the class if we are using a list reference. This
means that calls to self.get_atoms return a reference rather than a copy
of the underlying list and therefore self.set_atoms need not be called.
This knwoledge is essential to make sure this class is performant.
Currently the implementation assumes that if the node needs to be
stored then it is using the attributes cache which is a reference.
:return: True if using self.get_atoms returns a reference to the
underlying sequence. False otherwise.
:rtype: bool
"""
return self.is_stored
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
"""`Data` sub class to represent any data using pickle."""

from aiida.orm import Data
from aiida import orm

__all__ = ("GeneralData",)

class Dict(orm.Dict):
@property
def value(self):
return self.get_dict()


class List(orm.List):
@property
def value(self):
return self.get_list()


class GeneralData(Data):
class GeneralData(orm.Data):
"""`Data to represent a pickled value."""

def __init__(self, value=None, **kwargs):
Expand Down
47 changes: 47 additions & 0 deletions aiida_workgraph/orm/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from .general_data import GeneralData
from aiida import orm
from importlib.metadata import entry_points


# Retrieve the entry points for 'aiida.data' and store them in a dictionary
eps = {ep.name: ep for ep in entry_points().get("aiida.data", [])}


def general_serializer(inputs):
"""Serialize the inputs to a dictionary of AiiDA data nodes.
Args:
inputs (dict): The inputs to be serialized.
Returns:
dict: The serialized inputs.
"""
new_inputs = {}
# save all kwargs to inputs port
for key, value in inputs.items():
if isinstance(value, orm.Data):
if not hasattr(value, "value"):
raise ValueError(
"Only AiiDA data Node with a value attribute is allowed."
)
new_inputs[key] = value
# if value is a class instance, get its __module__ and class name as a string
# for example, an Atoms will have ase.atoms.Atoms
else:
# try to get the serializer from the entry points
value_type = type(value)
ep_key = f"{value_type.__module__}.{value_type.__name__}"
# search for the key in the entry points
if ep_key in eps:
try:
new_inputs[key] = eps[ep_key].load()(value)
except Exception as e:
raise ValueError(f"Error in serializing {key}: {e}")
else:
# try to serialize the value as a GeneralData
try:
new_inputs[key] = GeneralData(value)
except Exception as e:
raise ValueError(f"Error in serializing {key}: {e}")

return new_inputs
Loading

0 comments on commit feced0c

Please sign in to comment.