From 74b7b9b49bd36f98800999021ef1e725681e9a53 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:07:52 -0400 Subject: [PATCH] let Parallel(+, f)(x, y, z) work like broadcasting --- src/layers/basic.jl | 26 +++++++++++++++++++------- test/layers/basic.jl | 9 +++++++-- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 018b19b31d..12600f7057 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -471,8 +471,11 @@ end Create a layer which passes an input array to each path in `layers`, before reducing the output with `connection`. -Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`. -If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. +Obeys the similar rules to broadcasting: +* Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`. +* With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`. +* With multiple inputs and multiple layers, one input is passed to each layer, + thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor. These can be accessed by indexing: `m[1] == m[:name]` is the first layer. @@ -524,23 +527,32 @@ end @layer :expand Parallel -(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) -(m::Parallel)(xs::Tuple) = m(xs...) +(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument function _parallel_check(layers, xs) nl = length(layers) nx = length(xs) if (nl != nx) - throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs")) + throw(ArgumentError("Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs")) end end ChainRulesCore.@non_differentiable _parallel_check(nl, nx) +(m::Parallel)(xs::Tuple) = m(xs...) + function (m::Parallel)(xs...) - _parallel_check(m.layers, xs) - m.connection(map(|>, xs, Tuple(m.layers))...) + if length(m.layers) == 1 + f = only(m.layers) + m.connection(map(x -> f(x), xs)...) # multiple arguments, one layer + else + _parallel_check(m.layers, xs) + m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers + end end +# (m::Parallel{<:Any, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}})(xs...) = +# m.connection(map(x -> only(m.layers)(x), xs)...) # multiple arguments, one layer + Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) = diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 95da13f0c9..83795a09a0 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -234,11 +234,14 @@ using Flux: activations end @testset "vararg input" begin - inputs = randn(10), randn(5), randn(4) + inputs = randn32(10), randn32(5), randn32(4) @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) @test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs - @test Parallel(+, sin, cos)(pi/2) ≈ 1 + @test Parallel(+, sin, cos)(pi/2) ≈ 1 # one input, several layers + @test Parallel(/, abs)(3, -4) ≈ 3/4 # one layer, several inputs + @test Parallel(/, abs)((3, -4)) ≈ 3/4 + @test Parallel(/; f=abs)(3, -4) ≈ 3/4 end @testset "named access" begin @@ -270,6 +273,8 @@ using Flux: activations @test CNT[] == 2 Parallel(f_cnt, sin)(1) @test CNT[] == 3 + Parallel(f_cnt, sin)(1,2,3) + @test CNT[] == 4 end # Ref https://github.com/FluxML/Flux.jl/issues/1673