Skip to content

Commit

Permalink
Renaming r to x in var4d backprop to make it more generalizable
Browse files Browse the repository at this point in the history
  • Loading branch information
kysolvik committed Oct 5, 2023
1 parent fb053df commit cdced26
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions dabench/dacycler/_var4d_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import jax.numpy as jnp
import jax.scipy as jscipy
from jax import grad
from jax import grad, value_and_grad
import jax

from dabench import dacycler, vector
Expand Down Expand Up @@ -83,26 +83,26 @@ def _calc_default_R(self, obs_values, obs_error_sd):
def _calc_default_B(self):
return jnp.identity(self.system_dim)

def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, rb0, n_steps):
def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, xb0, n_steps):
"""Define loss function based on 4dvar cost"""
Rinv = jscipy.linalg.inv(R)
Binv = jscipy.linalg.inv(B)

def loss_4dvarcost(r0):
def loss_4dvarcost(x0):
# Make prediction based on current r
pred_r = self.step_forecast(
vector.StateVector(values=r0, store_as_jax=True),
pred_x = self.step_forecast(
vector.StateVector(values=x0, store_as_jax=True),
n_steps=n_steps).values

# Apply observation operator to map to obs spcae
pred_obs = time_sel_matrix @ pred_r @ H
pred_obs = time_sel_matrix @ pred_x @ H

# Calculate observation term of J_0
resid = pred_obs.ravel() - obs_vals.ravel()
obs_term = 0.5*np.sum(resid.T @ Rinv @ resid)

# Calculate initial departure term of J_0 based on original x0
db0 = pred_r[0].ravel() - rb0.ravel()
db0 = pred_x[0].ravel() - xb0.ravel()
initial_term = 0.5*(db0.T @ Binv @ db0)

# Cost is the sum of the two terms
Expand All @@ -118,10 +118,10 @@ def _calc_time_sel_matrix(self, obs_steps_inds, n_pred_steps):

def _make_backprop_epoch(self, loss_func):

def _backprop_epoch(i, r0):
dr0 = grad(loss_func, argnums=0)(r0)
r0 -= self.learning_rate*dr0
return r0
def _backprop_epoch(i, x0):
loss, dx0 = value_and_grad(loss_func, argnums=0)(x0)
x0 -= self.learning_rate*dx0
return x0

return _backprop_epoch

Expand All @@ -147,32 +147,21 @@ def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd,
else:
B = self.B

r0 = xb
xb = xb
loss_func = self._make_loss(obs_values, H, B, R, time_sel_matrix,
rb0=r0, n_steps=n_steps)
xb0=xb, n_steps=n_steps)
backprop_epoch_func = self._make_backprop_epoch(loss_func)
r0 = jax.lax.fori_loop(0, self.num_epochs, backprop_epoch_func, r0)
xb = jax.lax.fori_loop(0, self.num_epochs, backprop_epoch_func, xb)

ra = self.step_forecast(
vector.StateVector(values=r0, store_as_jax=True),
xa = self.step_forecast(
vector.StateVector(values=xb, store_as_jax=True),
n_steps=n_steps)

return ra, None
return xa, None

def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1,
obs_window_indices=[0]):
"""Perform one step of DA Cycle
Args:
xb:
yo:
H
Returns:
vector.StateVector containing analysis results
"""
"""Perform one step of DA Cycle"""
time_sel_matrix = self._calc_time_sel_matrix(obs_window_indices,
n_steps)
if H is not None or h is None:
Expand All @@ -181,9 +170,11 @@ def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1,
H, R, B, time_sel_matrix=time_sel_matrix, n_steps=n_steps)
else:
return self._cycle_obsop(
xb, yo, h, R, B, time_sel_matrix=time_sel_matrix, n_steps=n_steps)
xb, yo, h, R, B, time_sel_matrix=time_sel_matrix,
n_steps=n_steps)

def step_forecast(self, xa, n_steps=1):
"""Perform forecast using model object"""
if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args:
return self.model_obj.forecast(xa, n_steps=n_steps)
else:
Expand Down

0 comments on commit cdced26

Please sign in to comment.