Skip to content

Commit

Permalink
change implementation to dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 13, 2024
1 parent aa3433f commit afcbac9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
23 changes: 11 additions & 12 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}}
layers::T
end

_ParallelONE{T} = Parallel{T, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}}

Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
Expand All @@ -573,20 +575,17 @@ function _parallel_check(layers, xs)
end
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)

(m::Parallel)(xs::Tuple) = m(xs...)

function (m::Parallel)(xs...)
if length(m.layers) == 1
f = only(m.layers)
m.connection(map(x -> f(x), xs)...) # multiple arguments, one layer
else
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers
end
function (m::Parallel)(x, ys...)
xs = (x, ys...)
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers
end

# (m::Parallel{<:Any, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}})(xs...) =
# m.connection(map(x -> only(m.layers)(x), xs)...) # multiple arguments, one layer
(m::_ParallelONE)(x, ys...) =
m.connection(map(z -> only(m.layers)(z), (x, ys...))...) # multiple arguments, one layer

(m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted
(m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Expand Down
2 changes: 1 addition & 1 deletion test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ using Flux: activations
end

@testset "concat size" begin
input = randn(10, 2)
input = randn32(10, 2)
@test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4)
@test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4)
end
Expand Down

0 comments on commit afcbac9

Please sign in to comment.