diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index e65c1b4..5561168 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -89,29 +89,24 @@ def _calc_default_R(self, obs_values, obs_error_sd): def _calc_default_B(self): return jnp.identity(self.system_dim) - def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, n_steps): + def _make_loss(self, xb0, obs_vals, init_pred_obs, Ht, B, R, time_sel_matrix, M_obs, n_steps): """Define loss function based on 4dvar cost""" Rinv = jscipy.linalg.inv(R) Binv = jscipy.linalg.inv(B) @jax.jit def loss_4dvarcost(x0): - pred_x = self.step_forecast( - vector.StateVector(values=x0, store_as_jax=True), - n_steps=n_steps).values - - # Make prediction based on current r - xb0 = pred_x[0] + # Get initial departure + db0 = (x0.ravel() - xb0.ravel()) - # Apply observation operator to map to obs spcae - pred_obs = time_sel_matrix @ pred_x @ H + # Get approximate new observations + pred_obs = M_obs @ db0 + init_pred_obs # Calculate observation term of J_0 resid = (pred_obs.ravel() - obs_vals.ravel()) obs_term = 0.5*np.sum(resid.T @ Rinv @ resid) # Calculate initial departure term of J_0 based on original x0 - db0 = (x0.ravel() - xb0.ravel()) initial_term = 0.5*(db0.T @ Binv @ db0) # Cost is the sum of the two terms @@ -125,19 +120,26 @@ def _calc_time_sel_matrix(self, obs_steps_inds, n_pred_steps): jnp.arange(time_sel_matrix.shape[0]), obs_steps_inds].set(1) return time_sel_matrix - def _make_backprop_epoch(self, optimizer): + def _make_backprop_epoch(self, loss_func, optimizer): - @jax.jit + @ jax.jit def _backprop_epoch(x0_opt_state_tuple, i): - x0, dx0, i, opt_state = x0_opt_state_tuple + x0, i, opt_state = x0_opt_state_tuple + loss_val, dx0 = value_and_grad(loss_func, argnums=0)(x0) updates, opt_state = optimizer.update(dx0, opt_state) x0_new = optax.apply_updates(x0, updates) - return (x0_new, dx0, i+1, opt_state), x0_new + return (x0_new, i+1, opt_state), loss_val return _backprop_epoch - def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd, + def _gen_forecast_obs(self, x0, Ht, time_sel_matrix): + pred_x = self.step_forecast( + vector.StateVector(values=x0, store_as_jax=True), 11).values + pred_obs = time_sel_matrix @ pred_x @ Ht + return pred_obs, pred_obs + + def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, H=None, h=None, R=None, B=None, time_sel_matrix=None, n_steps=1): if H is None and h is None: @@ -148,19 +150,27 @@ def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd, h = self.h else: H = self.H + Ht = H.T + if R is None: if self.R is None: R = self._calc_default_R(obs_values, obs_error_sd) else: R = self.R + if B is None: if self.B is None: B = self._calc_default_B() else: B = self.B - x0 = xb - loss_func = self._make_loss(obs_values, H, B, R, time_sel_matrix, + # Get initial observations and jacobian + M_obs, pred_obs = jax.jacrev( + self._gen_forecast_obs, has_aux=True, argnums=0)( + x0, Ht, time_sel_matrix) + + loss_func = self._make_loss(x0, obs_values, pred_obs, Ht, B, R, + time_sel_matrix, M_obs, n_steps=n_steps) lr = optax.exponential_decay( @@ -170,21 +180,18 @@ def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd, opt_state = optimizer.init(x0) # Make initial forecast and calculate loss - loss_val, dx0 = value_and_grad(loss_func, argnums=0)(x0) - backprop_epoch_func = self._make_backprop_epoch(optimizer) - x0_opt_state_tuple, x0_vals = jax.lax.scan( - backprop_epoch_func, init=(x0, dx0, 0, opt_state), + backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer) + x0_opt_state_tuple, loss_vals = jax.lax.scan( + backprop_epoch_func, init=(x0, 0, opt_state), xs=None, length=self.num_epochs) - x0, dx0, i, opt_state = x0_opt_state_tuple + x0, i, opt_state = x0_opt_state_tuple - # Analysis - loss_val_end, dx0 = value_and_grad(loss_func, argnums=0)(x0) xa = self.step_forecast( vector.StateVector(values=x0, store_as_jax=True), n_steps=n_steps) - return xa, jnp.array([loss_val, loss_val_end]) + return xa, loss_vals def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1, obs_window_indices=[0]):