Skip to content

Commit

Permalink
add tests for output size
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianM-C committed Feb 20, 2024
1 parent d1d8592 commit b73c174
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
3 changes: 0 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ function parameterlength(d::Dense{use_bias}) where {use_bias}
end
statelength(d::Dense) = 0

inputsize(d::Dense) = (d.in_dims,)
outputsize(d::Dense) = (d.out_dims,)

Check warning on line 202 in src/layers/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/basic.jl#L202

Added line #L202 was not covered by tests

@inline function (d::Dense{false})(x::AbstractVecOrMat, ps, st::NamedTuple)
Expand Down Expand Up @@ -304,7 +303,6 @@ end
parameterlength(d::Scale{use_bias}) where {use_bias} = (1 + use_bias) * prod(d.dims)
statelength(d::Scale) = 0

inputsize(d::Scale) = d.dims
outputsize(d::Scale) = d.dims

Check warning on line 306 in src/layers/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/basic.jl#L306

Added line #L306 was not covered by tests

function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple)
Expand Down Expand Up @@ -511,5 +509,4 @@ function Base.show(io::IO, e::Embedding)
return print(io, "Embedding(", e.in_dims, " => ", e.out_dims, ")")
end

inputsize(e::Embedding) = (e.in_dims,)
outputsize(e::Embedding) = (e.out_dims,)

Check warning on line 512 in src/layers/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/basic.jl#L512

Added line #L512 was not covered by tests
9 changes: 9 additions & 0 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
x = randn(rng, 6, 3) |> aType

@test size(layer(x, ps, st)[1]) == (2, 3, 3)
@test Lux.outputsize(layer) == (2, 3)

@jet layer(x, ps, st)
__f = x -> sum(first(layer(x, ps, st)))
Expand Down Expand Up @@ -103,6 +104,8 @@ end

@test size(first(Lux.apply(layer, randn(10), ps, st))) == (5,)
@test size(first(Lux.apply(layer, randn(10, 2), ps, st))) == (5, 2)

@test LuxCore.outputsize(layer) == (5,)
end

@testset "zeros" begin
Expand Down Expand Up @@ -178,6 +181,8 @@ end
@test size(first(Lux.apply(layer, randn(10) |> aType, ps, st))) == (10, 5)
@test size(first(Lux.apply(layer, randn(10, 5, 2) |> aType, ps, st))) ==
(10, 5, 2)

@test LuxCore.outputsize(layer) == (10, 5)
end

@testset "zeros" begin
Expand Down Expand Up @@ -274,6 +279,8 @@ end
@test size(layer((x, y), ps, st)[1]) == (3, 1)
@test sum(abs2, layer((x, y), ps, st)[1]) == 0.0f0

@test LuxCore.outputsize(layer) == (3,)

@jet layer((x, y), ps, st)
__f = (x, y, ps) -> sum(first(layer((x, y), ps, st)))
@eval @test_gradients $__f $x $y $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu
Expand Down Expand Up @@ -316,6 +323,8 @@ end

@test size(ps.weight) == (embed_size, vocab_size)

@test LuxCore.outputsize(layer) == (4,)

x = rand(1:vocab_size, 1)[1]
y, st_ = layer(x, ps, st)
@test size(layer(x, ps, st)[1]) == (embed_size,)
Expand Down

0 comments on commit b73c174

Please sign in to comment.