diff --git a/src/aiida_pythonjob/parsers/pythonjob.py b/src/aiida_pythonjob/parsers/pythonjob.py index c3b7281..7e5fe04 100644 --- a/src/aiida_pythonjob/parsers/pythonjob.py +++ b/src/aiida_pythonjob/parsers/pythonjob.py @@ -45,8 +45,22 @@ def parse(self, **kwargs): elif isinstance(exit_code, int): exit_code = ExitCode(exit_code) return exit_code - - if len(top_level_output_list) > 1: + if len(top_level_output_list) == 1: + # if output name in results, use it + if top_level_output_list[0]["name"] in results: + top_level_output_list[0]["value"] = self.serialize_output( + results.pop(top_level_output_list[0]["name"]), + top_level_output_list[0], + ) + # if there are any remaining results, raise an warning + if len(results) > 0: + self.logger.warning( + f"Found extra results that are not included in the output: {results.keys()}" + ) + # otherwise, we assume the results is the output + else: + top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) + elif len(top_level_output_list) > 1: for output in top_level_output_list: if output["name"] not in results: if output.get("required", True): @@ -54,20 +68,11 @@ def parse(self, **kwargs): else: output["value"] = self.serialize_output(results.pop(output["name"]), output) # if there are any remaining results, raise an warning - if results: + if len(results) > 0: self.logger.warning( f"Found extra results that are not included in the output: {results.keys()}" ) - elif len(top_level_output_list) == 1: - # if output name in results, use it - if top_level_output_list[0]["name"] in results: - top_level_output_list[0]["value"] = self.serialize_output( - results[top_level_output_list[0]["name"]], - top_level_output_list[0], - ) - # otherwise, we assume the results is the output - else: - top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) + elif len(top_level_output_list) == 1: # otherwise it returns a single value, we assume the results is the output top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) @@ -80,9 +85,6 @@ def parse(self, **kwargs): except ValueError as exception: self.logger.error(exception) return self.exit_codes.ERROR_INVALID_OUTPUT - except Exception as exception: - self.logger.error(exception) - return self.exit_codes.ERROR_INVALID_OUTPUT def find_output(self, name): """Find the output with the given name.""" diff --git a/tests/test_data.py b/tests/test_data.py index 82c8f9c..a40b604 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,4 +1,5 @@ import aiida +from aiida_pythonjob.data import general_serializer from aiida_pythonjob.utils import get_required_imports @@ -44,3 +45,10 @@ def test_atoms_data(): atoms_data = AtomsData(atoms) assert atoms_data.value == atoms + + +def test_only_data_with_value(): + try: + general_serializer(aiida.orm.List([1])) + except ValueError as e: + assert str(e) == "Only AiiDA data Node with a value attribute is allowed." diff --git a/tests/test_parser.py b/tests/test_parser.py index d020469..1bef763 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -3,25 +3,26 @@ import cloudpickle as pickle from aiida import orm +from aiida.cmdline.utils.common import get_workchain_report from aiida.common.links import LinkType from aiida_pythonjob.parsers import PythonJobParser -def create_retrieved_folder(result: dict): +def create_retrieved_folder(result: dict, output_filename="results.pickle"): # Create a retrieved ``FolderData`` node with results with tempfile.TemporaryDirectory() as tmpdir: dirpath = pathlib.Path(tmpdir) - with open((dirpath / "results.pickle"), "wb") as handle: + with open((dirpath / output_filename), "wb") as handle: pickle.dump(result, handle) folder_data = orm.FolderData(tree=dirpath.absolute()) return folder_data -def create_process_node(result: dict, function_data: dict): +def create_process_node(result: dict, function_data: dict, output_filename: str = "results.pickle"): node = orm.CalcJobNode() node.set_process_type("aiida.calculations:pythonjob.pythonjob") function_data = orm.Dict(function_data) - retrieved = create_retrieved_folder(result) + retrieved = create_retrieved_folder(result, output_filename=output_filename) node.base.links.add_incoming(function_data, link_type=LinkType.INPUT_CALC, link_label="function_data") retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label="retrieved") function_data.store() @@ -30,8 +31,8 @@ def create_process_node(result: dict, function_data: dict): return node -def create_parser(result, function_data): - node = create_process_node(result, function_data) +def create_parser(result, function_data, output_filename="results.pickle"): + node = create_process_node(result, function_data, output_filename=output_filename) parser = PythonJobParser(node=node) return parser @@ -55,11 +56,13 @@ def test_tuple_result_mismatch(fixture_localhost): def test_dict_result(fixture_localhost): result = {"a": 1, "b": 2, "c": 3} - function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + function_data = {"outputs": [{"name": "a"}, {"name": "b"}]} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code is None - assert len(parser.outputs) == 3 + assert len(parser.outputs) == 2 + report = get_workchain_report(parser.node, levelname="WARNING") + assert "Found extra results that are not included in the output: dict_keys(['c'])" in report def test_dict_result_missing(fixture_localhost): @@ -70,6 +73,27 @@ def test_dict_result_missing(fixture_localhost): assert exit_code == parser.exit_codes.ERROR_MISSING_OUTPUT +def test_dict_result_as_one_output(fixture_localhost): + result = {"a": 1, "b": 2, "c": 3} + function_data = {"outputs": [{"name": "result"}]} + parser = create_parser(result, function_data) + exit_code = parser.parse() + assert exit_code is None + assert len(parser.outputs) == 1 + assert parser.outputs["result"] == result + + +def test_dict_result_only_show_one_output(fixture_localhost): + result = {"a": 1, "b": 2} + function_data = {"outputs": [{"name": "a"}]} + parser = create_parser(result, function_data) + parser.parse() + assert len(parser.outputs) == 1 + assert parser.outputs["a"] == 1 + report = get_workchain_report(parser.node, levelname="WARNING") + assert "Found extra results that are not included in the output: dict_keys(['b'])" in report + + def test_exit_code(fixture_localhost): result = {"exit_code": {"status": 1, "message": "error"}} function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} @@ -78,3 +102,11 @@ def test_exit_code(fixture_localhost): assert exit_code is not None assert exit_code.status == 1 assert exit_code.message == "error" + + +def test_no_output_file(fixture_localhost): + result = {"a": 1, "b": 2, "c": 3} + function_data = {"outputs": [{"name": "result"}]} + parser = create_parser(result, function_data, output_filename="not_results.pickle") + exit_code = parser.parse() + assert exit_code == parser.exit_codes.ERROR_READING_OUTPUT_FILE diff --git a/tests/test_pythonjob.py b/tests/test_pythonjob.py index c1f4cb7..786659d 100644 --- a/tests/test_pythonjob.py +++ b/tests/test_pythonjob.py @@ -205,6 +205,30 @@ def add(x, y): assert "result.txt" in result["retrieved"].list_object_names() +def test_copy_files(fixture_localhost): + """Test function with copy files.""" + + def add(x, y): + z = x + y + with open("result.txt", "w") as f: + f.write(str(z)) + + def multiply(x_folder_name, y): + with open(f"{x_folder_name}/result.txt", "r") as f: + x = int(f.read()) + return x * y + + inputs = prepare_pythonjob_inputs(add, function_inputs={"x": 1, "y": 2}) + result, node = run_get_node(PythonJob, inputs=inputs) + inputs = prepare_pythonjob_inputs( + multiply, + function_inputs={"x_folder_name": "x_folder_name", "y": 2}, + copy_files={"x_folder_name": result["remote_folder"]}, + ) + result, node = run_get_node(PythonJob, inputs=inputs) + assert result["result"].value == 6 + + def test_exit_code(fixture_localhost): """Test function with exit code.""" from numpy import array