Skip to content

Commit

Permalink
docstring nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 24, 2023
1 parent 723ad2e commit fa87ea7
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 173 deletions.
2 changes: 1 addition & 1 deletion serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3775,4 +3775,4 @@ class Conv3DLocal(ConvNDLocal):
- https://github.com/google/flax/blob/main/flax/linen/linear.py
"""

spatial_ndim: int = 3
spatial_ndim: int = 3
65 changes: 22 additions & 43 deletions serket/_src/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ class Dropout(sk.TreeClass):
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> layer = sk.nn.Dropout(0.5)
>>> print(layer(jnp.ones([10]), key=jr.PRNGKey(0)))
[2. 0. 2. 2. 2. 2. 2. 2. 0. 0.]
>>> input = jnp.ones(10)
>>> key = jr.PRNGKey(0)
>>> output = layer(input, key=key)
Note:
Use :func:`.tree_eval` to turn off dropout during evaluation by converting
Expand Down Expand Up @@ -193,8 +194,9 @@ class Dropout1D(DropoutND):
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> layer = sk.nn.Dropout1D(0.5)
>>> print(layer(jnp.ones((1, 10)), key=jr.PRNGKey(0)))
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
>>> input = jnp.ones((1, 10))
>>> key = jr.PRNGKey(0)
>>> output = layer(input, key=key)
Note:
Use :func:`.tree_eval` to turn off dropout during evaluation by converting
Expand Down Expand Up @@ -228,12 +230,9 @@ class Dropout2D(DropoutND):
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> layer = sk.nn.Dropout2D(0.5)
>>> print(layer(jnp.ones((1, 5, 5)), key=jr.PRNGKey(0)))
[[[2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2.]]]
>>> input = jnp.ones((1, 5, 5))
>>> key = jr.PRNGKey(0)
>>> output = layer(input, key=key)
Note:
Use :func:`.tree_eval` to turn off dropout during evaluation by converting
Expand Down Expand Up @@ -267,12 +266,9 @@ class Dropout3D(DropoutND):
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> layer = sk.nn.Dropout3D(0.5)
>>> print(layer(jnp.ones((1, 2, 2, 2)), key=jr.PRNGKey(0))) # doctest: +NORMALIZE_WHITESPACE
[[[[2. 2.]
[2. 2.]]
<BLANKLINE>
[[2. 2.]
[2. 2.]]]]
>>> input = jnp.ones((1, 2, 2, 2))
>>> key = jr.PRNGKey(0)
>>> output = layer(input, key=key)
Note:
Use :func:`.tree_eval` to turn off dropout during evaluation by converting
Expand Down Expand Up @@ -342,8 +338,10 @@ class RandomCutout1D(RandomCutoutND):
>>> import jax.numpy as jnp
>>> import serket as sk
>>> import jax.random as jr
>>> print(sk.nn.RandomCutout1D(5)(jnp.ones((1, 10)) * 100, key=jr.PRNGKey(0)))
[[100. 100. 100. 100. 100. 0. 0. 0. 0. 0.]]
>>> layer = sk.nn.RandomCutout1D(5)
>>> input = jnp.ones((1, 10)) * 100
>>> key = jr.PRNGKey(0)
>>> output = layer(input, key=key)
Reference:
- https://arxiv.org/abs/1708.04552
Expand Down Expand Up @@ -371,19 +369,10 @@ class RandomCutout2D(RandomCutoutND):
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> x = jnp.arange(1,101).reshape(1, 10, 10)
>>> layer = sk.nn.RandomCutout2D(shape=(3,2), cutout_count=2, fill_value=0)
>>> input = jnp.arange(1,101).reshape(1, 10, 10)
>>> key = jr.PRNGKey(0)
>>> print(sk.nn.RandomCutout2D(shape=(3,2), cutout_count=2, fill_value=0)(x,key=key))
[[[ 1 2 3 4 5 6 7 8 9 10]
[ 11 12 13 14 15 16 17 18 19 20]
[ 21 22 23 24 25 26 27 28 29 30]
[ 31 32 33 34 0 0 37 38 39 40]
[ 41 42 43 44 0 0 47 48 49 50]
[ 51 52 53 54 0 0 57 58 59 60]
[ 61 62 0 0 65 66 67 68 69 70]
[ 71 72 0 0 75 76 77 78 79 80]
[ 81 82 0 0 85 86 87 88 89 90]
[ 91 92 93 94 95 96 97 98 99 100]]]
>>> output = layer(input, key=key)
Reference:
- https://arxiv.org/abs/1708.04552
Expand All @@ -409,20 +398,10 @@ class RandomCutout3D(RandomCutoutND):
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> x = jnp.arange(1, 2 * 5 * 5 + 1).reshape(1, 2, 5, 5)
>>> layer = sk.nn.RandomCutout3D(shape=(2, 2, 2), cutout_count=2, fill_value=0)
>>> input = jnp.arange(1, 2 * 5 * 5 + 1).reshape(1, 2, 5, 5)
>>> key = jr.PRNGKey(0)
>>> print(sk.nn.RandomCutout3D(shape=(2, 2, 2), cutout_count=2, fill_value=0)(x, key=key))
[[[[ 1 2 0 0 5]
[ 6 7 0 0 10]
[ 0 0 13 14 15]
[ 0 0 18 19 20]
[21 22 23 24 25]]
<BLANKLINE>
[[26 27 0 0 30]
[31 32 0 0 35]
[ 0 0 38 39 40]
[ 0 0 43 44 45]
[46 47 48 49 50]]]]
>>> output = layer(input, key=key)
Reference:
- https://arxiv.org/abs/1708.04552
Expand Down
95 changes: 52 additions & 43 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,11 @@ class Linear(sk.TreeClass):
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import serket as sk
>>> array_1 = jnp.ones((1, 5))
>>> array_2 = jnp.ones((1, 6))
>>> layer = sk.nn.Linear((5,6), 7, key=jr.PRNGKey(0))
>>> layer(array_1, array_2).shape
>>> input_1 = jnp.ones((1, 5)) # 5 features
>>> input_2 = jnp.ones((1, 6)) # 6 features
>>> key = jr.PRNGKey(0)
>>> layer = sk.nn.Linear((5, 6), 7, key=key)
>>> layer(input_1, input_2).shape
(1, 7)
Note:
Expand All @@ -148,18 +149,19 @@ class Linear(sk.TreeClass):
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax
>>> k1, k2 = jr.split(jr.PRNGKey(0))
>>> @sk.autoinit
... class Linears(sk.TreeClass):
... l1: sk.nn.Linear = sk.nn.Linear(None, 32, key=k1)
... l2: sk.nn.Linear = sk.nn.Linear((32,), 10, key=k2)
>>> class Linears(sk.TreeClass):
... def __init__(self, *, key: jax.Array):
... k1, k2 = jr.split(key)
... self.l1 = sk.nn.Linear(None, 32, key=k1)
... self.l2 = sk.nn.Linear(32, 10, key=k2)
... def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x, y)))
>>> lazy_linears = Linears()
>>> x = jnp.ones([100, 28])
>>> y = jnp.ones([100, 56])
>>> _, material_linears = lazy_linears.at["__call__"](x, y)
>>> material_linears.l1.in_features
>>> key = jr.PRNGKey(0)
>>> lazy_layer = Linears(key=key)
>>> input_1 = jnp.ones([100, 28])
>>> input_2 = jnp.ones([100, 56])
>>> _, material_layer = lazy_layer.at["__call__"](input_1, input_2)
>>> material_layer.l1.in_features
(28, 56)
Note:
Expand All @@ -174,8 +176,8 @@ def __init__(
self,
in_features: int | tuple[int, ...] | None,
out_features: int,
key: jax.Array,
*,
key: jax.Array,
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
dtype: DType = jnp.float32,
Expand Down Expand Up @@ -318,10 +320,13 @@ class Embedding(sk.TreeClass):
>>> import serket as sk
>>> import jax.random as jr
>>> # 10 words in the vocabulary, each word is represented by a 3 dimensional vector
>>> table = sk.nn.Embedding(10, 3, key=jr.PRNGKey(0))
>>> key = jr.PRNGKey(0)
>>> table = sk.nn.Embedding(10, 3, key=key)
>>> # take the last word in the vocab
>>> table(jnp.array([9]))
Array([[0.43810904, 0.35078037, 0.13254273]], dtype=float32)
>>> input = jnp.array([9])
>>> output = table(input)
>>> output.shape
(1, 3)
"""

