Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation fixes #350

Merged
merged 9 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"h5py>=3.6.0",
"xarray>=0.20.2",
"PyYAML>=6.0",
"numpy>=1.21.2",
"numpy>=1.21.2,<2.0.0",
"pandas>=1.3.2",
"ase>=3.19.0",
"mergedeep",
Expand Down
123 changes: 106 additions & 17 deletions src/pynxtools/dataconverter/nexus_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
It also allows for adding further nodes from the inheritance chain on the fly.
"""

from functools import reduce
from typing import Any, List, Literal, Optional, Set, Tuple, Union

import lxml.etree as ET
Expand All @@ -41,6 +42,7 @@
is_appdef,
remove_namespace_from_tag,
)
from pynxtools.definitions.dev_tools.utils.nxdl_utils import get_nx_namefit

NexusType = Literal[
"NX_BINARY",
Expand Down Expand Up @@ -139,15 +141,18 @@ class NexusNode(NodeMixin):
optionality: Literal["required", "recommended", "optional"] = "required"
variadic: bool = False
inheritance: List[ET._Element]
is_a: List["NexusNode"]
parent_of: List["NexusNode"]

def _set_optionality(self):
if not self.inheritance:
return
if self.inheritance[0].attrib.get("recommended"):
self.optionality = "recommended"
elif (
self.inheritance[0].attrib.get("optional")
or self.inheritance[0].attrib.get("minOccurs") == "0"
elif self.inheritance[0].attrib.get("required"):
self.optionality = "required"
elif self.inheritance[0].attrib.get("optional") or (
isinstance(self, NexusGroup) and self.occurrence_limits[0] == 0
):
self.optionality = "optional"

Expand All @@ -172,6 +177,8 @@ def __init__(
else:
self.inheritance = []
self.parent = parent
self.is_a = []
self.parent_of = []

def _construct_inheritance_chain_from_parent(self):
if self.parent is None:
Expand Down Expand Up @@ -221,18 +228,33 @@ def search_child_with_name(
direct_child = next((x for x in self.children if x.name == name), None)
if direct_child is not None:
return direct_child
if name in self.get_all_children_names():
if name in self.get_all_direct_children_names():
return self.add_inherited_node(name)
return None

def get_all_children_names(
self, depth: Optional[int] = None, only_appdef: bool = False
def get_all_direct_children_names(
self,
node_type: Optional[str] = None,
nx_class: Optional[str] = None,
depth: Optional[int] = None,
only_appdef: bool = False,
) -> Set[str]:
"""
Get all children names of the current node up to a certain depth.
Only `field`, `group` `choice` or `attribute` are considered as children.

Args:
node_type (Optional[str], optional):
The tags of the children to consider.
This should either be "field", "group", "choice" or "attribute".
If None all tags are considered.
Defaults to None.
nx_class (Optional[str], optional):
The NeXus class of the group to consider.
This is only used if `node_type` is "group".
It should contain the preceding `NX` and the class name in lowercase,
e.g., "NXentry".
Defaults to None.
depth (Optional[int], optional):
The inheritance depth up to which get children names.
`depth=1` will return only the children of the current node.
Expand All @@ -251,18 +273,24 @@ def get_all_children_names(
if depth is not None and (not isinstance(depth, int) or depth < 0):
raise ValueError("Depth must be a positive integer or None")

tag_type = ""
if node_type == "group" and nx_class is not None:
tag_type = f"[@type='{nx_class}']"

if node_type is not None:
search_tags = f"nx:{node_type}{tag_type}"
else:
search_tags = (
"*[self::nx:field or self::nx:group "
"or self::nx:attribute or self::nx:choice]"
)

names = set()
for elem in self.inheritance[:depth]:
if only_appdef and not is_appdef(elem):
break

for subelems in elem.xpath(
(
r"*[self::nx:field or self::nx:group "
r"or self::nx:attribute or self::nx:choice]"
),
namespaces=namespaces,
):
for subelems in elem.xpath(search_tags, namespaces=namespaces):
if "name" in subelems.attrib:
names.add(subelems.attrib["name"])
elif "type" in subelems.attrib:
Expand Down Expand Up @@ -360,6 +388,33 @@ def _build_inheritance_chain(self, xml_elem: ET._Element) -> List[ET._Element]:
else f"nx:group[@type='{xml_elem.attrib['type']}']",
namespaces=namespaces,
)
if not inherited_elem and name is not None:
# Try to namefit
groups = elem.findall(
f"nx:group[@type='{xml_elem.attrib['type']}']",
namespaces=namespaces,
)
best_group = None
best_score = -1
for group in groups:
if name in group.attrib and not contains_uppercase(
group.attrib["name"]
):
continue
group_name = (
group.attrib.get("name")
if "name" in group.attrib
else group.attrib["type"][2:].upper()
)

