Skip to content

Commit

Permalink
ok even better typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Feb 18, 2024
1 parent 64b7f05 commit e8a75c2
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ ignore_missing_imports=true
[[tool.mypy.overrides]]
module="kalman.*"
ignore_missing_imports=true

[[tool.mypy.overrides]]
module="scipy.*"
ignore_missing_imports=true
35 changes: 18 additions & 17 deletions src/gen_experiments/data.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
from math import ceil
from pathlib import Path
from typing import Callable
from typing import Callable, Optional, cast
from warnings import warn

import mitosis
import numpy as np
import scipy

from gen_experiments.utils import GridsearchResultDetails
from gen_experiments.utils import Float1D, Float2D, GridsearchResultDetails

INTEGRATOR_KEYWORDS = {"rtol": 1e-12, "method": "LSODA", "atol": 1e-12}
TRIALS_FOLDER = Path(__file__).parent.absolute() / "trials"


def gen_data(
rhs_func,
n_coord,
seed=None,
n_trajectories=1,
x0_center=None,
ic_stdev=3,
noise_abs=None,
noise_rel=None,
nonnegative=False,
dt=0.01,
t_end=10,
):
rhs_func: Callable,
n_coord: int,
seed: Optional[int] = None,
n_trajectories: int = 1,
x0_center: Optional[Float1D] = None,
ic_stdev: float = 3,
noise_abs: Optional[float] = None,
noise_rel: Optional[float] = None,
nonnegative: bool = False,
dt: float = 0.01,
t_end: float = 10,
) -> tuple[float, Float1D, Float2D, Float2D, Float2D, Float2D]:
"""Generate random training and test data
Note that test data has no noise.
Expand Down Expand Up @@ -57,7 +57,7 @@ def gen_data(
rng = np.random.default_rng(seed)
if x0_center is None:
x0_center = np.zeros((n_coord))
t_train = np.arange(0, t_end, dt)
t_train = np.arange(0, t_end, dt, dtype=np.float_)
t_train_span = (t_train[0], t_train[-1])
if nonnegative:
shape = ((x0_center + 1) / ic_stdev) ** 2
Expand Down Expand Up @@ -123,10 +123,11 @@ def _alert_short(arr):
x_train_true = np.copy(x_train)
if noise_rel is not None:
noise_abs = np.sqrt(_signal_avg_power(x_test) * noise_rel)
x_train = x_train + noise_abs * rng.standard_normal(x_train.shape)
x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape)
x_train = list(x_train)
x_test = list(x_test)
x_dot_test = list(x_dot_test)
x_train_true = list(x_train_true)
return dt, t_train, x_train, x_test, x_dot_test, x_train_true


Expand Down Expand Up @@ -211,7 +212,7 @@ def gen_pde_data(
x_train_true = np.copy(x_train)
if noise_rel is not None:
noise_abs = _max_amplitude(x_test) * noise_rel
x_train = x_train + noise_abs * rng.standard_normal(x_train.shape)
x_train = x_train + cast(float, noise_abs) * rng.standard_normal(x_train.shape)
x_train = [np.moveaxis(x_train, 0, -2)]
x_train_true = np.moveaxis(x_train_true, 0, -2)
x_test = [np.moveaxis(x_test, [0, 1], [-1, -2])]
Expand Down
25 changes: 25 additions & 0 deletions src/gen_experiments/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Annotated, Generic, TypedDict, TypeVar

import numpy as np
from numpy.typing import DTypeLike, NBitBase, NDArray

# T = TypeVar("T")

# class Foo[T]:
# items: list[T]

# def __init__(self, thing: T):
# self.items = [thing, thing]

# Bar =


T = TypeVar("T", bound=np.generic)
Foo = NDArray[T]
Bar = Annotated[NDArray, "foobar"]

lil_foo = NDArray[np.void]


def baz(qux: Foo[np.void]):
pass
4 changes: 2 additions & 2 deletions src/gen_experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ class SavedData(TypedDict):


T = TypeVar("T", bound=np.generic)
GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"] # type: ignore
GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"]
SeriesData = Annotated[
list[
tuple[
Annotated[GridsearchResult, "metrics"],
Annotated[GridsearchResult, "arg_opts"],
Annotated[GridsearchResult[np.void], "arg_opts"],
]
],
"len=n_grid_axes",
Expand Down

0 comments on commit e8a75c2

Please sign in to comment.