Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Dec 12, 2024
1 parent 89e9ed7 commit 4ea91b4
Showing 1 changed file with 128 additions and 74 deletions.
202 changes: 128 additions & 74 deletions docs/examples/attrs_xarray_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

# This example demonstrates a tentative `attrs`-based object model.

from datetime import datetime
from itertools import repeat
from os import PathLike
from pathlib import Path
from typing import Literal, Optional, get_origin
from typing import Iterable, Literal, Optional, get_origin
from warnings import warn

import numpy as np
from attr import define, field, fields_dict
from cattr import Converter
from attr import Factory, define, field, fields_dict
from numpy.typing import ArrayLike, NDArray
from xarray import Dataset, DataTree

Expand All @@ -31,6 +32,8 @@ def _try_resolve_dim(self, name) -> int | str:
return value
if hasattr(self, "model") and hasattr(self.model, "dis"):
return getattr(self.model.dis, name, name)
if hasattr(self, "sim") and hasattr(self.sim, "tdis"):
return getattr(self.sim.tdis, name, name)
return name


Expand Down Expand Up @@ -68,6 +71,29 @@ def _to_shaped_array(
return value


def _to_shaped_list(
value: Optional[Iterable | str | PathLike], self_, field
) -> Optional[list]:
if isinstance(value, (str, PathLike)):
# TODO handle external lists
pass

shape = _try_resolve_shape(self_, field)
if len(shape) > 1:
raise ValueError(f"Expected at most 1 dimension, got {len(shape)}")
unresolved = [dim for dim in shape if not isinstance(dim, int)]
if any(unresolved):
warn(f"Failed to resolve dimension names: {', '.join(unresolved)}")
return value
elif np.array(value).shape == ():
return list(repeat(value, shape[0]))
elif len(value) != shape[0]:
raise ValueError(
f"Length mismatch, got {len(value)}, expected {shape[0]}"
)
return value


def _to_path(value) -> Optional[Path]:
return Path(value) if value else None

Expand All @@ -80,37 +106,62 @@ def _set_data_on_self(self, cls):
fields = fields_dict(cls)
arrays = {}
for n, f in fields.items():
if get_origin(f.type) is not np.ndarray:
continue
value = getattr(self, n)
if value is None:
continue
arrays[n] = (
_parse_dim_names(f.metadata["shape"]),
_to_shaped_array(value, self, f),
)
if get_origin(f.type) is np.ndarray:
value = getattr(self, n)
if value is None:
continue
arrays[n] = (
_parse_dim_names(f.metadata["shape"]),
_to_shaped_array(value, self, f),
)
elif get_origin(f.type) is list:
value = getattr(self, n)
if value is None:
continue
arrays[n] = (
_parse_dim_names(f.metadata["shape"]),
_to_shaped_list(value, self, f),
)

dataset = Dataset(arrays)
self.data = (
DataTree(dataset, name=cls.__name__.lower()[3:])
if issubclass(cls, Model)
DataTree(dataset, name=cls.__name__.lower())
if cls is Sim or issubclass(cls, Model)
else dataset
)

def _set_self_on_model(self, cls):
def _set_self_on_parent(self, cls):
self_name = cls.__name__.lower()
model = getattr(self, "model", None)
if model:
self_name = cls.__name__.lower()[3:]
setattr(model, self_name, self)
model.data = model.data.assign(
{self_name: DataTree(self.data, name=self_name)}
data = (
DataTree(self.data, name=self_name)
if not isinstance(self.data, DataTree)
else self.data
)
model.data = model.data.assign({self_name: data})
sim = getattr(model, "sim", None)
if sim:
model_name = type(model).__name__.lower()
setattr(sim, model_name, model)
sim.data = sim.data.assign({model_name: model.data})
sim = getattr(self, "sim", None)
if sim:
setattr(sim, self_name, self)
data = (
DataTree(self.data, name=self_name)
if not isinstance(self.data, DataTree)
else self.data
)
sim.data = sim.data.assign({self_name: data})

def __attrs_post_init__(self):
if post_init_prev:
post_init_prev(self)

_set_data_on_self(self, cls)
_set_self_on_model(self, cls)
_set_self_on_parent(self, cls)

# TODO: figure out why classes need to have a
# __attrs_post_init__ method for this to work
Expand All @@ -124,7 +175,7 @@ class Model:

@datatree
@define(slots=False)
class GwfDis:
class Dis:
nlay: int = field(default=1, metadata={"block": "dimensions"})
ncol: int = field(default=2, metadata={"block": "dimensions"})
nrow: int = field(default=2, metadata={"block": "dimensions"})
Expand Down Expand Up @@ -170,7 +221,7 @@ def __attrs_post_init__(self):

@datatree
@define(slots=False)
class GwfIc:
class Ic:
strt: NDArray[np.floating] = field(
converter=_to_array,
default=1.0,
Expand All @@ -192,14 +243,22 @@ def __attrs_post_init__(self):

@datatree
@define(slots=False)
class GwfOc:
@define
class Oc:
@define(slots=False)
class Format:
columns: int
width: int
digits: int
format: Literal["exponential", "fixed", "general", "scientific"]

@define(slots=False)
class Steps:
first: Optional[Literal["first"]] = field(default="first")
last: Optional[Literal["last"]] = field(default=None)
all: Optional[Literal["all"]] = field(default=None)
frequency: Optional[int] = field(default=None)
steps: Optional[list[int]] = field(default=None)

budget_file: Optional[Path] = field(
converter=_to_path, default=None, metadata={"block": "options"}
)
Expand All @@ -212,8 +271,9 @@ class Format:
printhead: Optional[Format] = field(
default=None, metadata={"block": "options"}
)
perioddata: Optional[list[list[tuple]]] = field(
default=None, metadata={"block": "perioddata"}
perioddata: list[Steps] = field(
default=Factory(list),
metadata={"block": "perioddata", "shape": "(nper,)"},
)
model: Optional[Model] = field(default=None)

Expand All @@ -225,68 +285,62 @@ def __attrs_post_init__(self):
@datatree
@define(slots=False)
class Gwf(Model):
dis: Optional[GwfDis] = field(default=None)
ic: Optional[GwfIc] = field(default=None)
oc: Optional[GwfOc] = field(default=None)
dis: Optional[Dis] = field(default=None)
ic: Optional[Ic] = field(default=None)
oc: Optional[Oc] = field(default=None)
sim: Optional["Sim"] = field(default=None)

def __attrs_post_init__(self):
# for some reason this is necessary..
pass


# We can define a package with some data.


oc = GwfOc(
budget_file="some/file/path.cbc",
perioddata=[[("print", "budget", "steps", 1, 3, 5)]],
)
assert isinstance(oc.budget_file, Path)


# We now set up a `cattrs` converter to convert an unstructured
# representation of the package input data to a structured form.
@datatree
@define(slots=False)
class Tdis:
@define(slots=False)
class PeriodData:
perlen: float = field(default=1.0)
nstp: int = field(default=1)
tsmult: float = field(default=1.0)

nper: int = field(default=1, metadata={"block": "dimensions"})
perioddata: list[PeriodData] = field(
default=Factory(list),
metadata={"block": "perioddata", "shape": "(nper)"},
)
time_units: Optional[str] = field(
default=None, metadata={"block": "options"}
)
start_date_time: Optional[datetime] = field(
default=None, metadata={"block": "options"}
)
sim: Optional["Sim"] = field(default=None)

converter = Converter()
def __attrs_post_init__(self):
# for some reason this is necessary..
pass


# We can load the full package from an unstructured dictionary,
# as would be returned by a separate IO layer in the future.
# (Either hand-written or using e.g. lark.)
@datatree
@define(slots=False)
class Sim:
tdis: Optional[Tdis] = field(default=None)
gwf: Optional[Gwf] = field(default=None)

oc = converter.structure(
{
"budget_file": "some/file/path.cbc",
"head_file": "some/file/path.hds",
"printhead": {
"columns": 1,
"width": 10,
"digits": 8,
"format": "scientific",
},
"perioddata": [
[
("print", "budget", "steps", 1, 3, 5),
("save", "head", "frequency", 2),
]
],
},
GwfOc,
)
assert oc.budget_file == Path("some/file/path.cbc")
assert oc.printhead.width == 10
assert oc.printhead.format == "scientific"
period = oc.perioddata[0]
assert len(period) == 2
assert period[0] == ("print", "budget", "steps", 1, 3, 5)
def __attrs_post_init__(self):
# for some reason this is necessary..
pass


# Create a model.


gwf = Gwf()
dis = GwfDis(model=gwf)
ic = GwfIc(model=gwf, strt=1)
oc.model = gwf
sim = Sim()
tdis = Tdis(sim=sim, nper=1, perioddata=[Tdis.PeriodData()])
gwf = Gwf(sim=sim)
dis = Dis(model=gwf)
ic = Ic(model=gwf, strt=1)
oc = Oc(model=gwf, perioddata=[Oc.Steps()])

# View the data tree.
gwf.data

0 comments on commit 4ea91b4

Please sign in to comment.