Skip to content

Commit

Permalink
Use .at[].get() instead of dynamic slice to fix bug when padded indic…
Browse files Browse the repository at this point in the history
…es have len greater than available obs values
  • Loading branch information
kysolvik committed Jul 16, 2024
1 parent ccab1fd commit 544eeff
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
9 changes: 4 additions & 5 deletions dabench/dacycler/_var4d.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,9 @@ def _cycle_and_forecast(self, cur_state_vals, filtered_idx):
obs_mask = filtered_idx > 0
filtered_idx = filtered_idx - 1

cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0],
len(filtered_idx))
cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices,
filtered_idx[0],
len(filtered_idx))
cur_obs_vals = jnp.array(obs_vals).at[filtered_idx].get()
cur_obs_loc_indices = jnp.array(obs_loc_indices).at[filtered_idx].get()

analysis, kh = self.step_cycle(
vector.StateVector(values=cur_state_vals, store_as_jax=True),
vector.ObsVector(values=cur_obs_vals,
Expand Down Expand Up @@ -353,6 +351,7 @@ def cycle(self,

self._obs_vector = obs_vector
self._obs_error_sd = obs_error_sd

cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
init=input_state.values,
Expand Down
9 changes: 4 additions & 5 deletions dabench/dacycler/_var4d_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,9 @@ def _cycle_and_forecast(self, cur_state_vals, filtered_idx):
obs_mask = filtered_idx > 0
filtered_idx = filtered_idx - 1

cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0],
len(filtered_idx))
cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices,
filtered_idx[0],
len(filtered_idx))
cur_obs_vals = jnp.array(obs_vals).at[filtered_idx].get()
cur_obs_loc_indices = jnp.array(obs_loc_indices).at[filtered_idx].get()

analysis, loss_vals = self.step_cycle(
vector.StateVector(values=cur_state_vals, store_as_jax=True),
vector.ObsVector(values=cur_obs_vals,
Expand Down Expand Up @@ -329,6 +327,7 @@ def cycle(self,

self._obs_vector = obs_vector
self._obs_error_sd = obs_error_sd

cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
init=input_state.values,
Expand Down

0 comments on commit 544eeff

Please sign in to comment.