Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465426495
  • Loading branch information
learned_optimization authors committed Aug 5, 2022
1 parent 4447601 commit 5a95af5
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions learned_optimization/circular_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Generic, Tuple, TypeVar

import jax
from jax import tree_util
import jax.numpy as jnp

CircularBufferState = collections.namedtuple("CircularBufferState",
Expand Down Expand Up @@ -54,7 +55,7 @@ def build_one(x):
tiled = jnp.tile(expanded, [self.size] + [1] * len(x.shape))
return jnp.asarray(tiled, dtype=x.dtype)

empty_buffer = jax.tree_map(build_one, self.abstract_value)
empty_buffer = tree_util.tree_map(build_one, self.abstract_value)
return CircularBufferState(
idx=jnp.asarray(0, jnp.int64),
values=(empty_buffer,
Expand All @@ -71,7 +72,8 @@ def do_update(src, to_set):
else:
return src.at[idx, :].set(to_set)

new_jax_array = jax.tree_map(do_update, state.values, (value, state.idx))
new_jax_array = tree_util.tree_map(do_update, state.values,
(value, state.idx))
return CircularBufferState(idx=state.idx + 1, values=new_jax_array)

def _reorder(self, vals, idx):
Expand Down Expand Up @@ -100,13 +102,13 @@ def stack_reorder(self, state: CircularBufferState) -> Tuple[T, jnp.ndarray]: #
candidate = jnp.clip((state.values[1] - state.idx + self.size), -1,
self.size)
mask = self._reorder(jnp.where(candidate == -1, 0, 1), state.idx)
return jax.tree_map(lambda x: self._reorder(x, state.idx),
state.values[0]), mask
return tree_util.tree_map(lambda x: self._reorder(x, state.idx),
state.values[0]), mask

@functools.partial(jax.jit, static_argnums=(0,))
def gather_from_present(
self, state: CircularBufferState, idxs: jnp.ndarray) -> T: # pytype: disable=invalid-annotation
"""Get the values from for each idx in the past."""
offset = (idxs % self.size)
idx = (state.idx + offset) % self.size
return jax.tree_map(lambda x: x[idx], state.values[0])
return tree_util.tree_map(lambda x: x[idx], state.values[0])

0 comments on commit 5a95af5

Please sign in to comment.