From 37c838de0fd3a31bdf1701f7474977faa09e0a1c Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 18 Jan 2021 12:30:28 +0100 Subject: [PATCH 1/2] one-arg Fix2 unsqueeze method --- src/layers/stateless.jl | 23 ++++++++++++++++++--- src/utils.jl | 44 +++++++++++++++++++++++++++++++---------- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 7395aa7a62..1a3a0df5ec 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,9 +1,26 @@ """ flatten(x::AbstractArray) -Reshape arbitrarly-shaped input into a matrix-shaped output -preserving the last dimension size. -Equivalent to `reshape(x, :, size(x)[end])`. +Reshape arbitrarly-shaped input into a matrix-shaped output, +preserving the size of the last dimension. + +See also [`unsqueeze`](@ref). + +# Examples +```jldoctest +julia> rand(3,4,5) |> Flux.flatten |> size +(12, 5) + +julia> xs = rand(Float32, 10,10,3,7); + +julia> m = Chain(Conv((3,3), 3=>4, pad=1), Flux.flatten, Dense(400,33)); + +julia> xs |> m[1] |> size +(10, 10, 4, 7) + +julia> xs |> m |> size +(33, 7) +``` """ function flatten(x::AbstractArray) return reshape(x, :, size(x)[end]) diff --git a/src/utils.jl b/src/utils.jl index 33955870ef..cfb4478f84 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -241,11 +241,23 @@ create_bias(x, ::Any...) = x """ unsqueeze(xs, dim) -Return `xs` reshaped into an `Array` one dimensionality higher than `xs`, +Return `xs` reshaped into an array one dimensionality higher than `xs`, where `dim` indicates in which dimension `xs` is extended. +See also [`flatten`](@ref), [`stack`](@ref). + # Examples ```jldoctest +julia> Flux.unsqueeze([1 2; 3 4], 2) +2×1×2 Array{Int64,3}: +[:, :, 1] = + 1 + 3 + +[:, :, 2] = + 2 + 4 + julia> xs = [[1, 2], [3, 4], [5, 6]] 3-element Array{Array{Int64,1},1}: [1, 2] @@ -255,19 +267,31 @@ julia> xs = [[1, 2], [3, 4], [5, 6]] julia> Flux.unsqueeze(xs, 1) 1×3 Array{Array{Int64,1},2}: [1, 2] [3, 4] [5, 6] +``` +""" +unsqueeze(xs::AbstractArray, dim::Integer) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) -julia> Flux.unsqueeze([1 2; 3 4], 2) -2×1×2 Array{Int64,3}: -[:, :, 1] = - 1 - 3 +""" + unsqueeze(dim) -[:, :, 2] = - 2 - 4 +Returns a function which, acting on an array, inserts a dimension of size 1 at `dim`. + +# Examples +```jldoctest +julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size +(21, 1, 22, 23) + +julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad())); + +julia> rand(Float32, 10, 10) |> m |> size +(10, 10, 7, 1) ``` """ -unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim) + +Base.show(io::IO, u::Base.Fix2{typeof(unsqueeze)}) = print(io, "unsqueeze(", u.x, ")") +Base.show(io::IO, ::MIME"text/plain", u::Base.Fix2{typeof(unsqueeze)}) = show(io, u) # at top level +Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = show(io, u) # within Chain etc. """ stack(xs, dim) From 0619f15cd3e50cb25d945924568ebe503be7658e Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 18 Jan 2021 12:52:35 +0100 Subject: [PATCH 2/2] outputsize test, and simplify printing --- src/utils.jl | 4 +--- test/outputsize.jl | 3 +++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index cfb4478f84..297e3c72e1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -289,9 +289,7 @@ julia> rand(Float32, 10, 10) |> m |> size """ unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim) -Base.show(io::IO, u::Base.Fix2{typeof(unsqueeze)}) = print(io, "unsqueeze(", u.x, ")") -Base.show(io::IO, ::MIME"text/plain", u::Base.Fix2{typeof(unsqueeze)}) = show(io, u) # at top level -Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = show(io, u) # within Chain etc. +Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = print(io, "unsqueeze(", u.x, ")") """ stack(xs, dim) diff --git a/test/outputsize.jl b/test/outputsize.jl index ba183a985c..b3e531f0f4 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -22,6 +22,9 @@ m = flatten @test outputsize(m, (5, 5, 3, 10)) == (75, 10) + m = Flux.unsqueeze(3) + @test outputsize(m, (5, 7, 13)) == (5, 7, 1, 13) + m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) @test outputsize(m, (10, 10, 3, 50)) == (10, 50) @test outputsize(m, (10, 10, 3, 2)) == (10, 2)