Skip to content

Commit

Permalink
remove def_init_entry
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 31, 2024
1 parent dbd0735 commit ff3b193
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 72 deletions.
1 change: 0 additions & 1 deletion docs/API/common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

.. autofunction:: tree_state
.. autofunction:: tree_eval
.. autofunction:: def_init_entry
.. autofunction:: def_act_entry

.. autoclass:: Sequential
Expand Down
2 changes: 0 additions & 2 deletions serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from serket._src.containers import RandomChoice, Sequential
from serket._src.custom_transform import tree_eval, tree_state
from serket._src.nn.activation import def_act_entry
from serket._src.nn.initialization import def_init_entry

from . import cluster, image, nn

Expand Down Expand Up @@ -66,7 +65,6 @@
"image",
"tree_eval",
"tree_state",
"def_init_entry",
"def_act_entry",
# containers
"Sequential",
Expand Down
37 changes: 0 additions & 37 deletions serket/_src/nn/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,40 +85,3 @@ def _(init: None):
@resolve_init.def_type(ABCCallable)
def _(init: Callable):
return jtu.Partial(init)


def def_init_entry(key: str, init_func: InitFuncType) -> None:
"""Register a custom initialization function key for use in ``serket`` layers.
Args:
key: The key to register the function under.
init_func: The function to register. must take three arguments: a key,
a shape, and a dtype, and return an array of the given shape and dtype.
Note:
By design initialization function can be passed directly to ``serket`` layers
without registration. This function is useful if you want to
represent initialization functions as a string in a configuration file.
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> import jax.random as jr
>>> import math
>>> def my_init_func(key, shape, dtype=jnp.float32):
... return jnp.arange(math.prod(shape), dtype=dtype).reshape(shape)
>>> sk.def_init_entry("my_init", my_init_func)
"""
import inspect

signature = inspect.signature(init_func)

if key in init_map:
raise ValueError(f"`init_key` {key=} already registered")

if len(signature.parameters) != 3:
# verify its a three argument function
raise ValueError(f"`init_func` {len(signature.parameters)=} != 3")

init_map[key] = init_func
32 changes: 0 additions & 32 deletions tests/test_init.py

This file was deleted.

0 comments on commit ff3b193

Please sign in to comment.