Skip to content

Commit

Permalink
add model parsing and tests for distributions
Browse files Browse the repository at this point in the history
- support secorolab/metamodels#13
- add DistributionModel that for parsing univariate and multivariate
  versions of the uniform and normal distributions, as well as uniform
  rotation (3D) distribution
- add method to sample using info from DistributionModel
- add SampledQuantityModel that cache sample when requested
- add unit test on a valid model that check sampling of all supported
  distributions
- minor: use check_shacl_constraints func in test_python_model
- minor: remove black precommit hook since conflicting with ruff
  • Loading branch information
minhnh committed Nov 19, 2024
1 parent bf317c8 commit 89e3ee9
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 19 deletions.
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
repos:
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.6
hooks:
Expand Down
2 changes: 1 addition & 1 deletion src/rdf_utils/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, violation_str: str):

def check_shacl_constraints(graph: Graph, shacl_dict: Dict[str, str], quiet=False) -> bool:
"""
:param graph: rdfl.Graph to be checked
:param graph: rdflib.Graph to be checked
:param shacl_dict: mapping from SHACL path to graph format, e.g. URL -> "turtle"
:param quiet: if true will not throw an exception
"""
Expand Down
12 changes: 9 additions & 3 deletions src/rdf_utils/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,21 @@ class ModelBase(object):
types: set[URIRef]
_attributes: Dict[URIRef, Any]

def __init__(self, node_id: URIRef, graph: Optional[Graph] = None, types: Optional[set[URIRef]] = None) -> None:
def __init__(
self, node_id: URIRef, graph: Optional[Graph] = None, types: Optional[set[URIRef]] = None
) -> None:
self.id = node_id
if graph is not None:
self.types = get_node_types(graph=graph, node_id=node_id)
assert types is None, f"ModelBase.__init__: node '{node_id}': both 'graph' and 'types' args are not None"
assert (
types is None
), f"ModelBase.__init__: node '{node_id}': both 'graph' and 'types' args are not None"
elif types is not None:
self.types = types
else:
raise RuntimeError(f"ModelBase.__init__: node '{node_id}': neither 'graph' or 'types' specified")
raise RuntimeError(
f"ModelBase.__init__: node '{node_id}': neither 'graph' or 'types' specified"
)
assert len(self.types) > 0, f"node '{self.id}' has no type"

self._attributes = {}
Expand Down
244 changes: 244 additions & 0 deletions src/rdf_utils/models/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# SPDX-Litense-Identifier: MPL-2.0
from typing import Any, Optional
import numpy as np
from rdflib import BNode, Literal, URIRef, Graph
from rdf_utils.collection import load_list_re
from rdf_utils.models.common import ModelBase
from rdf_utils.namespace import NS_MM_DISTRIB


URI_DISTRIB_TYPE_DISTRIB = NS_MM_DISTRIB["Distribution"]
URI_DISTRIB_PRED_DIM = NS_MM_DISTRIB["dimension"]

URI_DISTRIB_TYPE_CONT = NS_MM_DISTRIB["Continuous"]
URI_DISTRIB_TYPE_DISCR = NS_MM_DISTRIB["Discrete"]

URI_DISTRIB_TYPE_UNIFORM = NS_MM_DISTRIB["Uniform"]
URI_DISTRIB_PRED_UPPER = NS_MM_DISTRIB["upper-bound"]
URI_DISTRIB_PRED_LOWER = NS_MM_DISTRIB["lower-bound"]

URI_DISTRIB_TYPE_NORMAL = NS_MM_DISTRIB["Normal"]
URI_DISTRIB_PRED_MEAN = NS_MM_DISTRIB["mean"]
URI_DISTRIB_PRED_STD = NS_MM_DISTRIB["standard-deviation"]
URI_DISTRIB_PRED_COV = NS_MM_DISTRIB["covariance"]

URI_DISTRIB_TYPE_UNIFORM_ROT = NS_MM_DISTRIB["UniformRotation"]

URI_DISTRIB_TYPE_SAMPLED_QUANTITY = NS_MM_DISTRIB["SampledQuantity"]
URI_DISTRIB_PRED_FROM_DISTRIB = NS_MM_DISTRIB["from-distribution"]


def _get_float_from_literal(literal: Literal) -> float:
try:
lit_val = literal.toPython()
return float(lit_val)
except ValueError as e:
raise ValueError(f"can't convert literal '{literal}' as float: {e}")


class DistributionModel(ModelBase):
distrib_type: URIRef

def __init__(self, distrib_id: URIRef, graph: Graph) -> None:
super().__init__(node_id=distrib_id, graph=graph)

