Skip to content

Commit

Permalink
4dBP with jacobian computed wrt obs. Works, but very slow
Browse files Browse the repository at this point in the history
  • Loading branch information
kysolvik committed Nov 16, 2023
1 parent 4b9c256 commit ffb54c1
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions dabench/dacycler/_var4d_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,29 +89,24 @@ def _calc_default_R(self, obs_values, obs_error_sd):
def _calc_default_B(self):
return jnp.identity(self.system_dim)

def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, n_steps):
def _make_loss(self, xb0, obs_vals, init_pred_obs, Ht, B, R, time_sel_matrix, M_obs, n_steps):
"""Define loss function based on 4dvar cost"""
Rinv = jscipy.linalg.inv(R)
Binv = jscipy.linalg.inv(B)

@jax.jit
def loss_4dvarcost(x0):
pred_x = self.step_forecast(
vector.StateVector(values=x0, store_as_jax=True),
n_steps=n_steps).values

# Make prediction based on current r
xb0 = pred_x[0]
# Get initial departure
db0 = (x0.ravel() - xb0.ravel())

# Apply observation operator to map to obs spcae
pred_obs = time_sel_matrix @ pred_x @ H
# Get approximate new observations
pred_obs = M_obs @ db0 + init_pred_obs

# Calculate observation term of J_0
resid = (pred_obs.ravel() - obs_vals.ravel())
obs_term = 0.5*np.sum(resid.T @ Rinv @ resid)

# Calculate initial departure term of J_0 based on original x0
db0 = (x0.ravel() - xb0.ravel())
initial_term = 0.5*(db0.T @ Binv @ db0)

# Cost is the sum of the two terms
Expand All @@ -125,19 +120,26 @@ def _calc_time_sel_matrix(self, obs_steps_inds, n_pred_steps):
jnp.arange(time_sel_matrix.shape[0]), obs_steps_inds].set(1)
return time_sel_matrix

def _make_backprop_epoch(self, optimizer):
def _make_backprop_epoch(self, loss_func, optimizer):

@jax.jit
@ jax.jit
def _backprop_epoch(x0_opt_state_tuple, i):
x0, dx0, i, opt_state = x0_opt_state_tuple
x0, i, opt_state = x0_opt_state_tuple
loss_val, dx0 = value_and_grad(loss_func, argnums=0)(x0)
updates, opt_state = optimizer.update(dx0, opt_state)
x0_new = optax.apply_updates(x0, updates)

return (x0_new, dx0, i+1, opt_state), x0_new
return (x0_new, i+1, opt_state), loss_val

return _backprop_epoch

def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd,
def _gen_forecast_obs(self, x0, Ht, time_sel_matrix):
pred_x = self.step_forecast(
vector.StateVector(values=x0, store_as_jax=True), 11).values
pred_obs = time_sel_matrix @ pred_x @ Ht
return pred_obs, pred_obs

def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd,
H=None, h=None, R=None, B=None, time_sel_matrix=None,
n_steps=1):
if H is None and h is None:
Expand All @@ -148,19 +150,27 @@ def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd,
h = self.h
else:
H = self.H
Ht = H.T

if R is None:
if self.R is None:
R = self._calc_default_R(obs_values, obs_error_sd)
else:
R = self.R

if B is None:
if self.B is None:
B = self._calc_default_B()
else:
B = self.B

x0 = xb
loss_func = self._make_loss(obs_values, H, B, R, time_sel_matrix,
# Get initial observations and jacobian
M_obs, pred_obs = jax.jacrev(
self._gen_forecast_obs, has_aux=True, argnums=0)(
x0, Ht, time_sel_matrix)

loss_func = self._make_loss(x0, obs_values, pred_obs, Ht, B, R,
time_sel_matrix, M_obs,
n_steps=n_steps)

lr = optax.exponential_decay(
Expand All @@ -170,21 +180,18 @@ def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd,
opt_state = optimizer.init(x0)

# Make initial forecast and calculate loss
loss_val, dx0 = value_and_grad(loss_func, argnums=0)(x0)
backprop_epoch_func = self._make_backprop_epoch(optimizer)
x0_opt_state_tuple, x0_vals = jax.lax.scan(
backprop_epoch_func, init=(x0, dx0, 0, opt_state),
backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer)
x0_opt_state_tuple, loss_vals = jax.lax.scan(
backprop_epoch_func, init=(x0, 0, opt_state),
xs=None, length=self.num_epochs)

x0, dx0, i, opt_state = x0_opt_state_tuple
x0, i, opt_state = x0_opt_state_tuple

# Analysis
loss_val_end, dx0 = value_and_grad(loss_func, argnums=0)(x0)
xa = self.step_forecast(
vector.StateVector(values=x0, store_as_jax=True),
n_steps=n_steps)

return xa, jnp.array([loss_val, loss_val_end])
return xa, loss_vals

def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1,
obs_window_indices=[0]):
Expand Down

0 comments on commit ffb54c1

Please sign in to comment.