Skip to content

Commit

Permalink
Updates to 4dbackprop from experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
kysolvik committed Nov 7, 2023
1 parent 04afde5 commit c8031be
Showing 1 changed file with 65 additions and 48 deletions.
113 changes: 65 additions & 48 deletions dabench/dacycler/_var4d_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import numpy as np
import jax.numpy as jnp
import jax.scipy as jscipy
from jax import grad
from jax import grad, value_and_grad
from jax.scipy import optimize
import jax
import optax

from dabench import dacycler, vector

Expand Down Expand Up @@ -37,6 +39,8 @@ class Var4DBackprop(dacycler.DACycler):
steps_per_window (int): Number of timesteps per analysis window.
learning_rate (float): LR for backpropogation. Default is 1e-5, but
DA results can be quite sensitive to this parameter.
lr_decay (float): Exponential learning rate decay. If set to 1,
no decay. Default is 1.
obs_window_indices (list): Timestep indices where observations fall
within each analysis window. For example, if analysis window is
0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01,
Expand All @@ -53,6 +57,7 @@ def __init__(self,
H=None,
h=None,
learning_rate=1e-5,
lr_decay=1.0,
num_epochs=20,
steps_per_window=1,
obs_window_indices=[0],
Expand All @@ -61,6 +66,7 @@ def __init__(self,

self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.lr_decay = lr_decay
self.steps_per_window = steps_per_window
self.obs_window_indices = obs_window_indices

Expand All @@ -83,26 +89,27 @@ def _calc_default_R(self, obs_values, obs_error_sd):
def _calc_default_B(self):
return jnp.identity(self.system_dim)

def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, rb0, n_steps):
def _make_loss(self, obs_vals, H, B, R, time_sel_matrix, n_steps):
"""Define loss function based on 4dvar cost"""
Rinv = jscipy.linalg.inv(R)
Binv = jscipy.linalg.inv(B)

def loss_4dvarcost(r0):
@jax.jit
def loss_4dvarcost(x0, xb0):
# Make prediction based on current r
pred_r = self.step_forecast(
vector.StateVector(values=r0, store_as_jax=True),
pred_x = self.step_forecast(
vector.StateVector(values=x0, store_as_jax=True),
n_steps=n_steps).values

# Apply observation operator to map to obs spcae
pred_obs = time_sel_matrix @ pred_r @ H
pred_obs = time_sel_matrix @ pred_x @ H

# Calculate observation term of J_0
resid = pred_obs.ravel() - obs_vals.ravel()
resid = (pred_obs.ravel() - obs_vals.ravel())
obs_term = 0.5*np.sum(resid.T @ Rinv @ resid)

# Calculate initial departure term of J_0 based on original x0
db0 = pred_r[0].ravel() - rb0.ravel()
db0 = (x0.ravel() - xb0.ravel())
initial_term = 0.5*(db0.T @ Binv @ db0)

# Cost is the sum of the two terms
Expand All @@ -116,12 +123,18 @@ def _calc_time_sel_matrix(self, obs_steps_inds, n_pred_steps):
jnp.arange(time_sel_matrix.shape[0]), obs_steps_inds].set(1)
return time_sel_matrix

def _make_backprop_epoch(self, loss_func):
def _make_backprop_epoch(self, loss_func, optimizer):

def _backprop_epoch(i, r0):
dr0 = grad(loss_func, argnums=0)(r0)
r0 -= self.learning_rate*dr0
return r0
@jax.jit
def _backprop_epoch(x0_opt_state_tuple, i):
x0, xb0, i, opt_state = x0_opt_state_tuple
# xb0 = jax.lax.cond(i % 5 == 0, lambda: x0, lambda: xb0)
loss_val, dx0 = value_and_grad(loss_func, argnums=0)(x0, xb0)
updates, opt_state = optimizer.update(dx0, opt_state)
xb0 = x0
x0_new = optax.apply_updates(x0, updates)

return (x0_new, xb0, i+1, opt_state), loss_val

return _backprop_epoch

Expand All @@ -147,32 +160,32 @@ def _cycle_obsop(self, xb, obs_values, obs_loc_indices, obs_error_sd,
else:
B = self.B

r0 = xb
x0 = xb
loss_func = self._make_loss(obs_values, H, B, R, time_sel_matrix,
rb0=r0, n_steps=n_steps)
backprop_epoch_func = self._make_backprop_epoch(loss_func)
r0 = jax.lax.fori_loop(0, self.num_epochs, backprop_epoch_func, r0)

ra = self.step_forecast(
vector.StateVector(values=r0, store_as_jax=True),
n_steps=n_steps)
n_steps=n_steps)

return ra, None
lr = optax.exponential_decay(
self.learning_rate,
self.num_epochs, self.lr_decay)
optimizer = optax.sgd(lr)
opt_state = optimizer.init(x0)

def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1,
obs_window_indices=[0]):
"""Perform one step of DA Cycle
backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer)
x0_opt_state_tuple, loss_vals = jax.lax.scan(
backprop_epoch_func, init=(x0, x0, 0, opt_state),
xs=None, length=self.num_epochs)

Args:
xb:
yo:
H
x0, xb0, i, opt_state = x0_opt_state_tuple

xa = self.step_forecast(
vector.StateVector(values=x0, store_as_jax=True),
n_steps=n_steps)

Returns:
vector.StateVector containing analysis results
return xa, loss_vals

"""
def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1,
obs_window_indices=[0]):
"""Perform one step of DA Cycle"""
time_sel_matrix = self._calc_time_sel_matrix(obs_window_indices,
n_steps)
if H is not None or h is None:
Expand All @@ -181,9 +194,11 @@ def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1,
H, R, B, time_sel_matrix=time_sel_matrix, n_steps=n_steps)
else:
return self._cycle_obsop(
xb, yo, h, R, B, time_sel_matrix=time_sel_matrix, n_steps=n_steps)
xb, yo, h, R, B, time_sel_matrix=time_sel_matrix,
n_steps=n_steps)

