From ff73b44f5af732b2617bff0ab0b7c3f0b390c26c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 13 Mar 2024 10:28:38 -0400 Subject: [PATCH] nicer errors when called on zero inputs --- src/layers/basic.jl | 5 ++++- test/layers/basic.jl | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3fb03e1621..a8cb0a9d1c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -572,7 +572,8 @@ end function _parallel_check(layers, xs) nl = length(layers) - nx = length(xs) + @assert nl > 1 # dispatch handles nl==1 cases + nx = length(xs) if (nl != nx) throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs")) end @@ -591,6 +592,8 @@ end (m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted (m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity +(m::Parallel)() = throw(ArgumentError("Parallel layer cannot take 0 inputs")) + Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) = diff --git a/test/layers/basic.jl b/test/layers/basic.jl index cd7918e487..d09579c257 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -261,8 +261,15 @@ using Flux: activations end @testset "trivial cases" begin + # zero inputs, always an error + @test_throws ArgumentError Parallel(hcat)() + @test_throws ArgumentError Parallel(hcat, inv)() + @test_throws ArgumentError Parallel(hcat, inv, sqrt)() + + # zero layers -- not useful... can we make this an error without a breaking change? @test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple @test Parallel(hcat)(1) == hcat() + @test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once. end