Skip to content

Commit

Permalink
minimal working example
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Dec 12, 2024
1 parent 9719f97 commit 66e79a9
Showing 1 changed file with 75 additions and 57 deletions.
132 changes: 75 additions & 57 deletions docs/examples/attrs_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from os import PathLike
from pathlib import Path
from typing import Any, Literal, Optional, get_origin
from typing import Literal, Optional, get_origin
from warnings import warn

import attrs
import numpy as np
from attr import define, field, fields_dict
from cattr import Converter
Expand Down Expand Up @@ -35,23 +34,28 @@ def _try_resolve_dim(self, name) -> int | str:
return name


def _to_array(value: ArrayLike) -> Optional[NDArray]:
def _try_resolve_shape(self, field) -> tuple[int | str]:
dim_names = _parse_dim_names(field.metadata["shape"])
return tuple([_try_resolve_dim(self, n) for n in dim_names])


def _to_array(value: Optional[ArrayLike]) -> Optional[NDArray]:
return None if value is None else np.array(value)


def _to_shaped_array(
value: ArrayLike | str | PathLike, self_, field
value: Optional[ArrayLike | str | PathLike], self_, field
) -> Optional[NDArray]:
if isinstance(value, (str, PathLike)):
# TODO
# TODO handle external arrays
pass

value = _to_array(value)
if value is None:
return None
dim_names = _parse_dim_names(field.metadata["shape"])
shape = tuple([_try_resolve_dim(self_, n) for n in dim_names])
unresolved = [d for d in shape if not isinstance(d, int)]

shape = _try_resolve_shape(self_, field)
unresolved = [dim for dim in shape if not isinstance(dim, int)]
if any(unresolved):
warn(f"Failed to resolve dimension names: {', '.join(unresolved)}")
return value
Expand All @@ -69,20 +73,10 @@ def _to_path(value) -> Optional[Path]:


def datatree(cls):
# TODO
# - determine whether data array, data set, or data tree DONE
# - shape check arrays (dynamic validator?)
# check for parent and update dimensions
# then try to realign existing packages?

old_post_init = getattr(cls, "__attrs_post_init__", None)

def __attrs_post_init__(self):
print(f"Running datatree on {cls.__name__}")

if old_post_init:
old_post_init(self)
post_init_name = "__attrs_post_init__"
post_init_prev = getattr(cls, post_init_name, None)

def _set_data_on_self(self, cls):
fields = fields_dict(cls)
arrays = {}
for n, f in fields.items():
Expand All @@ -91,59 +85,71 @@ def __attrs_post_init__(self):
value = getattr(self, n)
if value is None:
continue
arrays[n] = (_parse_dim_names(f.metadata["shape"]), value)
arrays[n] = (
_parse_dim_names(f.metadata["shape"]),
_to_shaped_array(value, self, f),
)
dataset = Dataset(arrays)
children = getattr(self, "children", None)
if children:
self.data = DataTree(
dataset, name=cls.__name__, children=[c.data for c in children]
self.data = (
DataTree(dataset, name=cls.__name__.lower()[3:])
if issubclass(cls, Model)
else dataset
)

def _set_self_on_model(self, cls):
model = getattr(self, "model", None)
if model:
self_name = cls.__name__.lower()[3:]
setattr(model, self_name, self)
model.data = model.data.assign(
{self_name: DataTree(self.data, name=self_name)}
)
else:
self.data = dataset

cls.__attrs_post_init__ = __attrs_post_init__
def __attrs_post_init__(self):
if post_init_prev:
post_init_prev(self)

_set_data_on_self(self, cls)
_set_self_on_model(self, cls)

# TODO: figure out why classes need to have a
# __attrs_post_init__ method for this to work
setattr(cls, post_init_name, __attrs_post_init__)
return cls


class Model:
pass


@datatree
@define(slots=False)
class GwfDis:
nlay: int = field(default=1, metadata={"block": "dimensions"})
ncol: int = field(default=2, metadata={"block": "dimensions"})
nrow: int = field(default=2, metadata={"block": "dimensions"})
delr: NDArray[np.floating] = field(
converter=attrs.Converter(
_to_shaped_array, takes_self=True, takes_field=True
),
converter=_to_array,
default=1.0,
metadata={"block": "griddata", "shape": "(ncol,)"},
)
delc: NDArray[np.floating] = field(
converter=attrs.Converter(
_to_shaped_array, takes_self=True, takes_field=True
),
converter=_to_array,
default=1.0,
metadata={"block": "griddata", "shape": "(nrow,)"},
)
top: NDArray[np.floating] = field(
converter=attrs.Converter(
_to_shaped_array, takes_self=True, takes_field=True
),
converter=_to_array,
default=1.0,
metadata={"block": "griddata", "shape": "(ncol, nrow)"},
)
botm: NDArray[np.floating] = field(
converter=attrs.Converter(
_to_shaped_array, takes_self=True, takes_field=True
),
converter=_to_array,
default=0.0,
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"},
)
idomain: Optional[NDArray[np.integer]] = field(
converter=attrs.Converter(
_to_shaped_array, takes_self=True, takes_field=True
),
converter=_to_array,
default=1,
metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"},
)
Expand All @@ -156,8 +162,7 @@ class GwfDis:
default=False, metadata={"block": "options"}
)
nodes: int = field(init=False)
data: Dataset = field(init=False)
model: Optional[Any] = field(default=None)
model: Optional[Model] = field(default=None)

def __attrs_post_init__(self):
self.nodes = self.nlay * self.ncol * self.nrow
Expand All @@ -167,9 +172,8 @@ def __attrs_post_init__(self):
@define(slots=False)
class GwfIc:
strt: NDArray[np.floating] = field(
converter=attrs.Converter(
_to_shaped_array, takes_self=True, takes_field=True
),
converter=_to_array,
default=1.0,
metadata={"block": "packagedata", "shape": "(nodes)"},
)
export_array_ascii: bool = field(
Expand All @@ -179,8 +183,11 @@ class GwfIc:
default=False,
metadata={"block": "options"},
)
data: Dataset = field(init=False)
model: Optional[Any] = field(default=None)
model: Optional[Model] = field(default=None)

def __attrs_post_init__(self):
# for some reason this is necessary..
pass


@datatree
Expand Down Expand Up @@ -208,17 +215,23 @@ class Format:
perioddata: Optional[list[list[tuple]]] = field(
default=None, metadata={"block": "perioddata"}
)
data: Dataset = field(init=False)
model: Optional[Any] = field(default=None)
model: Optional[Model] = field(default=None)

def __attrs_post_init__(self):
# for some reason this is necessary..
pass


@datatree
@define(slots=False)
class Gwf:
class Gwf(Model):
dis: Optional[GwfDis] = field(default=None)
ic: Optional[GwfIc] = field(default=None)
oc: Optional[GwfOc] = field(default=None)
data: DataTree = field(init=False)

def __attrs_post_init__(self):
# for some reason this is necessary..
pass


# We can define a package with some data.
Expand Down Expand Up @@ -268,7 +281,12 @@ class Gwf:
assert period[0] == ("print", "budget", "steps", 1, 3, 5)


# Creating a model by constructor.
# Create a model.


gwf = Gwf()
dis = GwfDis(model=gwf)
ic = GwfIc(model=gwf, strt=1)
oc.model = gwf

gwf = Gwf(dis=GwfDis(), ic=GwfIc(strt=1), oc=oc)
# View the data tree.

0 comments on commit 66e79a9

Please sign in to comment.