if URI_DISTRIB_TYPE_UNIFORM_ROT in self.types:
self.distrib_type = URI_DISTRIB_TYPE_UNIFORM_ROT
elif URI_DISTRIB_TYPE_UNIFORM in self.types:
self.distrib_type = URI_DISTRIB_TYPE_UNIFORM
self._load_uniform_distrib_attrs(graph=graph)
elif URI_DISTRIB_TYPE_NORMAL in self.types:
self.distrib_type = URI_DISTRIB_TYPE_NORMAL
self._load_normal_distrib_attrs(graph=graph)
else:
raise RuntimeError(f"Distrib '{self.id}' has unhandled types: {self.types}")

def _load_uniform_distrib_attrs(self, graph: Graph) -> None:
# dimension
dim_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_DIM)
assert isinstance(
dim_node, Literal
), f"Uniform distrib '{self.id}' does not have a Literal 'dimension': {dim_node}"
dim = dim_node.toPython()
assert (
isinstance(dim, int) and dim > 0
), f"Uniform distrib '{self.id}' does not have a positive integer 'dimension': {dim}"

upper_bounds = None
lower_bounds = None

# upper bound(s)
upper_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_UPPER)
if isinstance(upper_node, Literal):
upper_val = _get_float_from_literal(upper_node)
upper_bounds = [upper_val]
elif isinstance(upper_node, BNode):
upper_bounds = load_list_re(
graph=graph, first_node=upper_node, parse_uri=False, quiet=False
)
else:
raise RuntimeError(
f"Uniform distrib '{self.id}' has invalid type for :upper-bound: {type(upper_node)}"
)

# lower bound(s)
lower_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_LOWER)
if isinstance(lower_node, Literal):
lower_val = _get_float_from_literal(lower_node)
lower_bounds = [lower_val]
elif isinstance(lower_node, BNode):
lower_bounds = load_list_re(
graph=graph, first_node=lower_node, parse_uri=False, quiet=False
)
else:
raise RuntimeError(
f"Uniform distrib '{self.id}' has invalid type for lower-bound: {type(lower_node)}"
)

# check property dimensions
assert (
dim == len(lower_bounds) and dim == len(upper_bounds)
), f"Uniform distrib '{self.id}' has mismatching property dimensions: dim={dim}, upper bounds num={len(upper_bounds)}, lower bounds num={len(lower_bounds)}"

# check lower bounds less than higher bounds
less_than = np.less(lower_bounds, upper_bounds)
assert np.all(
less_than
), f"Uniform distrib '{self.id}': not all lower bounds less than upper bounds: lower={lower_bounds}, upper={upper_bounds}"

# set attributes
self.set_attr(key=URI_DISTRIB_PRED_DIM, val=dim)
self.set_attr(key=URI_DISTRIB_PRED_UPPER, val=upper_bounds)
self.set_attr(key=URI_DISTRIB_PRED_LOWER, val=lower_bounds)

def _load_normal_distrib_attrs(self, graph: Graph) -> None:
# dimension
dim_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_DIM)
assert isinstance(
dim_node, Literal
), f"Normal distrib '{self.id}' does not have a Literal 'dimension': {dim_node}"
dim = dim_node.toPython()
assert (
isinstance(dim, int) and dim > 0
), f"Normal distrib '{self.id}' does not have a positive integer 'dimension': {dim}"
self.set_attr(key=URI_DISTRIB_PRED_DIM, val=dim)

# get mean
mean_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_MEAN)
if isinstance(mean_node, Literal):
assert (
dim == 1
), f"Normal distrib '{self.id}' has single mean '{mean_node}' but dimension '{dim}'"
mean_val = _get_float_from_literal(mean_node)
self.set_attr(key=URI_DISTRIB_PRED_MEAN, val=[mean_val])
elif isinstance(mean_node, BNode):
mean_vals = load_list_re(
graph=graph, first_node=mean_node, parse_uri=False, quiet=False
)
assert (
len(mean_vals) == dim
), f"Normal distrib '{self.id}': number of mean values ({len(mean_vals)}) does not match dimension ({dim})"
self.set_attr(key=URI_DISTRIB_PRED_MEAN, val=mean_vals)
else:
raise RuntimeError(
f"Normal distrib '{self.id}' has invalid type for 'mean': {type(mean_node)}"
)

# get standard deviation or covariance based on dimension
if dim == 1:
std_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_STD)
assert isinstance(
std_node, Literal
), f"Normal distrib '{self.id}' does not have a Literal 'standard-deviation': {std_node}"
std = _get_float_from_literal(std_node)
self.set_attr(key=URI_DISTRIB_PRED_STD, val=std)
else:
cov_node = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_COV)
assert isinstance(
cov_node, BNode
), f"Normal distrib '{self.id}': 'covariance' property not a container, type={type(cov_node)}"
cov_vals = load_list_re(graph=graph, first_node=cov_node, parse_uri=False, quiet=False)
try:
cov_mat = np.array(cov_vals, dtype=float)
except ValueError as e:
raise ValueError(
f"Normal distrib '{self.id}', can't convert covariance to float numpy array: {e}\n{cov_vals}"
)
assert (
cov_mat.shape
== (
dim,
dim,
)
), f"Normal distrib '{self.id}': dimension='{dim}' doesn't match 'covariance' shape'{cov_mat.shape}'"
self.set_attr(key=URI_DISTRIB_PRED_COV, val=cov_mat)


