From 26978a345cc15067185237632604af83880a44a2 Mon Sep 17 00:00:00 2001 From: Mayalen Etcheverry Date: Wed, 22 Feb 2023 10:35:03 +0100 Subject: [PATCH] minor fix in ClampModule to deal with low, high when given as floats --- autodiscjax/modules/misc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autodiscjax/modules/misc.py b/autodiscjax/modules/misc.py index c8b419d..544ee9a 100644 --- a/autodiscjax/modules/misc.py +++ b/autodiscjax/modules/misc.py @@ -12,12 +12,14 @@ def __init__(self, out_treedef, out_shape, out_dtype, low=None, high=None): super().__init__(out_treedef, out_shape, out_dtype) if isinstance(low, float): - self.low = self.out_treedef.unflatten([low]*self.out_treedef.num_leaves) + self.low = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype), + low, self.out_shape, self.out_dtype) else: self.low = low if isinstance(high, float): - self.high = self.out_treedef.unflatten([high]*self.out_treedef.num_leaves) + self.high = jtu.tree_map(lambda val, shape, dtype: val * jnp.ones(shape=shape, dtype=dtype), + high, self.out_shape, self.out_dtype) else: self.high = high