diff --git a/src/nomad_material_processing/utils.py b/src/nomad_material_processing/utils.py index 984132b..9d0dc5c 100644 --- a/src/nomad_material_processing/utils.py +++ b/src/nomad_material_processing/utils.py @@ -16,27 +16,86 @@ # limitations under the License. # +import json +import yaml +from nomad.datamodel.context import ClientContext def get_reference(upload_id, entry_id): - return f'../uploads/{upload_id}/archive/{entry_id}#data' + return f"../uploads/{upload_id}/archive/{entry_id}" -def get_entry_id_from_file_name(file_name, archive): +def get_entry_id_from_file_name(filename, upload_id): from nomad.utils import hash - return hash(archive.metadata.upload_id, file_name) + return hash(upload_id, filename) -def create_archive(entity, archive, file_name) -> str: - import json - from nomad.datamodel.context import ClientContext - if isinstance(archive.m_context, ClientContext): + +def nan_equal(a, b): + """ + Compare two values with NaN values. + """ + if isinstance(a, float) and isinstance(b, float): + return a == b or (math.isnan(a) and math.isnan(b)) + elif isinstance(a, dict) and isinstance(b, dict): + return dict_nan_equal(a, b) + elif isinstance(a, list) and isinstance(b, list): + return list_nan_equal(a, b) + else: + return a == b + + +def list_nan_equal(list1, list2): + """ + Compare two lists with NaN values. + """ + if len(list1) != len(list2): + return False + for a, b in zip(list1, list2): + if not nan_equal(a, b): + return False + return True + + +def dict_nan_equal(dict1, dict2): + """ + Compare two dictionaries with NaN values. + """ + if set(dict1.keys()) != set(dict2.keys()): + return False + for key in dict1: + if not nan_equal(dict1[key], dict2[key]): + return False + return True + + +def create_archive( + entry_dict, context, filename, file_type, logger, *, overwrite: bool = False +): + if isinstance(context, ClientContext): return None - if not archive.m_context.raw_path_exists(file_name): - entity_entry = entity.m_to_dict(with_root_def=True) - with archive.m_context.raw_file(file_name, 'w') as outfile: - json.dump({"data": entity_entry}, outfile) - archive.m_context.process_updated_raw_file(file_name) + if context.raw_path_exists(filename): + with context.raw_file(filename, "r") as file: + existing_dict = yaml.safe_load(file) + if context.raw_path_exists(filename) and not dict_nan_equal( + existing_dict, entry_dict + ): + logger.error( + f"{filename} archive file already exists. " + f"You are trying to overwrite it with a different content. " + f"To do so, remove the existing archive and click reprocess again." + ) + if ( + not context.raw_path_exists(filename) + or existing_dict == entry_dict + or overwrite + ): + with context.raw_file(filename, "w") as newfile: + if file_type == "json": + json.dump(entry_dict, newfile) + elif file_type == "yaml": + yaml.dump(entry_dict, newfile) + context.upload.process_updated_raw_file(filename, allow_modify=True) + return get_reference( - archive.metadata.upload_id, - get_entry_id_from_file_name(file_name, archive) - ) + context.upload_id, get_entry_id_from_file_name(filename, context.upload_id) + ) \ No newline at end of file