From 66e79a905f0e7304223b25b64c1b2f85d1fd0943 Mon Sep 17 00:00:00 2001 From: wpbonelli Date: Wed, 11 Dec 2024 09:59:21 -0500 Subject: [PATCH] minimal working example --- docs/examples/attrs_demo.py | 132 ++++++++++++++++++++---------------- 1 file changed, 75 insertions(+), 57 deletions(-) diff --git a/docs/examples/attrs_demo.py b/docs/examples/attrs_demo.py index 1bb2676..fe267c1 100644 --- a/docs/examples/attrs_demo.py +++ b/docs/examples/attrs_demo.py @@ -4,10 +4,9 @@ from os import PathLike from pathlib import Path -from typing import Any, Literal, Optional, get_origin +from typing import Literal, Optional, get_origin from warnings import warn -import attrs import numpy as np from attr import define, field, fields_dict from cattr import Converter @@ -35,23 +34,28 @@ def _try_resolve_dim(self, name) -> int | str: return name -def _to_array(value: ArrayLike) -> Optional[NDArray]: +def _try_resolve_shape(self, field) -> tuple[int | str]: + dim_names = _parse_dim_names(field.metadata["shape"]) + return tuple([_try_resolve_dim(self, n) for n in dim_names]) + + +def _to_array(value: Optional[ArrayLike]) -> Optional[NDArray]: return None if value is None else np.array(value) def _to_shaped_array( - value: ArrayLike | str | PathLike, self_, field + value: Optional[ArrayLike | str | PathLike], self_, field ) -> Optional[NDArray]: if isinstance(value, (str, PathLike)): - # TODO + # TODO handle external arrays pass value = _to_array(value) if value is None: return None - dim_names = _parse_dim_names(field.metadata["shape"]) - shape = tuple([_try_resolve_dim(self_, n) for n in dim_names]) - unresolved = [d for d in shape if not isinstance(d, int)] + + shape = _try_resolve_shape(self_, field) + unresolved = [dim for dim in shape if not isinstance(dim, int)] if any(unresolved): warn(f"Failed to resolve dimension names: {', '.join(unresolved)}") return value @@ -69,20 +73,10 @@ def _to_path(value) -> Optional[Path]: def datatree(cls): - # TODO - # - determine whether data array, data set, or data tree DONE - # - shape check arrays (dynamic validator?) - # check for parent and update dimensions - # then try to realign existing packages? - - old_post_init = getattr(cls, "__attrs_post_init__", None) - - def __attrs_post_init__(self): - print(f"Running datatree on {cls.__name__}") - - if old_post_init: - old_post_init(self) + post_init_name = "__attrs_post_init__" + post_init_prev = getattr(cls, post_init_name, None) + def _set_data_on_self(self, cls): fields = fields_dict(cls) arrays = {} for n, f in fields.items(): @@ -91,21 +85,43 @@ def __attrs_post_init__(self): value = getattr(self, n) if value is None: continue - arrays[n] = (_parse_dim_names(f.metadata["shape"]), value) + arrays[n] = ( + _parse_dim_names(f.metadata["shape"]), + _to_shaped_array(value, self, f), + ) dataset = Dataset(arrays) - children = getattr(self, "children", None) - if children: - self.data = DataTree( - dataset, name=cls.__name__, children=[c.data for c in children] + self.data = ( + DataTree(dataset, name=cls.__name__.lower()[3:]) + if issubclass(cls, Model) + else dataset + ) + + def _set_self_on_model(self, cls): + model = getattr(self, "model", None) + if model: + self_name = cls.__name__.lower()[3:] + setattr(model, self_name, self) + model.data = model.data.assign( + {self_name: DataTree(self.data, name=self_name)} ) - else: - self.data = dataset - cls.__attrs_post_init__ = __attrs_post_init__ + def __attrs_post_init__(self): + if post_init_prev: + post_init_prev(self) + + _set_data_on_self(self, cls) + _set_self_on_model(self, cls) + # TODO: figure out why classes need to have a + # __attrs_post_init__ method for this to work + setattr(cls, post_init_name, __attrs_post_init__) return cls +class Model: + pass + + @datatree @define(slots=False) class GwfDis: @@ -113,37 +129,27 @@ class GwfDis: ncol: int = field(default=2, metadata={"block": "dimensions"}) nrow: int = field(default=2, metadata={"block": "dimensions"}) delr: NDArray[np.floating] = field( - converter=attrs.Converter( - _to_shaped_array, takes_self=True, takes_field=True - ), + converter=_to_array, default=1.0, metadata={"block": "griddata", "shape": "(ncol,)"}, ) delc: NDArray[np.floating] = field( - converter=attrs.Converter( - _to_shaped_array, takes_self=True, takes_field=True - ), + converter=_to_array, default=1.0, metadata={"block": "griddata", "shape": "(nrow,)"}, ) top: NDArray[np.floating] = field( - converter=attrs.Converter( - _to_shaped_array, takes_self=True, takes_field=True - ), + converter=_to_array, default=1.0, metadata={"block": "griddata", "shape": "(ncol, nrow)"}, ) botm: NDArray[np.floating] = field( - converter=attrs.Converter( - _to_shaped_array, takes_self=True, takes_field=True - ), + converter=_to_array, default=0.0, metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"}, ) idomain: Optional[NDArray[np.integer]] = field( - converter=attrs.Converter( - _to_shaped_array, takes_self=True, takes_field=True - ), + converter=_to_array, default=1, metadata={"block": "griddata", "shape": "(ncol, nrow, nlay)"}, ) @@ -156,8 +162,7 @@ class GwfDis: default=False, metadata={"block": "options"} ) nodes: int = field(init=False) - data: Dataset = field(init=False) - model: Optional[Any] = field(default=None) + model: Optional[Model] = field(default=None) def __attrs_post_init__(self): self.nodes = self.nlay * self.ncol * self.nrow @@ -167,9 +172,8 @@ def __attrs_post_init__(self): @define(slots=False) class GwfIc: strt: NDArray[np.floating] = field( - converter=attrs.Converter( - _to_shaped_array, takes_self=True, takes_field=True - ), + converter=_to_array, + default=1.0, metadata={"block": "packagedata", "shape": "(nodes)"}, ) export_array_ascii: bool = field( @@ -179,8 +183,11 @@ class GwfIc: default=False, metadata={"block": "options"}, ) - data: Dataset = field(init=False) - model: Optional[Any] = field(default=None) + model: Optional[Model] = field(default=None) + + def __attrs_post_init__(self): + # for some reason this is necessary.. + pass @datatree @@ -208,17 +215,23 @@ class Format: perioddata: Optional[list[list[tuple]]] = field( default=None, metadata={"block": "perioddata"} ) - data: Dataset = field(init=False) - model: Optional[Any] = field(default=None) + model: Optional[Model] = field(default=None) + + def __attrs_post_init__(self): + # for some reason this is necessary.. + pass @datatree @define(slots=False) -class Gwf: +class Gwf(Model): dis: Optional[GwfDis] = field(default=None) ic: Optional[GwfIc] = field(default=None) oc: Optional[GwfOc] = field(default=None) - data: DataTree = field(init=False) + + def __attrs_post_init__(self): + # for some reason this is necessary.. + pass # We can define a package with some data. @@ -268,7 +281,12 @@ class Gwf: assert period[0] == ("print", "budget", "steps", 1, 3, 5) -# Creating a model by constructor. +# Create a model. + +gwf = Gwf() +dis = GwfDis(model=gwf) +ic = GwfIc(model=gwf, strt=1) +oc.model = gwf -gwf = Gwf(dis=GwfDis(), ic=GwfIc(strt=1), oc=oc) +# View the data tree.