Skip to content

Commit

Permalink
fix linear failing
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 8, 2024
1 parent 9d2c99b commit b6eb925
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
11 changes: 4 additions & 7 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,10 @@ def __init__(
if len(out_axis) != len(out_features):
raise ValueError(f"{len(out_axis)=} != {len(out_features)=}")

_, in_f = zip(*sorted(enumerate(in_features), key=lambda ik: in_axis[ik[0]]))
_, out_f = zip(*sorted(enumerate(out_features), key=lambda ik: out_axis[ik[0]]))

self.in_features = in_f
self.out_features = out_f
self.in_axis = sorted(in_axis)
self.out_axis = sorted(out_axis)
self.in_features = in_features
self.out_features = out_features
self.in_axis = in_axis
self.out_axis = out_axis
self.weight_init = weight_init
self.bias_init = bias_init

Expand Down
2 changes: 1 addition & 1 deletion tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_general_linear_shape(
assert layer(x).shape == desired_shape


def test_linear_errors():
def test_linear():
x = jnp.ones([1, 2, 3, 4])
layer = sk.nn.Linear(
in_features=(1, 2),
Expand Down

0 comments on commit b6eb925

Please sign in to comment.