Skip to content

Commit

Permalink
DA Cycler (#37)
Browse files Browse the repository at this point in the history
* Initial dacycler with 3dvar

* Update location indices when timefilt for obsvector, add ability to store error_sd

* Observer now passes error_sd and error_bias to obs_vec

* Working 3D var

* Add one-stop cycle() method that runs DA all the way through (user-specified timesteps)

* ETKF, work in progress

* ETKF compute_analysis, work in progress

* Functional ETKF, but error seems high

* Change model_obj name to forecast_model

* Var3d in Jax

* Add store_as_jax option to observer

* ETKF in Jax, except real_if_close isn't implemented yet. Very slow to run the actual cycle

* Fully functioning (but slow) jax implementation of ETKF

* Remove input type checks

* Jax implementation (ish) of np.real_if_close

* Cleaned up dacycler base class, added documentation

* Cleaned up 3Dvar

* Faster Jax etkf, but requires identical number of observations in each analysis window (WIP)

* Add rest of modules to main __init__

* Rename forecast_model to model_obj in dacycler

* Rename some things in var3d

* ETKF support for non-linear observation operators (h)

* Remove some unneeded debugging printing

* Fix bug in 3dvar R calculation

* Fix test failure, no location_indices specified

* Tests for train/val/test split method

* Store M and times as jax when appropriate

* Basic code for backprop4d, running but not producing accurate results

* Add slicing to obs_vec and state_vector

* Fix time dim calculation for data slicing

* Working bp4d dacycler, but some quirks to it (need to add one to the n_steps you think you might want)

* Fix isclose to rtol=0, stops false equalities with large numbers

* Fully jaxified backprop4d, MUCH faster (200 10-step cycles in under a minute depending on num_epochs)

* Remove outdated comment

* Edit for clarity when observing gridded values

* Rename backprop4d to var4d_backprop

* Dacycler base class docstrings

* Updated top-level docstrings for da classes, plus init args

* Clean up etkf, make step_cycle and step_forecast non-public methods

* Cleanup var4d backprop

* Update name of var4d backprop in init

* Make step_cycle and step_forecast protected methods

* Make _step_cycle and _step_forecast protected in var3d

* Add reservoir computing model as sub-class of dab.model.Model

* Use libmamba for workflow conda install

* Use libmamba for all conda installs in workflow, works very fast

* Set libmamba solver in config for workflow

* Only use libmamba for the big install

* Use environment for pytest workflow bc of libmamba problems

* Fix problem with datetimes not working with isclose

* Create blank environment then populate

* Add full --file option to env update

* Revert to base environment for pytest workflow, add pyqg to environment.yml during workflow

* Install flake8 and pytest earlier in workflow

* Add saving and loading weights to rc model

* Fix for numpy arange floating point arithmetic error

* Fix error in etkf time filter from last commit
  • Loading branch information
kysolvik authored Sep 29, 2023
1 parent 842aa06 commit eeed4e1
Show file tree
Hide file tree
Showing 15 changed files with 1,569 additions and 58 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/python-ci-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ jobs:

- name: Update environment
run: |
conda env update --file environment.yml --name base
conda install -c conda-forge pyqg
conda install conda-libmamba-solver flake8 pytest
# Slightly awkward way of adding an optional package to our environment
echo " - pyqg" >> environment.yml
conda env update --name base --file environment.yml --solver=libmamba
if: steps.cache.outputs.cache-hit != 'true'

- name: Lint with flake8
run: |
conda install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
conda install pytest
python -m pytest
2 changes: 1 addition & 1 deletion dabench/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import data, _suppl_data
from . import data, vector, model, observer, obsop, dacycler, _suppl_data
11 changes: 11 additions & 0 deletions dabench/dacycler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ._dacycler import DACycler
from ._var3d import Var3D
from ._etkf import ETKF
from ._var4d_backprop import Var4DBackprop

__all__ = [
'DACycler',
'Var3D',
'ETKF',
'Var4DBackprop',
]
104 changes: 104 additions & 0 deletions dabench/dacycler/_dacycler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Base class for Data Assimilation Cycler object (DACycler)"""

from dabench import vector
import numpy as np


class DACycler():
"""Base class for DACycler object
Attributes:
system_dim (int): System dimension
delta_t (float): The timestep of the model (assumed uniform)
model_obj (dabench.Model): Forecast model object.
in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar).
Default is False.
ensemble (bool): True for ensemble-based data assimilation techniques
(ETKF). Default is False
B (ndarray): Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
R (ndarray): Observation error covariance matrix. Shape
(obs_dim, obs_dim). If not provided, will be calculated
automatically.
H (ndarray): Observation operator with shape: (obs_dim, system_dim).
If not provided will be calculated automatically.
h (function): Optional observation operator as function. More flexible
(allows for more complex observation operator). Default is None.
"""

def __init__(self,
system_dim=None,
delta_t=None,
model_obj=None,
in_4d=False,
ensemble=False,
B=None,
R=None,
H=None,
h=None,
):

self.h = h
self.H = H
self.R = R
self.B = B
self.in_4d = in_4d
self.ensemble = ensemble
self.system_dim = system_dim
self.delta_t = delta_t
self.model_obj = model_obj

def cycle(self,
input_state,
start_time,
obs_vector,
timesteps,
analysis_window,
analysis_time_in_window=None):
"""Perform DA cycle repeatedly, including analysis and forecast
Args:
input_state (vector.StateVector): Input state.
start_time (float or datetime-like): Starting time.
obs_vector (vector.ObsVector): Observations vector.
timesteps (int): Number of timesteps, in model time.
analysis_window (float): Time window from which to gather
observations for DA Cycle.
analysis_time_in_window (float): Where within analysis_window
to perform analysis. For example, 0.0 is the start of the
window. Default is None, which selects the middle of the
window.
Returns:
vector.StateVector of analyses and times.
"""

if analysis_time_in_window is None:
analysis_time_in_window = analysis_window/2

# For storing outputs
all_analyses = []
all_times = []
cur_time = start_time + analysis_time_in_window
cur_state = input_state

for i in range(timesteps):
# 1. Filter observations to plus/minus 0.1 from that time
obs_vec_timefilt = obs_vector.filter_times(
cur_time - analysis_window/2, cur_time + analysis_window/2)

if obs_vec_timefilt.values.shape[0] > 0:
# 2. Calculate analysis
analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt)
# 3. Forecast next timestep
cur_state = self._step_forecast(analysis)
# 4. Save outputs
all_analyses.append(analysis.values)
all_times.append(cur_time)

cur_time += self.delta_t

return vector.StateVector(values=np.stack(all_analyses),
times=np.array(all_times))

Loading

0 comments on commit eeed4e1

Please sign in to comment.