diff --git a/src/aiida_pythonjob/calculations/pythonjob.py b/src/aiida_pythonjob/calculations/pythonjob.py index 8fd12b5..5299efd 100644 --- a/src/aiida_pythonjob/calculations/pythonjob.py +++ b/src/aiida_pythonjob/calculations/pythonjob.py @@ -125,7 +125,7 @@ def _build_process_label(self) -> str: if "process_label" in self.inputs: return self.inputs.process_label.value else: - data = self.get_function_data() + data = self.inputs.function_data.get_dict() return f"PythonJob<{data['name']}>" def on_create(self) -> None: diff --git a/src/aiida_pythonjob/parsers/pythonjob.py b/src/aiida_pythonjob/parsers/pythonjob.py index 2fb659f..c3b7281 100644 --- a/src/aiida_pythonjob/parsers/pythonjob.py +++ b/src/aiida_pythonjob/parsers/pythonjob.py @@ -28,57 +28,61 @@ def parse(self, **kwargs): self.output_list = function_outputs # first we remove nested outputs, e.g., "add_multiply.add" top_level_output_list = [output for output in self.output_list if "." not in output["name"]] - exit_code = 0 try: with self.retrieved.base.repository.open("results.pickle", "rb") as handle: results = pickle.load(handle) if isinstance(results, tuple): if len(top_level_output_list) != len(results): - self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH + return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH for i in range(len(top_level_output_list)): top_level_output_list[i]["value"] = self.serialize_output(results[i], top_level_output_list[i]) - elif isinstance(results, dict) and len(top_level_output_list) > 1: + elif isinstance(results, dict): # pop the exit code if it exists exit_code = results.pop("exit_code", 0) - for output in top_level_output_list: - if output.get("required", False): + if exit_code: + if isinstance(exit_code, dict): + exit_code = ExitCode(exit_code["status"], exit_code["message"]) + elif isinstance(exit_code, int): + exit_code = ExitCode(exit_code) + return exit_code + + if len(top_level_output_list) > 1: + for output in top_level_output_list: if output["name"] not in results: - self.exit_codes.ERROR_MISSING_OUTPUT - output["value"] = self.serialize_output(results.pop(output["name"]), output) - # if there are any remaining results, raise an warning - if results: - self.logger.warning( - f"Found extra results that are not included in the output: {results.keys()}" - ) - elif isinstance(results, dict) and len(top_level_output_list) == 1: - exit_code = results.pop("exit_code", 0) - # if output name in results, use it - if top_level_output_list[0]["name"] in results: - top_level_output_list[0]["value"] = self.serialize_output( - results[top_level_output_list[0]["name"]], - top_level_output_list[0], - ) - # otherwise, we assume the results is the output - else: - top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) + if output.get("required", True): + return self.exit_codes.ERROR_MISSING_OUTPUT + else: + output["value"] = self.serialize_output(results.pop(output["name"]), output) + # if there are any remaining results, raise an warning + if results: + self.logger.warning( + f"Found extra results that are not included in the output: {results.keys()}" + ) + elif len(top_level_output_list) == 1: + # if output name in results, use it + if top_level_output_list[0]["name"] in results: + top_level_output_list[0]["value"] = self.serialize_output( + results[top_level_output_list[0]["name"]], + top_level_output_list[0], + ) + # otherwise, we assume the results is the output + else: + top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) elif len(top_level_output_list) == 1: - # otherwise, we assume the results is the output + # otherwise it returns a single value, we assume the results is the output top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) else: - raise ValueError("The number of results does not match the number of outputs.") + return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH for output in top_level_output_list: self.out(output["name"], output["value"]) - if exit_code: - if isinstance(exit_code, dict): - exit_code = ExitCode(exit_code["status"], exit_code["message"]) - elif isinstance(exit_code, int): - exit_code = ExitCode(exit_code) - return exit_code except OSError: return self.exit_codes.ERROR_READING_OUTPUT_FILE except ValueError as exception: self.logger.error(exception) return self.exit_codes.ERROR_INVALID_OUTPUT + except Exception as exception: + self.logger.error(exception) + return self.exit_codes.ERROR_INVALID_OUTPUT def find_output(self, name): """Find the output with the given name.""" diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..d020469 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,80 @@ +import pathlib +import tempfile + +import cloudpickle as pickle +from aiida import orm +from aiida.common.links import LinkType +from aiida_pythonjob.parsers import PythonJobParser + + +def create_retrieved_folder(result: dict): + # Create a retrieved ``FolderData`` node with results + with tempfile.TemporaryDirectory() as tmpdir: + dirpath = pathlib.Path(tmpdir) + with open((dirpath / "results.pickle"), "wb") as handle: + pickle.dump(result, handle) + folder_data = orm.FolderData(tree=dirpath.absolute()) + return folder_data + + +def create_process_node(result: dict, function_data: dict): + node = orm.CalcJobNode() + node.set_process_type("aiida.calculations:pythonjob.pythonjob") + function_data = orm.Dict(function_data) + retrieved = create_retrieved_folder(result) + node.base.links.add_incoming(function_data, link_type=LinkType.INPUT_CALC, link_label="function_data") + retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label="retrieved") + function_data.store() + node.store() + retrieved.store() + return node + + +def create_parser(result, function_data): + node = create_process_node(result, function_data) + parser = PythonJobParser(node=node) + return parser + + +def test_tuple_result(fixture_localhost): + result = (1, 2, 3) + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is None + assert len(parser.outputs) == 3 + + +def test_tuple_result_mismatch(fixture_localhost): + result = (1, 2) + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code == parser.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH + + +def test_dict_result(fixture_localhost): + result = {"a": 1, "b": 2, "c": 3} + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is None + assert len(parser.outputs) == 3 + + +def test_dict_result_missing(fixture_localhost): + result = {"a": 1, "b": 2} + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code == parser.exit_codes.ERROR_MISSING_OUTPUT + + +def test_exit_code(fixture_localhost): + result = {"exit_code": {"status": 1, "message": "error"}} + function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is not None + assert exit_code.status == 1 + assert exit_code.message == "error" diff --git a/tests/test_parsers.py b/tests/test_parsers.py deleted file mode 100644 index e69de29..0000000