score = get_nx_namefit(name, group_name)
if get_nx_namefit(name, group_name) >= best_score:
best_group = group
best_score = score

if best_group is not None:
inherited_elem = [best_group]

if inherited_elem and inherited_elem[0] not in inheritance_chain:
inheritance_chain.append(inherited_elem[0])
bc_xml_root, _ = get_nxdl_root_and_path(xml_elem.attrib["type"])
Expand Down Expand Up @@ -432,13 +487,15 @@ def add_inherited_node(self, name: str) -> Optional["NexusNode"]:
"""
for elem in self.inheritance:
xml_elem = elem.xpath(
f"*[self::nx:field or self::nx:group or self::nx:attribute][@name='{name}']",
"*[self::nx:field or self::nx:group or"
f" self::nx:attribute or self::nx:choice][@name='{name}']",
namespaces=namespaces,
)
if not xml_elem:
# Find group by naming convention
xml_elem = elem.xpath(
f"*[self::nx:group][@type='NX{name.lower()}']",
"*[self::nx:group or self::nx:choice]"
f"[@type='NX{name.lower()}' and not(@name)]",
namespaces=namespaces,
)
if xml_elem:
Expand All @@ -462,7 +519,7 @@ class NexusChoice(NexusNode):
type: Literal["choice"] = "choice"

def __init__(self, **data) -> None:
super().__init__(**data)
super().__init__(type=self.type, **data)
self._construct_inheritance_chain_from_parent()
self._set_optionality()

Expand All @@ -489,6 +546,37 @@ class NexusGroup(NexusNode):
Optional[int],
] = (None, None)

def _check_sibling_namefit(self):
if not self.variadic:
return
domna marked this conversation as resolved.
Show resolved Hide resolved
for sibling in self.parent.get_all_direct_children_names(
node_type=self.type, nx_class=self.nx_class
):
if sibling == self.name or not contains_uppercase(sibling):
continue
if get_nx_namefit(sibling, self.name) >= -1:
fit = self.parent.search_child_with_name(sibling)
if (
self.inheritance[0] != fit.inheritance[0]
and fit.inheritance[0] in self.inheritance
):
fit.is_a.append(self)
self.parent_of.append(fit)

min_occurs = (
0 if self.occurrence_limits[0] is None else self.occurrence_limits[0]
)
min_occurs = 1 if self.optionality == "required" else min_occurs

required_children = reduce(
lambda x, y: x + (1 if y.optionality == "required" else 0),
self.parent_of,
0,
)

if required_children >= min_occurs:
self.optionality = "optional"

def _set_occurence_limits(self):
if not self.inheritance:
return
Expand All @@ -511,6 +599,7 @@ def __init__(self, nx_class: str, **data) -> None:
self.nx_class = nx_class
self._set_occurence_limits()
self._set_optionality()
self._check_sibling_namefit()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -638,7 +727,7 @@ def populate_tree_from_parents(node: NexusNode):
node (NexusNode):
The current node from which to populate the tree.
"""
for child in node.get_all_children_names(only_appdef=True):
for child in node.get_all_direct_children_names(only_appdef=True):
child_node = node.search_child_with_name(child)
populate_tree_from_parents(child_node)

