diff --git a/autodiscjax/modules/misc.py b/autodiscjax/modules/misc.py index c8b419d..544ee9a 100644 --- a/autodiscjax/modules/misc.py +++ b/autodiscjax/modules/misc.py @@ -12,12 +12,14 @@ def __init__(self, out_treedef, out_shape, out_dtype, low=None, high=None): super().__init__(out_treedef, out_shape, out_dtype) if isinstance(low, float): - self.low = self.out_treedef.unflatten([low]*self.out_treedef.num_leaves) + self.low = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype), + low, self.out_shape, self.out_dtype) else: self.low = low if isinstance(high, float): - self.high = self.out_treedef.unflatten([high]*self.out_treedef.num_leaves) + self.high = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype), + high, self.out_shape, self.out_dtype) else: self.high = high