Skip to content

Commit

Permalink
docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 24, 2023
1 parent fa87ea7 commit 8e244df
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 60 deletions.
25 changes: 13 additions & 12 deletions serket/_src/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ class KMeans(KmeansBase):
>>> import serket as sk
>>> x = jr.uniform(jr.PRNGKey(0), shape=(500, 2))
>>> layer = sk.cluster.KMeans(clusters=5, tol=1e-6)
>>> # state initialization by array and key to initialize the centers
>>> state = sk.tree_state(layer, array=x, key=jr.PRNGKey(0))
>>> # state initialization by input array and key to initialize the centers
>>> state = sk.tree_state(layer, input=x, key=jr.PRNGKey(0))
>>> labels, state = layer(x, state)
>>> plt.scatter(x[:, 0], x[:, 1], c=labels[:, 0], cmap="jet_r") # doctest: +SKIP
>>> plt.scatter(state.centers[:, 0], state.centers[:, 1], c="r", marker="o", linewidths=4) # doctest: +SKIP
Expand All @@ -162,15 +162,15 @@ class KMeans(KmeansBase):
>>> import jax.random as jr
>>> features = 3
>>> clusters = 4
>>> x = jr.uniform(jr.PRNGKey(0), shape=(100, features))
>>> k1, k2 = jr.split(jr.PRNGKey(0))
>>> input = jr.uniform(k1, shape=(100, features))
>>> # layer definition
>>> layer = sk.cluster.KMeans(clusters=clusters, tol=1e-6)
>>> # state initialization
>>> state = sk.tree_state(layer, array=x, key=jr.PRNGKey(0))
>>> x, state = layer(x, state)
>>> state = sk.tree_state(layer, input=input, key=k2)
>>> output, state = layer(input, state)
>>> eval_layer = sk.tree_eval(layer)
>>> y = jr.uniform(jr.PRNGKey(0), shape=(1, features))
>>> y, eval_state = eval_layer(y, state)
>>> output, eval_state = eval_layer(input, state)
>>> # centers are not updated
>>> assert jnp.all(eval_state.centers == state.centers)
Expand All @@ -193,7 +193,7 @@ class KMeans(KmeansBase):
... layer = sk.cluster.KMeans(clusters=clusters, tol=tol)
... shape = image.shape
... image = image.reshape(-1, 1)
... state = sk.tree_state(layer, array=image, key=key)
... state = sk.tree_state(layer, input=image, key=key)
... labels, state = layer(image, state)
... return state.centers[labels].reshape(shape).astype(jnp.uint8)
>>> fig, axes = plt.subplots(2, 2, figsize=(10, 10))
Expand Down Expand Up @@ -234,8 +234,9 @@ def __call__(
>>> import serket as sk
>>> import jax.random as jr
>>> key = jr.PRNGKey(0)
>>> input = jr.uniform(key, shape=(100, 2))
>>> layer = sk.cluster.KMeans(clusters=5, tol=1e-6)
>>> state = sk.tree_state(layer, array=x, key=key)
>>> state = sk.tree_state(layer, input=input, key=key)
Returns:
A tuple containing the labels and a ``KMeansState``.
Expand Down Expand Up @@ -266,9 +267,9 @@ def __call__(


@tree_state.def_state(KMeans)
def _(layer: KMeans, *, array: jax.Array, key: jax.Array, **_) -> KMeansState:
minval, maxval = array.min(), array.max()
shape = (layer.clusters, array.shape[1])
def _(layer: KMeans, *, input: jax.Array, key: jax.Array, **_) -> KMeansState:
minval, maxval = input.min(), input.max()
shape = (layer.clusters, input.shape[1])
centers = jr.uniform(key, minval=minval, maxval=maxval, shape=shape)
return KMeansState(centers=centers, error=centers + jnp.inf, iters=0)

Expand Down
96 changes: 48 additions & 48 deletions serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,12 @@ class Conv1D(ConvND):
>>> import jax.random as jr
>>> layer = sk.nn.Conv1D(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5))
>>> print(layer(input).shape)
(2, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5)
Note:
Expand Down Expand Up @@ -774,12 +774,12 @@ class Conv2D(ConvND):
>>> import jax.random as jr
>>> layer = sk.nn.Conv2D(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5)
Note:
Expand Down Expand Up @@ -868,12 +868,12 @@ class Conv3D(ConvND):
>>> import jax.random as jr
>>> layer = sk.nn.Conv3D(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5, 5)
Note:
Expand Down Expand Up @@ -996,12 +996,12 @@ class FFTConv1D(FFTConvND):
>>> import jax.random as jr
>>> layer = sk.nn.FFTConv1D(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5))
>>> print(layer(input).shape)
(2, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5)
Note:
Expand Down Expand Up @@ -1090,12 +1090,12 @@ class FFTConv2D(FFTConvND):
>>> import jax.random as jr
>>> layer = sk.nn.FFTConv2D(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5)
Note:
Expand Down Expand Up @@ -1184,12 +1184,12 @@ class FFTConv3D(FFTConvND):
>>> import jax.random as jr
>>> layer = sk.nn.FFTConv3D(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5, 5)
Note:
Expand Down Expand Up @@ -1363,12 +1363,12 @@ class Conv1DTranspose(ConvNDTranspose):
>>> import jax.random as jr
>>> layer = sk.nn.Conv1DTranspose(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5))
>>> print(layer(input).shape)
(2, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5)
Note:
Expand Down Expand Up @@ -1460,12 +1460,12 @@ class Conv2DTranspose(ConvNDTranspose):
>>> import jax
>>> layer = sk.nn.Conv2DTranspose(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5)
Note:
Expand Down Expand Up @@ -1558,12 +1558,12 @@ class Conv3DTranspose(ConvNDTranspose):
>>> import jax.random as jr
>>> layer = sk.nn.Conv3DTranspose(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5, 5)
Note:
Expand Down Expand Up @@ -1690,12 +1690,12 @@ class FFTConv1DTranspose(FFTConvNDTranspose):
>>> import jax.random as jr
>>> layer = sk.nn.FFTConv1DTranspose(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5))
>>> print(layer(input).shape)
(2, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5)
Note:
Expand Down Expand Up @@ -1788,12 +1788,12 @@ class FFTConv2DTranspose(FFTConvNDTranspose):
>>> import jax.random as jr
>>> layer = sk.nn.FFTConv2DTranspose(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5)
Note:
Expand Down Expand Up @@ -1886,12 +1886,12 @@ class FFTConv3DTranspose(FFTConvNDTranspose):
>>> import jax.random as jr
>>> layer = sk.nn.FFTConv3DTranspose(1, 2, 3, key=jr.PRNGKey(0))
>>> # single sample
>>> x = jnp.ones((1, 5, 5, 5))
>>> print(layer(x).shape)
>>> input = jnp.ones((1, 5, 5, 5))
>>> print(layer(input).shape)
(2, 5, 5, 5)
>>> # batch of samples
>>> x = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(x).shape)
>>> input = jnp.ones((2, 1, 5, 5, 5))
>>> print(jax.vmap(layer)(input).shape)
(2, 2, 5, 5, 5)
Note:
Expand Down

0 comments on commit 8e244df

Please sign in to comment.