Expand Down
37 changes: 22 additions & 15 deletions src/pynxtools/dataconverter/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def get_variations_of(node: NexusNode, keys: Mapping[str, Any]) -> List[str]:
continue
if (
get_nx_namefit(name2fit, node.name) >= 0
and key not in node.parent.get_all_children_names()
and key not in node.parent.get_all_direct_children_names()
):
variations.append(key)
if nx_name is not None and not variations:
Expand Down Expand Up @@ -239,7 +239,7 @@ def check_nxdata():
data_node = node.search_child_with_name((signal, "DATA"))
data_bc_node = node.search_child_with_name("DATA")
data_node.inheritance.append(data_bc_node.inheritance[0])
for child in data_node.get_all_children_names():
for child in data_node.get_all_direct_children_names():
data_node.search_child_with_name(child)

handle_field(
Expand Down Expand Up @@ -271,7 +271,7 @@ def check_nxdata():
axis_node = node.search_child_with_name((axis, "AXISNAME"))
axis_bc_node = node.search_child_with_name("AXISNAME")
axis_node.inheritance.append(axis_bc_node.inheritance[0])
for child in axis_node.get_all_children_names():
for child in axis_node.get_all_direct_children_names():
axis_node.search_child_with_name(child)

handle_field(
Expand Down Expand Up @@ -328,13 +328,24 @@ def check_nxdata():

def handle_group(node: NexusGroup, keys: Mapping[str, Any], prev_path: str):
variants = get_variations_of(node, keys)
domna marked this conversation as resolved.
Show resolved Hide resolved
if not variants:
if node.optionality == "required" and node.type in missing_type_err:
collector.collect_and_log(
f"{prev_path}/{node.name}", missing_type_err.get(node.type), None
)
if node.parent_of:
for child in node.parent_of:
variants += get_variations_of(child, keys)
if (
not variants
and node.optionality == "required"
and node.type in missing_type_err
):
collector.collect_and_log(
f"{prev_path}/{node.name}",
missing_type_err.get(node.type),
None,
)
return
for variant in variants:
if variant in [node.name for node in node.parent_of]:
# Don't process if this is actually a sub-variant of this group
continue
nx_class, _ = split_class_and_name_of(variant)
if not isinstance(keys[variant], Mapping):
if nx_class is not None:
Expand Down Expand Up @@ -499,16 +510,12 @@ def is_documented(key: str, node: NexusNode) -> bool:
return True

for name in key[1:].replace("@", "").split("/"):
children = node.get_all_children_names()
children = node.get_all_direct_children_names()
best_name = best_namefit_of(name, children)
if best_name is None:
return False

resolver = Resolver("name", relax=True)
child_node = resolver.get(node, best_name)
node = (
node.add_inherited_node(best_name) if child_node is None else child_node
)
node = node.search_child_with_name(best_name)

if isinstance(mapping[key], dict) and "link" in mapping[key]:
# TODO: Follow link and check consistency with current field
Expand Down Expand Up @@ -612,7 +619,7 @@ def populate_full_tree(node: NexusNode, max_depth: Optional[int] = 5, depth: int
# but it does while recursing the tree and it should
# be fixed.
return
for child in node.get_all_children_names():
for child in node.get_all_direct_children_names():
child_node = node.search_child_with_name(child)
populate_full_tree(child_node, max_depth=max_depth, depth=depth + 1)

Expand Down
3 changes: 2 additions & 1 deletion tests/dataconverter/test_nexus_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def test_correct_extension_of_tree():
def get_node_fields(tree: NexusNode) -> List[Tuple[str, Any]]:
return list(
filter(
lambda x: not x[0].startswith("_") and x[0] not in "inheritance",
lambda x: not x[0].startswith("_")
and x[0] not in ("inheritance", "is_a", "parent_of"),
tree.__dict__.items(),
)
)
Expand Down
Loading
Loading