diff --git a/dev_tools/docs/anchor_list.py b/dev_tools/docs/anchor_list.py index df668ead01..1b411041df 100644 --- a/dev_tools/docs/anchor_list.py +++ b/dev_tools/docs/anchor_list.py @@ -114,6 +114,10 @@ def write(self): return contents = dict( _metadata=dict( + # datetime=datetime.datetime.now(datetime.UTC).isoformat(), + # the next line is the py3.9 supported way of getting the datetime + # this will become deprecated however in py3.12 for which the + # line above-mentioned is a fix, which however does not work in py3.9 datetime=datetime.datetime.utcnow().isoformat(), title="NeXus NXDL vocabulary.", subtitle="Anchors for all NeXus fields, groups, " diff --git a/dev_tools/docs/nxdl.py b/dev_tools/docs/nxdl.py index 15d50948ab..3d22baf41c 100644 --- a/dev_tools/docs/nxdl.py +++ b/dev_tools/docs/nxdl.py @@ -311,13 +311,38 @@ def _get_doc_blocks(ns, node): out_blocks.append("\n".join(out_lines)) return out_blocks + def _handle_multiline_docstring(self, blocks): + link_pattern = re.compile(r"\.\. _([^:]+):(.*)") + + links = [] + docstring = "" + expanded_blocks = [] + + for block in blocks: + expanded_blocks += block.split("\n") + + for block in expanded_blocks: + if not block: + continue + + link_match = link_pattern.search(block) + if link_match is not None: + links.append((link_match.group(1), link_match.group(2).strip())) + else: + docstring += " " + block.strip().replace("\n", " ") + + for name, target in links: + docstring = docstring.replace(f"`{name}`_", f"`{name} <{target}>`_") + + return docstring + def _get_doc_line(self, ns, node): blocks = self._get_doc_blocks(ns, node) if len(blocks) == 0: return "" if len(blocks) > 1: - raise Exception(f"Unexpected multi-paragraph doc [{'|'.join(blocks)}]") - return re.sub(r"\n", " ", blocks[0]) + return self._handle_multiline_docstring(blocks) + return blocks[0].replace("\n", " ") def _get_minOccurs(self, node): """ diff --git a/dev_tools/tests/test_nxdl_utils.py b/dev_tools/tests/test_nxdl_utils.py index 7b10494bf3..2f7e677685 100644 --- a/dev_tools/tests/test_nxdl_utils.py +++ b/dev_tools/tests/test_nxdl_utils.py @@ -2,9 +2,10 @@ """ -import os +from pathlib import Path import lxml.etree as ET +import pytest from ..utils import nxdl_utils as nexus @@ -32,8 +33,8 @@ def test_get_nexus_classes_units_attributes(): def test_get_node_at_nxdl_path(): """Test to verify if we receive the right XML element for a given NXDL path""" - local_dir = os.path.abspath(os.path.dirname(__file__)) - nxdl_file_path = os.path.join(local_dir, "./NXtest.nxdl.xml") + local_dir = Path(__file__).resolve().parent + nxdl_file_path = local_dir / "NXtest.nxdl.xml" elem = ET.parse(nxdl_file_path).getroot() node = nexus.get_node_at_nxdl_path("/ENTRY/NXODD_name", elem=elem) assert node.attrib["type"] == "NXdata" @@ -48,11 +49,144 @@ def test_get_node_at_nxdl_path(): ) assert node.attrib["name"] == "long_name" + nxdl_file_path = ( + local_dir.parent.parent / "contributed_definitions" / "NXiv_temp.nxdl.xml" + ) + elem = ET.parse(nxdl_file_path).getroot() + node = nexus.get_node_at_nxdl_path( + "/ENTRY/INSTRUMENT/ENVIRONMENT/voltage_controller", elem=elem + ) + assert node.attrib["name"] == "voltage_controller" + + node = nexus.get_node_at_nxdl_path( + "/ENTRY/INSTRUMENT/ENVIRONMENT/voltage_controller/calibration_time", elem=elem + ) + assert node.attrib["name"] == "calibration_time" + def test_get_inherited_nodes(): """Test to verify if we receive the right XML element list for a given NXDL path.""" - local_dir = os.path.abspath(os.path.dirname(__file__)) - nxdl_file_path = os.path.join(local_dir, "./NXtest.nxdl.xml") + local_dir = Path(__file__).resolve().parent + nxdl_file_path = local_dir / "NXtest.nxdl.xml" + elem = ET.parse(nxdl_file_path).getroot() (_, _, elist) = nexus.get_inherited_nodes(nxdl_path="/ENTRY/NXODD_name", elem=elem) assert len(elist) == 3 + + nxdl_file_path = ( + local_dir.parent.parent / "contributed_definitions" / "NXiv_temp.nxdl.xml" + ) + + elem = ET.parse(nxdl_file_path).getroot() + (_, _, elist) = nexus.get_inherited_nodes( + nxdl_path="/ENTRY/INSTRUMENT/ENVIRONMENT", elem=elem + ) + assert len(elist) == 3 + + (_, _, elist) = nexus.get_inherited_nodes( + nxdl_path="/ENTRY/INSTRUMENT/ENVIRONMENT/voltage_controller", elem=elem + ) + assert len(elist) == 4 + + (_, _, elist) = nexus.get_inherited_nodes( + nxdl_path="/ENTRY/INSTRUMENT/ENVIRONMENT/voltage_controller", + nx_name="NXiv_temp", + ) + assert len(elist) == 4 + + +@pytest.mark.parametrize( + "hdf_name,concept_name,should_fit", + [ + ("source_pump", "sourceType", False), + ("source_pump", "sourceTYPE", True), + ("source pump", "sourceTYPE", False), + ("source", "sourceTYPE", False), + ("source123", "SOURCE", True), + ("1source", "SOURCE", True), + ("_source", "SOURCE", True), + ("same_name", "same_name", True), + ("angular_energy_resolution", "angularNresolution", True), + ("angularresolution", "angularNresolution", False), + ("Name with some whitespaces in it", "ENTRY", False), + ("simple_name", "TEST", True), + (".test", "TEST", False), + ], +) +def test_namefitting(hdf_name, concept_name, should_fit): + """Test namefitting of nexus concept names""" + if should_fit: + assert nexus.get_nx_namefit(hdf_name, concept_name, name_partial=True) > -1 + else: + assert nexus.get_nx_namefit(hdf_name, concept_name, name_partial=True) == -1 + + +@pytest.mark.parametrize( + "hdf_name,concept_name, score", + [ + ("test_name", "TEST_name", 9), + ("te_name", "TEST_name", 7), + ("my_other_name", "TEST_name", 5), + ("test_name", "test_name", 18), + ("test_other", "test_name", -1), + ("my_fancy_yet_long_name", "my_SOME_name", 8), + ("something", "XXXX", 0), + ("something", "OTHER", 1), + ], +) +def test_namefitting_scores(hdf_name, concept_name, score): + """Test namefitting of nexus concept names""" + assert nexus.get_nx_namefit(hdf_name, concept_name, name_partial=True) == score + + +@pytest.mark.parametrize( + "better_fit,better_ref,worse_fit,worse_ref", + [ + ("sourcetype", "sourceTYPE", "source_pump", "sourceTYPE"), + ("source_pump", "sourceTYPE", "source_pump", "TEST"), + ], +) +def test_namefitting_precedence(better_fit, better_ref, worse_fit, worse_ref): + """Test if namefitting follows proper precedence rules""" + + assert nexus.get_nx_namefit( + better_fit, better_ref, name_partial=True + ) > nexus.get_nx_namefit(worse_fit, worse_ref) + + +@pytest.mark.parametrize( + "string_obj, decode, expected", + [ + # Test with lists of bytes and strings + ([b"bytes", "string"], True, ["bytes", "string"]), + ([b"bytes", "string"], False, [b"bytes", "string"]), + ([b"bytes", b"more_bytes", "string"], True, ["bytes", "more_bytes", "string"]), + ( + [b"bytes", b"more_bytes", "string"], + False, + [b"bytes", b"more_bytes", "string"], + ), + ([b"fixed", b"length", b"strings"], True, ["fixed", "length", "strings"]), + ([b"fixed", b"length", b"strings"], False, [b"fixed", b"length", b"strings"]), + # Test with nested lists + ([[b"nested1"], [b"nested2"]], True, [["nested1"], ["nested2"]]), + ([[b"nested1"], [b"nested2"]], False, [[b"nested1"], [b"nested2"]]), + # Test with bytes + (b"single", True, "single"), + (b"single", False, b"single"), + # Test with str + ("single", True, "single"), + ("single", False, "single"), + # Test with int + (123, True, 123), + (123, False, 123), + ], +) +def test_decode_or_not(string_obj, decode, expected): + # Handle normal cases + result = nexus.decode_or_not(elem=string_obj, decode=decode) + if isinstance(expected, list): + assert isinstance(result, list), f"Expected list, but got {type(result)}" + # Handle all other cases + else: + assert result == expected, f"Failed for {string_obj} with decode={decode}" diff --git a/dev_tools/utils/nxdl_utils.py b/dev_tools/utils/nxdl_utils.py index 50e033c681..c74b02e74e 100644 --- a/dev_tools/utils/nxdl_utils.py +++ b/dev_tools/utils/nxdl_utils.py @@ -1,6 +1,5 @@ # pylint: disable=too-many-lines -"""Parse NeXus definition files -""" +"""Parse NeXus definition files""" import os import re @@ -8,11 +7,51 @@ from functools import lru_cache from glob import glob from pathlib import Path +from typing import List +from typing import Optional import lxml.etree as ET from lxml.etree import ParseError as xmlER +def decode_or_not(elem, encoding: str = "utf-8", decode: bool = True): + """ + Decodes a byte array to a string if necessary. All other types are returned untouched. + If `decode` is False, the initial value is returned without decoding, including for byte arrays. + + Args: + elem: Any Python object that may need decoding. + encoding: The encoding scheme to use. Default is "utf-8". + decode: A boolean flag indicating whether to perform decoding. + + Returns: + A decoded string (in case of a byte string) or the initial value. + If `decode` is False, always returns the initial value. + + Raises: + UnicodeDecodeError: If a byte string cannot be decoded using the provided encoding. + """ + if not decode: + return elem + + # Handle lists of bytes or strings + elif isinstance(elem, list): + if not elem: + return elem # Return an empty list unchanged + + decoded_list = [decode_or_not(x, encoding, decode) for x in elem] + return decoded_list + + if isinstance(elem, bytes): + try: + return elem.decode(encoding) + except UnicodeDecodeError as e: + e.add_note(f"Error decoding bytes object: {elem}") + raise + + return elem + + def remove_namespace_from_tag(tag): """Helper function to remove the namespace from an XML tag.""" @@ -46,8 +85,8 @@ def get_app_defs_names(): Path(nexus_def_path) / "contributed_definitions" / "*.nxdl.xml" ) - files = sorted(glob(app_def_path_glob)) - for nexus_file in sorted(contrib_def_path_glob): + files = sorted(glob(str(app_def_path_glob))) + for nexus_file in sorted(glob(str(contrib_def_path_glob))): root = get_xml_root(nexus_file) if root.attrib["category"] == "application": files.append(nexus_file) @@ -93,7 +132,13 @@ def get_hdf_info_parent(hdf_info): """Get the hdf_info for the parent of an hdf_node in an hdf_info""" if "hdf_path" not in hdf_info: return {"hdf_node": hdf_info["hdf_node"].parent} - node = get_hdf_parent(hdf_info) + node = ( + get_hdf_root(hdf_info["hdf_node"]) + if "hdf_root" not in hdf_info + else hdf_info["hdf_root"] + ) + for child_name in hdf_info["hdf_path"].split("/")[1:-1]: + node = node[child_name] return {"hdf_node": node, "hdf_path": get_parent_path(hdf_info["hdf_path"])} @@ -104,33 +149,85 @@ def get_nx_class(nxdl_elem): return nxdl_elem.attrib.get("type", "NX_CHAR") -def get_nx_namefit(hdf_name, name, name_any=False): - """Checks if an HDF5 node name corresponds to a child of the NXDL element - uppercase letters in front can be replaced by arbitrary name, but - uppercase to lowercase match is preferred, - so such match is counted as a measure of the fit""" +def get_nx_namefit( + hdf_name: str, name: str, name_any: bool = False, name_partial: bool = False +) -> int: + """ + Checks if an HDF5 node name corresponds to a child of the NXDL element. + Groups of uppercase letters anywhere in the name are treated as freely + choosable parts of this name. + + If a match is found, this function returns twice the length of the + name for an exact match. If there is no exact match, the function + returns the number of matching characters (case insensitive). If + `name_any` is set to True, it returns zero instead of a count of + matches. All uppercase groups are considered independently, and + lowercase matches do not depend on uppercase group lengths. For example, + calling `get_nx_namefit("my_fancy_yet_long_name", "my_SOME_name")` + would return a score of 8 for the lowercase matches `my_..._name`. + + All characters in `[a-zA-Z0-9_.]` are considered for matching to an + uppercase letter. Any other character in the name will result in + a non-match and return -1. Periods at the beginning or end of the + `hdf_name` are not allowed; only exact matches will be considered. + + Examples: + + * `get_nx_namefit("test_name", "TEST_name")` returns 9 + * `get_nx_namefit("te_name", "TEST_name")` returns 7 + * `get_nx_namefit("my_other_name", "TEST_name")` returns 5 + * `get_nx_namefit("test_name", "test_name")` returns 18 + * `get_nx_namefit("test_other", "test_name")` returns -1 + * `get_nx_namefit("something", "XXXX")` returns 0 + * `get_nx_namefit("something", "OTHER")` returns 1 + + Args: + hdf_name (str): The hdf_name, containing the name of the HDF5 node. + name (str): The concept name to match against. + name_any (bool, optional): + Accept any name and return either 0 (match) or -1 (no match). + Defaults to False. + name_partial (bool, optional): + If set to True, the function will return the total length of the name + plus the number of matching characters, minus the count of uppercase + letters in the concept name. This allows for partial matches to + contribute to the score. Defaults to False. + + Returns: + int: -1 if no match is found or the number of matching + characters (case insensitive). + """ + path_regex = r"([a-zA-Z0-9_.]+)" + if name == hdf_name: return len(name) * 2 - # count leading capitals - counting = 0 - while counting < len(name) and name[counting].isupper(): - counting += 1 - if ( - name_any - or counting == len(name) - or (counting > 0 and hdf_name.endswith(name[counting:])) - ): # if potential fit - # count the matching chars - fit = 0 - for i in range(min(counting, len(hdf_name))): - if hdf_name[i].upper() == name[i]: - fit += 1 - else: - break - if fit == min(counting, len(hdf_name)): # accept only full fits as better fits - return fit - return 0 - return -1 # no fit + if hdf_name.startswith(".") or hdf_name.endswith("."): + # Don't match anything with a dot at the beginning or end + return -1 + + uppercase_parts = re.findall(r"[A-Z]+(?:_[A-Z]+)*", name) + + regex_name = name + uppercase_count = 0 + for up in uppercase_parts: + uppercase_count += len(up) + regex_name = regex_name.replace(up, path_regex) + + name_match = re.search(rf"^{regex_name}$", hdf_name) + if name_match is None: + return 0 if name_any else -1 + + match_count = 0 + for uppercase, match in zip(uppercase_parts, name_match.groups()): + for s1, s2 in zip(uppercase.upper(), match.upper()): + if s1 == s2: + match_count += 1 + + if name_partial: + return len(name) + match_count - uppercase_count + elif name_any: + return match_count + return -1 def get_nx_classes(): @@ -210,6 +307,28 @@ def get_node_name(node): return name +def is_name_type(child, name_type_value: str) -> bool: + """ + Determines if the child XML element's nameType attribute is equal to + the specified value or if the child is a group without nameType and name attributes. + + Args: + child: The XML element to check. + name_type_value (str): The nameType value to compare against ("any" or "partial"). + + """ + if child.attrib.get("nameType") == name_type_value: + return True + + if name_type_value == "any" and ( + get_local_name_from_xml(child) == "group" + and "nameType" not in child.attrib + and "name" not in child.attrib + ): + return True + return False + + def belongs_to(nxdl_elem, child, name, class_type=None, hdf_name=None): """Checks if an HDF5 node name corresponds to a child of the NXDL element uppercase letters in front can be replaced by arbitraty name, but @@ -222,24 +341,14 @@ def belongs_to(nxdl_elem, child, name, class_type=None, hdf_name=None): return True if not hdf_name: # search for name fits is only allowed for hdf_nodes return False - try: # check if nameType allows different name - name_any = bool(child.attrib["nameType"] == "any") - except KeyError: - name_any = False - params = [act_htmlname, chk_name, name_any, nxdl_elem, child, name] - return belongs_to_capital(params) + name_any = is_name_type(child, "any") - -def belongs_to_capital(params): - """Checking continues for Upper case""" - (act_htmlname, chk_name, name_any, nxdl_elem, child, name) = params # or starts with capital and no reserved words used - if ( - (name_any or (act_htmlname[0].isalpha() and act_htmlname[0].isupper())) - and name != "doc" - and name != "enumeration" - ): - fit = get_nx_namefit(chk_name, act_htmlname, name_any) # check if name fits + name_partial = is_name_type(child, "partial") + if (name_any or name_partial) and name != "doc" and name != "enumeration": + fit = get_nx_namefit( + chk_name, act_htmlname, name_any=name_any, name_partial=name_partial + ) # check if name fits if fit < 0: return False for child2 in nxdl_elem: @@ -251,10 +360,17 @@ def belongs_to_capital(params): ): continue # check if the name of another sibling fits better - name_any2 = child2.attrib.get("nameType") == "any" - fit2 = get_nx_namefit(chk_name, get_node_name(child2), name_any2) - if fit2 > fit: - return False + name_any2 = is_name_type(child2, "any") + name_partial2 = child2.attrib.get("nameType") == "partial" + if name_partial2 or name_any2: + fit2 = get_nx_namefit( + chk_name, + get_node_name(child2), + name_any=name_any2, + name_partial=name_partial2, + ) + if fit2 > fit: + return False # accept this fit return True return False @@ -301,18 +417,6 @@ def get_own_nxdl_child( name - nxdl name class_type - nxdl type or hdf classname (for groups, it is obligatory) hdf_name - hdf name""" - for child in nxdl_elem: - if not isinstance(child.tag, str): - continue - if child.attrib.get("name") == name: - return set_nxdlpath(child, nxdl_elem) - for child in nxdl_elem: - if not isinstance(child.tag, str): - continue - if child.attrib.get("name") == name: - child.set("nxdlbase", nxdl_elem.get("nxdlbase")) - return child - for child in nxdl_elem: if not isinstance(child.tag, str): continue @@ -402,7 +506,7 @@ def get_required_string(nxdl_elem): def write_doc_string(logger, doc, attr): """Simple function that prints a line in the logger if doc exists""" if doc: - logger.debug("@%s [NX_CHAR]", attr) + logger.debug(f"@{attr} [NX_CHAR]") return logger, doc, attr @@ -554,13 +658,15 @@ def get_doc(node, ntype, nxhtml, nxpath): doc_field = node.find("doc") if doc_field is not None: doc = doc_field.text - (index, enums) = get_enums(node) # enums - if index: + enums = get_enums(node) # enums + if enums is not None: enum_str = ( "\n " - + ("Possible values:" if enums.count(",") else "Obligatory value:") + + ("Possible values:" if len(enums) > 1 else "Obligatory value:") + "\n " - + enums + + "[" + + ",".join(enums) + + "]" + "\n" ) else: @@ -590,20 +696,26 @@ def get_namespace(element): return element.tag[element.tag.index("{") : element.tag.rindex("}") + 1] -def get_enums(node): - """Makes list of enumerations, if node contains any. - Returns comma separated STRING of enumeration values, if there are enum tag, - otherwise empty string.""" - # collect item values from enumeration tag, if any +def get_enums(node: ET._Element) -> Optional[List[str]]: + """ + Makes list of enumerations, if node contains any. + + Args: + node (ET._Element): The node to check for enumerations. + + Returns: + Optional[List[str]]: + Returns a list of the enumeration values if an enumeration was found. + If no enumeration was found it returns None. + """ namespace = get_namespace(node) enums = [] for enumeration in node.findall(f"{namespace}enumeration"): for item in enumeration.findall(f"{namespace}item"): enums.append(item.attrib["value"]) - enums = ",".join(enums) - if enums != "": - return (True, "[" + enums + "]") - return (False, "") # if there is no enumeration tag, returns empty string + if enums: + return enums + return None def add_base_classes(elist, nx_name=None, elem: ET.Element = None): @@ -731,7 +843,7 @@ def get_best_child(nxdl_elem, hdf_node, hdf_name, hdf_class_name, nexus_type): and nxdl_elem.attrib["name"] == "NXdata" and hdf_node is not None and hdf_node.parent is not None - and hdf_node.parent.attrs.get("NX_class") == "NXdata" + and decode_or_not(hdf_node.parent.attrs.get("NX_class")) == "NXdata" ): (fnd_child, fit) = get_best_nxdata_child(nxdl_elem, hdf_node, hdf_name) if fnd_child is not None: @@ -743,11 +855,15 @@ def get_best_child(nxdl_elem, hdf_node, hdf_name, hdf_class_name, nexus_type): if get_local_name_from_xml(child) == nexus_type and ( nexus_type != "group" or get_nx_class(child) == hdf_class_name ): - name_any = ( - "nameType" in nxdl_elem.attrib.keys() - and nxdl_elem.attrib["nameType"] == "any" - ) - fit = get_nx_namefit(hdf_name, get_node_name(child), name_any) + name_any = is_name_type(child, "any") + name_partial = is_name_type(child, "partial") + if name_partial or name_any: + fit = get_nx_namefit( + hdf_name, + get_node_name(child), + name_any=name_any, + name_partial=name_partial, + ) if fit > bestfit: bestfit = fit bestchild = set_nxdlpath(child, nxdl_elem) @@ -822,6 +938,9 @@ def get_node_at_nxdl_path( we are looking for or the root elem from a previously loaded NXDL file and finds the corresponding XML element with the needed attributes.""" try: + if nxdl_path.count("/") == 1 and not nxdl_path.upper().startswith("/ENTRY"): + elem = None + nx_name = "NXroot" (class_path, nxdlpath, elist) = get_inherited_nodes(nxdl_path, nx_name, elem) except ValueError as value_error: if exc: