Skip to content

Commit

Permalink
Improved get_children_names function
Browse files Browse the repository at this point in the history
  • Loading branch information
domna committed May 21, 2024
1 parent 4fff6a2 commit c040e52
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 29 deletions.
40 changes: 31 additions & 9 deletions pynxtools/dataconverter/nexus_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,32 @@ def search_child_with_name(
direct_child = next((x for x in self.children if x.name == name), None)
if direct_child is not None:
return direct_child
if name in self.get_all_children_names():
if name in self.get_all_direct_children_names():
return self.add_inherited_node(name)
return None

def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:
def get_all_direct_children_names(
self,
node_type: Optional[str] = None,
nx_class: Optional[str] = None,
depth: Optional[int] = None,
) -> Set[str]:
"""
Get all children names of the current node up to a certain depth.
Only `field`, `group` `choice` or `attribute` are considered as children.
Args:
node_type (Optional[str], optional):
The tags of the children to consider.
This should either be "field", "group", "choice" or "attribute".
If None all tags are considered.
Defaults to None.
nx_class (Optional[str], optional):
The NeXus class of the group to consider.
This is only used if `node_type` is "group".
It should contain the preceding `NX` and the class name in lowercase,
e.g., "NXentry".
Defaults to None.
depth (Optional[int], optional):
The inheritance depth up to which get children names.
`depth=1` will return only the children of the current node.
Expand All @@ -229,15 +245,21 @@ def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:
if depth is not None and (not isinstance(depth, int) or depth < 0):
raise ValueError("Depth must be a positive integer or None")

tag_type = ""
if node_type == "group" and nx_class is not None:
tag_type = f"[@type='{nx_class}']"

if node_type is not None:
search_tags = f"*[self::nx:{node_type}{tag_type}]"
else:
search_tags = (
r"*[self::nx:field or self::nx:group "
r"or self::nx:attribute or self::nx:choice]"
)

names = set()
for elem in self.inheritance[:depth]:
for subelems in elem.xpath(
(
r"*[self::nx:field or self::nx:group "
r"or self::nx:attribute or self::nx:choice]"
),
namespaces=namespaces,
):
for subelems in elem.xpath(search_tags, namespaces=namespaces):
if "name" in subelems.attrib:
names.add(subelems.attrib["name"])
elif "type" in subelems.attrib:
Expand Down
53 changes: 33 additions & 20 deletions pynxtools/dataconverter/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,20 @@
from pynxtools.definitions.dev_tools.utils.nxdl_utils import get_nx_namefit


def best_namefit_of_(
name: str, concepts: Set[str], nx_class: Optional[str] = None
) -> str:
# TODO: Find the best namefit of name in concepts
# Consider nx_class if it is not None
...
def best_namefit_of_(name: str, concepts: Set[str]) -> str:
if not concepts:
return None

if name in concepts:
return name

best_match, score = max(
map(lambda x: (x, get_nx_namefit(name, x)), concepts), key=lambda x: x[1]
)
if score < 0:
return None

return best_match


def validate_hdf_group_against(appdef: str, data: h5py.Group):
Expand All @@ -64,9 +72,11 @@ def validate_hdf_group_against(appdef: str, data: h5py.Group):
# Allow for 10000 cache entries. This should be enough for most cases
@cached(
cache=LRUCache(maxsize=10000),
key=lambda path, _: hashkey(path),
key=lambda path, *_: hashkey(path),
)
def find_node_for(path: str, nx_class: Optional[str] = None) -> Optional[NexusNode]:
def find_node_for(
path: str, node_type: Optional[str] = None, nx_class: Optional[str] = None
) -> Optional[NexusNode]:
if path == "":
return tree

Expand All @@ -75,10 +85,7 @@ def find_node_for(path: str, nx_class: Optional[str] = None) -> Optional[NexusNo

best_child = best_namefit_of_(
last_elem,
# TODO: Consider renaming `get_all_children_names` to
# `get_all_direct_children_names`. Because that's what it is.
node.get_all_children_names(),
nx_class,
node.get_all_direct_children_names(nx_class=nx_class, node_type=node_type),
)
if best_child is None:
return None
Expand All @@ -92,15 +99,19 @@ def remove_from_req_fields(path: str):
def handle_group(path: str, data: h5py.Group):
node = find_node_for(path, data.attrs.get("NX_class"))
if node is None:
# TODO: Log undocumented
collector.collect_and_log(
path, ValidationProblem.MissingDocumentation, None
)
return

# TODO: Do actual group checks

def handle_field(path: str, data: h5py.Dataset):
node = find_node_for(path)
if node is None:
# TODO: Log undocumented
collector.collect_and_log(
path, ValidationProblem.MissingDocumentation, None
)
return
remove_from_req_fields(f"{path}")

Expand All @@ -110,7 +121,9 @@ def handle_attributes(path: str, attribute_names: h5py.AttributeManager):
for attr_name in attribute_names:
node = find_node_for(f"{path}/{attr_name}")
if node is None:
# TODO: Log undocumented
collector.collect_and_log(
path, ValidationProblem.MissingDocumentation, None
)
continue
remove_from_req_fields(f"{path}/@{attr_name}")

Expand Down Expand Up @@ -282,7 +295,7 @@ def get_variations_of(node: NexusNode, keys: Mapping[str, Any]) -> List[str]:
continue
if (
get_nx_namefit(name2fit, node.name) >= 0
and key not in node.parent.get_all_children_names()
and key not in node.parent.get_all_direct_children_names()
):
variations.append(key)
if nx_name is not None and not variations:
Expand Down Expand Up @@ -315,7 +328,7 @@ def check_nxdata():
data_node = node.search_child_with_name((signal, "DATA"))
data_bc_node = node.search_child_with_name("DATA")
data_node.inheritance.append(data_bc_node.inheritance[0])
for child in data_node.get_all_children_names():
for child in data_node.get_all_direct_children_names():
data_node.search_child_with_name(child)

handle_field(
Expand Down Expand Up @@ -347,7 +360,7 @@ def check_nxdata():
axis_node = node.search_child_with_name((axis, "AXISNAME"))
axis_bc_node = node.search_child_with_name("AXISNAME")
axis_node.inheritance.append(axis_bc_node.inheritance[0])
for child in axis_node.get_all_children_names():
for child in axis_node.get_all_direct_children_names():
axis_node.search_child_with_name(child)

handle_field(
Expand Down Expand Up @@ -575,7 +588,7 @@ def is_documented(key: str, node: NexusNode) -> bool:
return True

for name in key[1:].replace("@", "").split("/"):
children = node.get_all_children_names()
children = node.get_all_direct_children_names()
best_name = best_namefit_of(name, children)
if best_name is None:
return False
Expand Down Expand Up @@ -688,7 +701,7 @@ def populate_full_tree(node: NexusNode, max_depth: Optional[int] = 5, depth: int
# but it does while recursing the tree and it should
# be fixed.
return
for child in node.get_all_children_names():
for child in node.get_all_direct_children_names():
print(child)
child_node = node.search_child_with_name(child)
populate_full_tree(child_node, max_depth=max_depth, depth=depth + 1)
Expand Down

0 comments on commit c040e52

Please sign in to comment.