def __init__(self, in_features: int, out_features: int, key: jax.Array):
Expand Down Expand Up @@ -360,8 +365,10 @@ class FNN(sk.TreeClass):
>>> import jax.numpy as jnp
>>> import serket as sk
>>> import jax.random as jr
>>> fnn = sk.nn.FNN([10, 5, 2], key=jr.PRNGKey(0))
>>> fnn(jnp.ones((3, 10))).shape
>>> key = jr.PRNGKey(0)
>>> layer = sk.nn.FNN([10, 5, 2], key=key)
>>> input = jnp.ones((3, 10))
>>> layer(input).shape
(3, 2)
Note:
Expand All @@ -383,9 +390,10 @@ class FNN(sk.TreeClass):
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy_fnn = sk.nn.FNN([None, 10, 2, 1], key=jr.PRNGKey(0))
>>> _, material_fnn = lazy_fnn.at['__call__'](jnp.ones([1, 10]))
>>> material_fnn.linear_0.in_features
>>> key = jr.PRNGKey(0)
>>> lazy_layer = sk.nn.FNN([None, 10, 2, 1], key=key)
>>> _, material_layer = lazy_layer.at['__call__'](jnp.ones([1, 10]))
>>> material_layer.linear_0.in_features
(10,)
"""

Expand Down Expand Up @@ -462,8 +470,10 @@ class MLP(sk.TreeClass):
>>> import jax.numpy as jnp
>>> import serket as sk
>>> import jax.random as jr
>>> mlp = sk.nn.MLP(1, 2, hidden_features=4, num_hidden_layers=2, key=jr.PRNGKey(0))
>>> mlp(jnp.ones((3, 1))).shape
>>> key = jr.PRNGKey(0)
>>> layer = sk.nn.MLP(1, 2, hidden_features=4, num_hidden_layers=2, key=key)
>>> input = jnp.ones((3, 1))
>>> layer(input).shape
(3, 2)
Note:
Expand All @@ -486,16 +496,14 @@ class MLP(sk.TreeClass):
>>> import jax.random as jr
>>> import serket as sk
>>> import numpy.testing as npt
>>> fnn = sk.nn.FNN([1] + [4] * 100 + [2], key=jr.PRNGKey(0))
>>> mlp = sk.nn.MLP(1, 2, hidden_features=4, num_hidden_layers=100, key=jr.PRNGKey(0))
>>> x = jnp.ones((100, 1))
>>> fnn_jaxpr = jax.make_jaxpr(fnn)(x)
>>> mlp_jaxpr = jax.make_jaxpr(mlp)(x)
>>> npt.assert_allclose(fnn(x), mlp(x), atol=1e-6)
>>> len(fnn_jaxpr.jaxpr.eqns)
403
>>> len(mlp_jaxpr.jaxpr.eqns)
10
>>> key = jr.PRNGKey(0)
>>> fnn = sk.nn.FNN([1] + [4] * 100 + [2], key=key)
>>> mlp = sk.nn.MLP(1, 2, hidden_features=4, num_hidden_layers=100, key=key)
>>> input = jnp.ones((100, 1))
>>> fnn_jaxpr = jax.make_jaxpr(fnn)(input)
>>> mlp_jaxpr = jax.make_jaxpr(mlp)(input)
>>> npt.assert_allclose(fnn(input), mlp(input), atol=1e-6)
>>> assert len(fnn_jaxpr.jaxpr.eqns) > len(mlp_jaxpr.jaxpr.eqns)
Note:
:class:`.MLP` supports lazy initialization, meaning that the weights and
Expand All @@ -509,9 +517,11 @@ class MLP(sk.TreeClass):
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy_mlp = sk.nn.MLP(None, 1, num_hidden_layers=2, hidden_features=10, key=jr.PRNGKey(0))
>>> _, material_mlp = lazy_mlp.at['__call__'](jnp.ones([1, 10]))
>>> material_mlp.linear_i.in_features
>>> key = jr.PRNGKey(0)
>>> lazy_layer = sk.nn.MLP(None, 1, num_hidden_layers=2, hidden_features=10, key=key)
>>> input = jnp.ones([1, 10])
>>> _, material_layer = lazy_layer.at['__call__'](input)
>>> material_layer.linear_i.in_features
(10,)
"""

Expand All @@ -535,14 +545,13 @@ def __init__(
self.act = resolve_activation(act)
kwargs = dict(weight_init=weight_init, bias_init=bias_init, dtype=dtype)

@jax.vmap
def batched_linear(key: jax.Array) -> Batched[Linear]:
return sk.tree_mask(
Linear(hidden_features, hidden_features, key=key, **kwargs)
)
layer = Linear(hidden_features, hidden_features, key=key, **kwargs)
# mask non-jaxtype on return
return sk.tree_mask(layer)

self.linear_i = Linear(in_features, hidden_features, key=keys[0], **kwargs)
self.linear_h = sk.tree_unmask(batched_linear(keys[1:-1]))
self.linear_h = sk.tree_unmask(jax.vmap(batched_linear)(keys[1:-1]))
self.linear_o = Linear(hidden_features, out_features, key=keys[-1], **kwargs)

def __call__(self, input: jax.Array) -> jax.Array:
Expand Down
Loading

0 comments on commit fa87ea7

Please sign in to comment.