Skip to content

Commit

Permalink
Update custom_transform.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 24, 2023
1 parent ecd3bcb commit f207403
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion serket/_src/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,44 @@ def tree_state(tree: T, **kwargs) -> T:
running_var=f32[5](μ=1.00, σ=0.00, ∈[1.00,1.00])
)]
"""
# tree_state handles state initialization for different layers
# like RNN cells, BatchNorm, KMeans, etc.
# one challenge is that the state initialization rule for a layer
# may depend only on the layer itself, or may depend on the layer
# and the input. For example, the state initialization rule for
# BatchNorm depends on the layer and sample input, but the state initialization
# rule for some RNN cells (e.g. LSTM) does not depend on the input.
# This poses a challenge for the user to pass the correct input
# to the state initialization rule.

types = tuple(set(tree_state.state_dispatcher.registry) - {object})

def is_leaf(x: Any) -> bool:
return isinstance(x, types)

def dispatch_func(leaf):
return tree_state.state_dispatcher(leaf, **kwargs)
try:
return tree_state.state_dispatcher(leaf, **kwargs)
# handle error from wrong signature
except TypeError as e:

Check warning on line 111 in serket/_src/custom_transform.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/custom_transform.py#L111

Added line #L111 was not covered by tests
# check if the leaf has a state rule
mro_sans_object = type(leaf).__mro__[:-1]
registry = tree_state.state_dispatcher.registry

Check warning on line 114 in serket/_src/custom_transform.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/custom_transform.py#L113-L114

Added lines #L113 - L114 were not covered by tests

for mro in mro_sans_object:
if mro in registry:
func = tree_state.state_dispatcher.registry[mro]
break

Check warning on line 119 in serket/_src/custom_transform.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/custom_transform.py#L116-L119

Added lines #L116 - L119 were not covered by tests
else:
# not registered to `tree_state`
raise type(e)(e)

Check warning on line 122 in serket/_src/custom_transform.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/custom_transform.py#L122

Added line #L122 was not covered by tests

# maybe wrong signature
raise type(e)(

Check warning on line 125 in serket/_src/custom_transform.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/custom_transform.py#L125

Added line #L125 was not covered by tests
f"For {type(leaf)=} with the registered state rule signature {sk.tree_str(func)}, "
f"The following error occurred:\n{e}\n"
f"Pass the correct inputs to the registered state rule as arguments to `tree_state`."
)

return jax.tree_map(dispatch_func, tree, is_leaf=is_leaf)

Expand Down

0 comments on commit f207403

Please sign in to comment.