def sample_from_distrib(
distrib: DistributionModel, size: Optional[int | tuple[int, ...]] = None
) -> Any:
if URI_DISTRIB_TYPE_UNIFORM_ROT in distrib.types:
try:
from scipy.spatial.transform import Rotation
except ImportError:
raise RuntimeError("to sample random rotations, 'scipy' must be installed")

return Rotation.random()

if URI_DISTRIB_TYPE_UNIFORM in distrib.types:
lower_bounds = distrib.get_attr(key=URI_DISTRIB_PRED_LOWER)
upper_bounds = distrib.get_attr(key=URI_DISTRIB_PRED_UPPER)
assert isinstance(lower_bounds, list) and isinstance(
upper_bounds, list
), f"Uniform distrib '{distrib.id}' does not have valid lower & upper bounds"
return np.random.uniform(lower_bounds, upper_bounds, size=size)

if URI_DISTRIB_TYPE_NORMAL in distrib.types:
dim = distrib.get_attr(key=URI_DISTRIB_PRED_DIM)
assert (
isinstance(dim, int) and dim > 0
), f"Normal distrib '{distrib.id}' does not have valid dimension: {dim}"

mean = distrib.get_attr(key=URI_DISTRIB_PRED_MEAN)
assert (
isinstance(mean, list) and len(mean) == dim
), f"Normal distrib '{distrib.id}' does not have valid mean: {mean}"

if dim == 1:
std = distrib.get_attr(key=URI_DISTRIB_PRED_STD)
assert isinstance(
std, float
), f"Normal distrib '{distrib.id}' does not have valid standard deviation: {std}"
return np.random.normal(loc=mean[0], scale=std, size=size)

# multivariate normal
cov = distrib.get_attr(key=URI_DISTRIB_PRED_COV)
assert isinstance(
cov, np.ndarray
), f"Normal distrib '{distrib.id}' does not have valid covariance: {cov}"
return np.random.multivariate_normal(mean=mean, cov=cov, size=size)

raise RuntimeError(f"Distrib '{distrib.id}' has unhandled types: {distrib.types}")


class SampledQuantityModel(ModelBase):
distribution: DistributionModel
_sampled_value: Optional[Any]

def __init__(self, quantity_id: URIRef, graph: Graph) -> None:
super().__init__(node_id=quantity_id, graph=graph)

distrib_id = graph.value(subject=self.id, predicate=URI_DISTRIB_PRED_FROM_DISTRIB)
assert isinstance(
distrib_id, URIRef
), f"SampledQuantity '{self.id}' does not link to a distribution node: {distrib_id}"
self.distribution = DistributionModel(distrib_id=distrib_id, graph=graph)

self._sampled_value = None

def sample(self, resample: bool = True) -> Any:
if not resample and self._sampled_value is not None:
return self._sampled_value

self._sampled_value = sample_from_distrib(distrib=self.distribution)
return self._sampled_value
3 changes: 2 additions & 1 deletion src/rdf_utils/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from rdflib import Namespace
from rdf_utils.uri import (
URI_MM_AGN,
URI_MM_DISTRIB,
URI_MM_GEOM,
URI_MM_GEOM_REL,
URI_MM_GEOM_COORD,
Expand All @@ -17,8 +18,8 @@
NS_MM_GEOM_COORD = Namespace(URI_MM_GEOM_COORD)

NS_MM_PYTHON = Namespace(URI_MM_PYTHON)

NS_MM_ENV = Namespace(URI_MM_ENV)
NS_MM_AGN = Namespace(URI_MM_AGN)
NS_MM_TIME = Namespace(URI_MM_TIME)
NS_MM_EL = Namespace(URI_MM_EL)
NS_MM_DISTRIB = Namespace(URI_MM_DISTRIB)
4 changes: 4 additions & 0 deletions src/rdf_utils/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
URL_MM_EL_JSON = f"{URL_SECORO_MM}/behaviour/event_loop.json"
URL_MM_EL_SHACL = f"{URL_SECORO_MM}/behaviour/event_loop.shacl.ttl"

URI_MM_DISTRIB = f"{URL_SECORO_MM}/probability/distribution#"
URL_MM_DISTRIB_JSON = f"{URL_SECORO_MM}/probability/distribution.json"
URL_MM_DISTRIB_SHACL = f"{URL_SECORO_MM}/probability/distribution.shacl.ttl"


def try_expand_curie(
ns_manager: NamespaceManager, curie_str: str, quiet: bool = False
Expand Down
Loading

0 comments on commit 89e3ee9

Please sign in to comment.