Skip to content

Commit

Permalink
let Parallel(+, f)(x, y, z) work like broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 10, 2024
1 parent 5f84b68 commit 74b7b9b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
26 changes: 19 additions & 7 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,11 @@ end
Create a layer which passes an input array to each path in
`layers`, before reducing the output with `connection`.
Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Obeys the similar rules to broadcasting:
* Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
* With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`.
* With multiple inputs and multiple layers, one input is passed to each layer,
thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
Expand Down Expand Up @@ -524,23 +527,32 @@ end

@layer :expand Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument

function _parallel_check(layers, xs)
nl = length(layers)
nx = length(xs)
if (nl != nx)
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
throw(ArgumentError("Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
end
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)

(m::Parallel)(xs::Tuple) = m(xs...)

function (m::Parallel)(xs...)
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...)
if length(m.layers) == 1
f = only(m.layers)
m.connection(map(x -> f(x), xs)...) # multiple arguments, one layer
else
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers
end
end

# (m::Parallel{<:Any, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}})(xs...) =
# m.connection(map(x -> only(m.layers)(x), xs)...) # multiple arguments, one layer

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Expand Down
9 changes: 7 additions & 2 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,14 @@ using Flux: activations
end

@testset "vararg input" begin
inputs = randn(10), randn(5), randn(4)
inputs = randn32(10), randn32(5), randn32(4)
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
@test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs
@test Parallel(+, sin, cos)(pi/2) 1
@test Parallel(+, sin, cos)(pi/2) 1 # one input, several layers
@test Parallel(/, abs)(3, -4) 3/4 # one layer, several inputs
@test Parallel(/, abs)((3, -4)) 3/4
@test Parallel(/; f=abs)(3, -4) 3/4
end

@testset "named access" begin
Expand Down Expand Up @@ -270,6 +273,8 @@ using Flux: activations
@test CNT[] == 2
Parallel(f_cnt, sin)(1)
@test CNT[] == 3
Parallel(f_cnt, sin)(1,2,3)
@test CNT[] == 4
end

# Ref https://github.com/FluxML/Flux.jl/issues/1673
Expand Down

0 comments on commit 74b7b9b

Please sign in to comment.