forked from minhnh/rdf-utils
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add model parsing and tests for distributions
- support secorolab/metamodels#13 - add DistributionModel that for parsing univariate and multivariate versions of the uniform and normal distributions, as well as uniform rotation (3D) distribution - add method to sample using info from DistributionModel - add SampledQuantityModel that cache sample when requested - add unit test on a valid model that check sampling of all supported distributions - minor: use check_shacl_constraints func in test_python_model - minor: remove black precommit hook since conflicting with ruff
- Loading branch information
Showing
8 changed files
with
389 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
# SPDX-Litense-Identifier: MPL-2.0 | ||
from typing import Any, Optional | ||
import numpy as np | ||
from rdflib import BNode, Literal, URIRef, Graph | ||
from rdf_utils.collection import load_list_re | ||
from rdf_utils.models.common import ModelBase | ||
from rdf_utils.namespace import NS_MM_DISTRIB | ||
|
||
|
||
URI_DISTRIB_TYPE_DISTRIB = NS_MM_DISTRIB["Distribution"] | ||
URI_DISTRIB_PRED_DIM = NS_MM_DISTRIB["dimension"] | ||
|
||
URI_DISTRIB_TYPE_CONT = NS_MM_DISTRIB["Continuous"] | ||
URI_DISTRIB_TYPE_DISCR = NS_MM_DISTRIB["Discrete"] | ||
|
||
URI_DISTRIB_TYPE_UNIFORM = NS_MM_DISTRIB["Uniform"] | ||
URI_DISTRIB_PRED_UPPER = NS_MM_DISTRIB["upper-bound"] | ||
URI_DISTRIB_PRED_LOWER = NS_MM_DISTRIB["lower-bound"] | ||
|
||
URI_DISTRIB_TYPE_NORMAL = NS_MM_DISTRIB["Normal"] | ||
URI_DISTRIB_PRED_MEAN = NS_MM_DISTRIB["mean"] | ||
URI_DISTRIB_PRED_STD = NS_MM_DISTRIB["standard-deviation"] | ||
URI_DISTRIB_PRED_COV = NS_MM_DISTRIB["covariance"] | ||
|
||
URI_DISTRIB_TYPE_UNIFORM_ROT = NS_MM_DISTRIB["UniformRotation"] | ||
|
||
URI_DISTRIB_TYPE_SAMPLED_QUANTITY = NS_MM_DISTRIB["SampledQuantity"] | ||
URI_DISTRIB_PRED_FROM_DISTRIB = NS_MM_DISTRIB["from-distribution"] | ||
|
||
|
||
def _get_float_from_literal(literal: Literal) -> float: | ||
try: | ||
lit_val = literal.toPython() | ||
return float(lit_val) | ||
except ValueError as e: | ||
raise ValueError(f"can't convert literal '{literal}' as float: {e}") | ||
|
||
|
||
class DistributionModel(ModelBase): | ||
distrib_type: URIRef | ||
|
||
def __init__(self, distrib_id: URIRef, graph: Graph) -> None: | ||
super().__init__(node_id=distrib_id, graph=graph) | ||
|
||
if URI_DISTRIB_TYPE_UNIFORM_ROT in self.types: | ||
self.distrib_type = URI_DISTRIB_TYPE_UNIFORM_ROT | ||
elif URI_DISTRIB_TYPE_UNIFORM in self.types: | ||
self.distrib_type = URI_DISTRIB_TYPE_UNIFORM | ||
self._load_uniform_distrib_attrs(graph=graph) | ||
elif URI_DISTRIB_TYPE_NORMAL in self.types: | ||
self.distrib_type = URI_DISTRIB_TYPE_NORMAL | ||
self._load_normal_distrib_attrs(graph=graph) | ||
else: | ||
raise RuntimeError(f"Distrib '{self.id}' has unhandled types: {self.types}") | ||
|
||
def _load_uniform_distrib_attrs(self, graph: Graph) -> None: | ||
# dimension | ||
dim_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_DIM) | ||
assert isinstance( | ||
dim_node, Literal | ||
), f"Uniform distrib '{self.id}' does not have a Literal 'dimension': {dim_node}" | ||
dim = dim_node.toPython() | ||
assert ( | ||
isinstance(dim, int) and dim > 0 | ||
), f"Uniform distrib '{self.id}' does not have a positive integer 'dimension': {dim}" | ||
|
||
upper_bounds = None | ||
lower_bounds = None | ||
|
||
# upper bound(s) | ||
upper_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_UPPER) | ||
if isinstance(upper_node, Literal): | ||
upper_val = _get_float_from_literal(upper_node) | ||
upper_bounds = [upper_val] | ||
elif isinstance(upper_node, BNode): | ||
upper_bounds = load_list_re( | ||
graph=graph, first_node=upper_node, parse_uri=False, quiet=False | ||
) | ||
else: | ||
raise RuntimeError( | ||
f"Uniform distrib '{self.id}' has invalid type for :upper-bound: {type(upper_node)}" | ||
) | ||
|
||
# lower bound(s) | ||
lower_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_LOWER) | ||
if isinstance(lower_node, Literal): | ||
lower_val = _get_float_from_literal(lower_node) | ||
lower_bounds = [lower_val] | ||
elif isinstance(lower_node, BNode): | ||
lower_bounds = load_list_re( | ||
graph=graph, first_node=lower_node, parse_uri=False, quiet=False | ||
) | ||
else: | ||
raise RuntimeError( | ||
f"Uniform distrib '{self.id}' has invalid type for lower-bound: {type(lower_node)}" | ||
) | ||
|
||
# check property dimensions | ||
assert ( | ||
dim == len(lower_bounds) and dim == len(upper_bounds) | ||
), f"Uniform distrib '{self.id}' has mismatching property dimensions: dim={dim}, upper bounds num={len(upper_bounds)}, lower bounds num={len(lower_bounds)}" | ||
|
||
# check lower bounds less than higher bounds | ||
less_than = np.less(lower_bounds, upper_bounds) | ||
assert np.all( | ||
less_than | ||
), f"Uniform distrib '{self.id}': not all lower bounds less than upper bounds: lower={lower_bounds}, upper={upper_bounds}" | ||
|
||
# set attributes | ||
self.set_attr(key=URI_DISTRIB_PRED_DIM, val=dim) | ||
self.set_attr(key=URI_DISTRIB_PRED_UPPER, val=upper_bounds) | ||
self.set_attr(key=URI_DISTRIB_PRED_LOWER, val=lower_bounds) | ||
|
||
def _load_normal_distrib_attrs(self, graph: Graph) -> None: | ||
# dimension | ||
dim_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_DIM) | ||
assert isinstance( | ||
dim_node, Literal | ||
), f"Normal distrib '{self.id}' does not have a Literal 'dimension': {dim_node}" | ||
dim = dim_node.toPython() | ||
assert ( | ||
isinstance(dim, int) and dim > 0 | ||
), f"Normal distrib '{self.id}' does not have a positive integer 'dimension': {dim}" | ||
self.set_attr(key=URI_DISTRIB_PRED_DIM, val=dim) | ||
|
||
# get mean | ||
mean_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_MEAN) | ||
if isinstance(mean_node, Literal): | ||
assert ( | ||
dim == 1 | ||
), f"Normal distrib '{self.id}' has single mean '{mean_node}' but dimension '{dim}'" | ||
mean_val = _get_float_from_literal(mean_node) | ||
self.set_attr(key=URI_DISTRIB_PRED_MEAN, val=[mean_val]) | ||
elif isinstance(mean_node, BNode): | ||
mean_vals = load_list_re( | ||
graph=graph, first_node=mean_node, parse_uri=False, quiet=False | ||
) | ||
assert ( | ||
len(mean_vals) == dim | ||
), f"Normal distrib '{self.id}': number of mean values ({len(mean_vals)}) does not match dimension ({dim})" | ||
self.set_attr(key=URI_DISTRIB_PRED_MEAN, val=mean_vals) | ||
else: | ||
raise RuntimeError( | ||
f"Normal distrib '{self.id}' has invalid type for 'mean': {type(mean_node)}" | ||
) | ||
|
||
# get standard deviation or covariance based on dimension | ||
if dim == 1: | ||
std_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_STD) | ||
assert isinstance( | ||
std_node, Literal | ||
), f"Normal distrib '{self.id}' does not have a Literal 'standard-deviation': {std_node}" | ||
std = _get_float_from_literal(std_node) | ||
self.set_attr(key=URI_DISTRIB_PRED_STD, val=std) | ||
else: | ||
cov_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_COV) | ||
assert isinstance( | ||
cov_node, BNode | ||
), f"Normal distrib '{self.id}': 'covariance' property not a container, type={type(cov_node)}" | ||
cov_vals = load_list_re(graph=graph, first_node=cov_node, parse_uri=False, quiet=False) | ||
try: | ||
cov_mat = np.array(cov_vals, dtype=float) | ||
except ValueError as e: | ||
raise ValueError( | ||
f"Normal distrib '{self.id}', can't convert covariance to float numpy array: {e}\n{cov_vals}" | ||
) | ||
assert ( | ||
cov_mat.shape | ||
== ( | ||
dim, | ||
dim, | ||
) | ||
), f"Normal distrib '{self.id}': dimension='{dim}' doesn't match 'covariance' shape'{cov_mat.shape}'" | ||
self.set_attr(key=URI_DISTRIB_PRED_COV, val=cov_mat) | ||
|
||
|
||
def sample_from_distrib( | ||
distrib: DistributionModel, size: Optional[int | tuple[int, ...]] = None | ||
) -> Any: | ||
if URI_DISTRIB_TYPE_UNIFORM_ROT in distrib.types: | ||
try: | ||
from scipy.spatial.transform import Rotation | ||
except ImportError: | ||
raise RuntimeError("to sample random rotations, 'scipy' must be installed") | ||
|
||
return Rotation.random() | ||
|
||
if URI_DISTRIB_TYPE_UNIFORM in distrib.types: | ||
lower_bounds = distrib.get_attr(key=URI_DISTRIB_PRED_LOWER) | ||
upper_bounds = distrib.get_attr(key=URI_DISTRIB_PRED_UPPER) | ||
assert isinstance(lower_bounds, list) and isinstance( | ||
upper_bounds, list | ||
), f"Uniform distrib '{distrib.id}' does not have valid lower & upper bounds" | ||
return np.random.uniform(lower_bounds, upper_bounds, size=size) | ||
|
||
if URI_DISTRIB_TYPE_NORMAL in distrib.types: | ||
dim = distrib.get_attr(key=URI_DISTRIB_PRED_DIM) | ||
assert ( | ||
isinstance(dim, int) and dim > 0 | ||
), f"Normal distrib '{distrib.id}' does not have valid dimension: {dim}" | ||
|
||
mean = distrib.get_attr(key=URI_DISTRIB_PRED_MEAN) | ||
assert ( | ||
isinstance(mean, list) and len(mean) == dim | ||
), f"Normal distrib '{distrib.id}' does not have valid mean: {mean}" | ||
|
||
if dim == 1: | ||
std = distrib.get_attr(key=URI_DISTRIB_PRED_STD) | ||
assert isinstance( | ||
std, float | ||
), f"Normal distrib '{distrib.id}' does not have valid standard deviation: {std}" | ||
return np.random.normal(loc=mean[0], scale=std, size=size) | ||
|
||
# multivariate normal | ||
cov = distrib.get_attr(key=URI_DISTRIB_PRED_COV) | ||
assert isinstance( | ||
cov, np.ndarray | ||
), f"Normal distrib '{distrib.id}' does not have valid covariance: {cov}" | ||
return np.random.multivariate_normal(mean=mean, cov=cov, size=size) | ||
|
||
raise RuntimeError(f"Distrib '{distrib.id}' has unhandled types: {distrib.types}") | ||
|
||
|
||
class SampledQuantityModel(ModelBase): | ||
distribution: DistributionModel | ||
_sampled_value: Optional[Any] | ||
|
||
def __init__(self, quantity_id: URIRef, graph: Graph) -> None: | ||
super().__init__(node_id=quantity_id, graph=graph) | ||
|
||
distrib_id = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_FROM_DISTRIB) | ||
assert isinstance( | ||
distrib_id, URIRef | ||
), f"SampledQuantity '{self.id}' does not link to a distribution node: {distrib_id}" | ||
self.distribution = DistributionModel(distrib_id=distrib_id, graph=graph) | ||
|
||
self._sampled_value = None | ||
|
||
def sample(self, resample: bool = True) -> Any: | ||
if not resample and self._sampled_value is not None: | ||
return self._sampled_value | ||
|
||
self._sampled_value = sample_from_distrib(distrib=self.distribution) | ||
return self._sampled_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.