Skip to content

Commit

Permalink
Fix Diagonal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Apr 2, 2022
1 parent 7465575 commit 9ab71f7
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,18 @@ import Flux: activations

@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))

@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1, 2]) == [1, 2]
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]

@test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
@test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
end

@testset "Maxout" begin
Expand Down

0 comments on commit 9ab71f7

Please sign in to comment.