diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 3237ca3..6b5b2cb 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -5,8 +5,10 @@ import numpy as np import jax.numpy as jnp import jax.scipy as jscipy -from jax import grad +from jax import grad, value_and_grad +from jax.scipy import optimize import jax +import optax from dabench import dacycler, vector @@ -37,6 +39,8 @@ class Var4DBackprop(dacycler.DACycler): steps_per_window (int): Number of timesteps per analysis window. learning_rate (float): LR for backpropogation. Default is 1e-5, but DA results can be quite sensitive to this parameter. + lr_decay (float): Exponential learning rate decay. If set to 1, + no decay. Default is 1. obs_window_indices (list): Timestep indices where observations fall within each analysis window. For example, if analysis window is 0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01, @@ -53,6 +57,7 @@ def __init__(self, H=None, h=None, learning_rate=1e-5, + lr_decay=1.0, num_epochs=20, steps_per_window=1, obs_window_indices=[0], @@ -61,6 +66,7 @@ def __init__(self, self.num_epochs = num_epochs self.learning_rate = learning_rate + self.lr_decay = lr_decay self.steps_per_window = steps_per_window self.obs_window_indices = obs_window_indices @@ -83,26 +89,27 @@ def _calc_default_R(self, obs_values, obs_error_sd): def _calc_default_B(self): return jnp.identity(self.system_dim) - def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, rb0, n_steps): + def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, n_steps): """Define loss function based on 4dvar cost""" Rinv = jscipy.linalg.inv(R) Binv = jscipy.linalg.inv(B) - def loss_4dvarcost(r0): + @jax.jit + def loss_4dvarcost(x0, xb0): # Make prediction based on current r - pred_r = self.step_forecast( - vector.StateVector(values=r0, store_as_jax=True), + pred_x = self.step_forecast( + vector.StateVector(values=x0, store_as_jax=True), n_steps=n_steps).values # Apply observation operator to map to obs spcae - pred_obs = time_sel_matrix @ pred_r @ H + pred_obs = time_sel_matrix @ pred_x @ H # Calculate observation term of J_0 - resid = pred_obs.ravel() - obs_vals.ravel() + resid = (pred_obs.ravel() - obs_vals.ravel()) obs_term = 0.5*np.sum(resid.T @ Rinv @ resid) # Calculate initial departure term of J_0 based on original x0 - db0 = pred_r[0].ravel() - rb0.ravel() + db0 = (x0.ravel() - xb0.ravel()) initial_term = 0.5*(db0.T @ Binv @ db0) # Cost is the sum of the two terms @@ -116,12 +123,18 @@ def _calc_time_sel_matrix(self, obs_steps_inds, n_pred_steps): jnp.arange(time_sel_matrix.shape[0]), obs_steps_inds].set(1) return time_sel_matrix - def _make_backprop_epoch(self, loss_func): + def _make_backprop_epoch(self, loss_func, optimizer): - def _backprop_epoch(i, r0): - dr0 = grad(loss_func, argnums=0)(r0) - r0 -= self.learning_rate*dr0 - return r0 + @jax.jit + def _backprop_epoch(x0_opt_state_tuple, i): + x0, xb0, i, opt_state = x0_opt_state_tuple +# xb0 = jax.lax.cond(i % 5 == 0, lambda: x0, lambda: xb0) + loss_val, dx0 = value_and_grad(loss_func, argnums=0)(x0, xb0) + updates, opt_state = optimizer.update(dx0, opt_state) + xb0 = x0 + x0_new = optax.apply_updates(x0, updates) + + return (x0_new, xb0, i+1, opt_state), loss_val return _backprop_epoch @@ -147,32 +160,32 @@ def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd, else: B = self.B - r0 = xb + x0 = xb loss_func = self._make_loss(obs_values, H, B, R, time_sel_matrix, - rb0=r0, n_steps=n_steps) - backprop_epoch_func = self._make_backprop_epoch(loss_func) - r0 = jax.lax.fori_loop(0, self.num_epochs, backprop_epoch_func, r0) - - ra = self.step_forecast( - vector.StateVector(values=r0, store_as_jax=True), - n_steps=n_steps) + n_steps=n_steps) - return ra, None + lr = optax.exponential_decay( + self.learning_rate, + self.num_epochs, self.lr_decay) + optimizer = optax.sgd(lr) + opt_state = optimizer.init(x0) - def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1, - obs_window_indices=[0]): - """Perform one step of DA Cycle + backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer) + x0_opt_state_tuple, loss_vals = jax.lax.scan( + backprop_epoch_func, init=(x0, x0, 0, opt_state), + xs=None, length=self.num_epochs) - Args: - xb: - yo: - H + x0, xb0, i, opt_state = x0_opt_state_tuple + xa = self.step_forecast( + vector.StateVector(values=x0, store_as_jax=True), + n_steps=n_steps) - Returns: - vector.StateVector containing analysis results + return xa, loss_vals - """ + def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1, + obs_window_indices=[0]): + """Perform one step of DA Cycle""" time_sel_matrix = self._calc_time_sel_matrix(obs_window_indices, n_steps) if H is not None or h is None: @@ -181,9 +194,11 @@ def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1, H, R, B, time_sel_matrix=time_sel_matrix, n_steps=n_steps) else: return self._cycle_obsop( - xb, yo, h, R, B, time_sel_matrix=time_sel_matrix, n_steps=n_steps) + xb, yo, h, R, B, time_sel_matrix=time_sel_matrix, + n_steps=n_steps) def step_forecast(self, xa, n_steps=1): + """Perform forecast using model object""" if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args: return self.model_obj.forecast(xa, n_steps=n_steps) else: @@ -197,19 +212,17 @@ def step_forecast(self, xa, n_steps=1): out.append(xi) return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - def _cycle_and_forecast(self, state_obs_tuple, filtered_idx): - cur_state_vals = state_obs_tuple[0] - obs_vals = state_obs_tuple[1] - obs_times = state_obs_tuple[2] - obs_loc_indices = state_obs_tuple[3] - obs_error_sd = state_obs_tuple[4] + def _cycle_and_forecast(self, cur_state_vals, filtered_idx): + obs_vals = self._obs_vector.values + obs_loc_indices = self._obs_vector.location_indices + obs_error_sd = self._obs_error_sd cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0], len(filtered_idx)) cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices, filtered_idx[0], len(filtered_idx)) - analysis, kh = self.step_cycle( + analysis, loss_vals = self.step_cycle( vector.StateVector(values=cur_state_vals, store_as_jax=True), vector.ObsVector(values=cur_obs_vals, location_indices=cur_obs_loc_indices, @@ -218,8 +231,7 @@ def _cycle_and_forecast(self, state_obs_tuple, filtered_idx): n_steps=self.steps_per_window, obs_window_indices=self.obs_window_indices) - return (analysis.values[-1], obs_vals, obs_times, obs_loc_indices, - obs_error_sd), analysis.values[:-1] + return analysis.values[-1], (analysis.values[:-1], loss_vals) def cycle(self, input_state, @@ -272,11 +284,16 @@ def cycle(self, rtol=0) )[0] for cur_time in all_times]) + self._obs_vector = obs_vector + self._obs_error_sd = obs_error_sd cur_state, all_values = jax.lax.scan( self._cycle_and_forecast, - (input_state.values, obs_vector.values, obs_vector.times, - obs_vector.location_indices, obs_error_sd), - all_filtered_idx) - - return vector.StateVector(values=jnp.vstack(all_values), - store_as_jax=True) + init=input_state.values, + xs=all_filtered_idx) + all_losses = all_values[1] + print(all_losses[:, -3:]) + all_values = all_values[0] + + return vector.StateVector( + values=jnp.vstack(all_values), + store_as_jax=True)