From f207403d664126fb80499c2c0ae50cd6ba5af491 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Sat, 25 Nov 2023 03:45:23 +0900 Subject: [PATCH] Update custom_transform.py --- serket/_src/custom_transform.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/serket/_src/custom_transform.py b/serket/_src/custom_transform.py index 0f9e971..d3915d5 100644 --- a/serket/_src/custom_transform.py +++ b/serket/_src/custom_transform.py @@ -89,6 +89,15 @@ def tree_state(tree: T, **kwargs) -> T: running_var=f32[5](μ=1.00, σ=0.00, ∈[1.00,1.00]) )] """ + # tree_state handles state initialization for different layers + # like RNN cells, BatchNorm, KMeans, etc. + # one challenge is that the state initialization rule for a layer + # may depend only on the layer itself, or may depend on the layer + # and the input. For example, the state initialization rule for + # BatchNorm depends on the layer and sample input, but the state initialization + # rule for some RNN cells (e.g. LSTM) does not depend on the input. + # This poses a challenge for the user to pass the correct input + # to the state initialization rule. types = tuple(set(tree_state.state_dispatcher.registry) - {object}) @@ -96,7 +105,28 @@ def is_leaf(x: Any) -> bool: return isinstance(x, types) def dispatch_func(leaf): - return tree_state.state_dispatcher(leaf, **kwargs) + try: + return tree_state.state_dispatcher(leaf, **kwargs) + # handle error from wrong signature + except TypeError as e: + # check if the leaf has a state rule + mro_sans_object = type(leaf).__mro__[:-1] + registry = tree_state.state_dispatcher.registry + + for mro in mro_sans_object: + if mro in registry: + func = tree_state.state_dispatcher.registry[mro] + break + else: + # not registered to `tree_state` + raise type(e)(e) + + # maybe wrong signature + raise type(e)( + f"For {type(leaf)=} with the registered state rule signature {sk.tree_str(func)}, " + f"The following error occurred:\n{e}\n" + f"Pass the correct inputs to the registered state rule as arguments to `tree_state`." + ) return jax.tree_map(dispatch_func, tree, is_leaf=is_leaf)