Skip to content

Commit

Permalink
nicer errors when called on zero inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 4, 2024
1 parent 5fae6a9 commit ff73b44
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ end

function _parallel_check(layers, xs)
nl = length(layers)
nx = length(xs)
@assert nl > 1 # dispatch handles nl==1 cases
nx = length(xs)
if (nl != nx)
throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
Expand All @@ -591,6 +592,8 @@ end
(m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted
(m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity

(m::Parallel)() = throw(ArgumentError("Parallel layer cannot take 0 inputs"))

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Expand Down
7 changes: 7 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,15 @@ using Flux: activations
end

@testset "trivial cases" begin
# zero inputs, always an error
@test_throws ArgumentError Parallel(hcat)()
@test_throws ArgumentError Parallel(hcat, inv)()
@test_throws ArgumentError Parallel(hcat, inv, sqrt)()

# zero layers -- not useful... can we make this an error without a breaking change?
@test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple
@test Parallel(hcat)(1) == hcat()

@test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once.
end

Expand Down

0 comments on commit ff73b44

Please sign in to comment.