diff --git a/docs/API/linear.rst b/docs/API/linear.rst index 933a652..31df293 100644 --- a/docs/API/linear.rst +++ b/docs/API/linear.rst @@ -5,4 +5,6 @@ Linear .. autoclass:: Linear .. autoclass:: Identity .. autoclass:: Embedding -.. autoclass:: MLP \ No newline at end of file +.. autoclass:: MLP + +.. autofunction:: linear \ No newline at end of file diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 15b4fa7..24fa043 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -48,13 +48,23 @@ def is_lazy_init(_, in_features, *__, **___) -> bool: return in_features is None -def general_linear( +def linear( input: jax.Array, weight: jax.Array, bias: jax.Array | None, in_axis: tuple[int, ...], out_axis: int, ) -> jax.Array: + """A@B + C. + + Args: + input: input array. + weight: weight array. In the shape of (out_features, in_feature_1, in_feature_2, ...) + bias: bias array. In the shape of (out_features,) or ``None`` for no bias. + in_axis: which axes in the input to apply the linear layer to. ``tuple`` of ``ints`` + corresponding to the (in_feature_1, in_feature_2, ...) + out_axis: the axis to put the result. accepts ``in`` values. + """ in_axis = [axis if axis >= 0 else axis + input.ndim for axis in in_axis] lhs = "".join(str(axis) for axis in range(input.ndim)) # 0, 1, 2, 3 rhs = "F" + "".join(str(axis) for axis in in_axis) # F, 1, 2, 3 @@ -194,7 +204,7 @@ def __init__( @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, input: jax.Array) -> jax.Array: """Apply a linear transformation to the input.""" - return general_linear( + return linear( input=input, weight=self.weight, bias=self.bias, @@ -369,7 +379,7 @@ def batched_linear(key: jax.Array) -> Batched[Linear]: return sk.tree_mask(layer) self.in_linear = Linear(in_features, hidden_features, key=keys[0], **kwargs) - self.mid_linear = sk.tree_unmask(batched_linear(keys[1:-1])) + self.mid_linear: Batched[Linear] = sk.tree_unmask(batched_linear(keys[1:-1])) self.out_linear = Linear(hidden_features, out_features, key=keys[-1], **kwargs) def __call__(self, input: jax.Array) -> jax.Array: diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 61c07ae..74b369f 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -98,7 +98,7 @@ dropout_nd, random_cutout_nd, ) -from serket._src.nn.linear import MLP, Embedding, Identity, Linear +from serket._src.nn.linear import MLP, Embedding, Identity, Linear, linear from serket._src.nn.normalization import ( BatchNorm, EvalBatchNorm, @@ -262,10 +262,9 @@ "dropout_nd", "random_cutout_nd", # linear - "FNN", "MLP", + "linear", "Embedding", - "GeneralLinear", "Identity", "Linear", # norms