def step_forecast(self, xa, n_steps=1):
"""Perform forecast using model object"""
if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args:
return self.model_obj.forecast(xa, n_steps=n_steps)
else:
Expand All @@ -197,19 +212,17 @@ def step_forecast(self, xa, n_steps=1):
out.append(xi)
return vector.StateVector(jnp.vstack(xi), store_as_jax=True)

def _cycle_and_forecast(self, state_obs_tuple, filtered_idx):
cur_state_vals = state_obs_tuple[0]
obs_vals = state_obs_tuple[1]
obs_times = state_obs_tuple[2]
obs_loc_indices = state_obs_tuple[3]
obs_error_sd = state_obs_tuple[4]
def _cycle_and_forecast(self, cur_state_vals, filtered_idx):
obs_vals = self._obs_vector.values
obs_loc_indices = self._obs_vector.location_indices
obs_error_sd = self._obs_error_sd

cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0],
len(filtered_idx))
cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices,
filtered_idx[0],
len(filtered_idx))
analysis, kh = self.step_cycle(
analysis, loss_vals = self.step_cycle(
vector.StateVector(values=cur_state_vals, store_as_jax=True),
vector.ObsVector(values=cur_obs_vals,
location_indices=cur_obs_loc_indices,
Expand All @@ -218,8 +231,7 @@ def _cycle_and_forecast(self, state_obs_tuple, filtered_idx):
n_steps=self.steps_per_window,
obs_window_indices=self.obs_window_indices)

return (analysis.values[-1], obs_vals, obs_times, obs_loc_indices,
obs_error_sd), analysis.values[:-1]
return analysis.values[-1], (analysis.values[:-1], loss_vals)

def cycle(self,
input_state,
Expand Down Expand Up @@ -272,11 +284,16 @@ def cycle(self,
rtol=0)
)[0] for cur_time in all_times])

self._obs_vector = obs_vector
self._obs_error_sd = obs_error_sd
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
(input_state.values, obs_vector.values, obs_vector.times,
obs_vector.location_indices, obs_error_sd),
all_filtered_idx)

return vector.StateVector(values=jnp.vstack(all_values),
store_as_jax=True)
init=input_state.values,
xs=all_filtered_idx)
all_losses = all_values[1]
print(all_losses[:, -3:])
all_values = all_values[0]

return vector.StateVector(
values=jnp.vstack(all_values),
store_as_jax=True)

0 comments on commit c8031be

Please sign in to comment.