Skip to content

Commit

Permalink
Merge #1469
Browse files Browse the repository at this point in the history
1469: One-arg unsqueeze method r=CarloLucibello a=mcabbott

This makes `unsqeeze(3)` return a function.

In fact a `Base.Fix2` so that it has a predictable type, although overloading `show` turns out to be a pain for `<:Function` (perhaps it would be simpler to give it its own `struct`?), right now this is pretty within `Chain` but not when used by itself.

The only tests of the existing method seem to be doctests, so this adds a few more. And cross-links to `flatten`.

Plus a test for `outputsize`, now. 

Co-authored-by: Michael Abbott <me@pseudomac>
  • Loading branch information
bors[bot] and Michael Abbott authored Jan 19, 2021
2 parents 1ac78b5 + 0619f15 commit 02ea511
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
23 changes: 20 additions & 3 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
"""
flatten(x::AbstractArray)
Reshape arbitrarly-shaped input into a matrix-shaped output
preserving the last dimension size.
Equivalent to `reshape(x, :, size(x)[end])`.
Reshape arbitrarly-shaped input into a matrix-shaped output,
preserving the size of the last dimension.
See also [`unsqueeze`](@ref).
# Examples
```jldoctest
julia> rand(3,4,5) |> Flux.flatten |> size
(12, 5)
julia> xs = rand(Float32, 10,10,3,7);
julia> m = Chain(Conv((3,3), 3=>4, pad=1), Flux.flatten, Dense(400,33));
julia> xs |> m[1] |> size
(10, 10, 4, 7)
julia> xs |> m |> size
(33, 7)
```
"""
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
Expand Down
42 changes: 32 additions & 10 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,23 @@ create_bias(x, ::Any...) = x
"""
unsqueeze(xs, dim)
Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
Return `xs` reshaped into an array one dimensionality higher than `xs`,
where `dim` indicates in which dimension `xs` is extended.
See also [`flatten`](@ref), [`stack`](@ref).
# Examples
```jldoctest
julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64,3}:
[:, :, 1] =
1
3
[:, :, 2] =
2
4
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Array{Array{Int64,1},1}:
[1, 2]
Expand All @@ -255,19 +267,29 @@ julia> xs = [[1, 2], [3, 4], [5, 6]]
julia> Flux.unsqueeze(xs, 1)
1×3 Array{Array{Int64,1},2}:
[1, 2] [3, 4] [5, 6]
```
"""
unsqueeze(xs::AbstractArray, dim::Integer) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))

julia> Flux.unsqueeze([1 2; 3 4], 2)
2×1×2 Array{Int64,3}:
[:, :, 1] =
1
3
"""
unsqueeze(dim)
[:, :, 2] =
2
4
Returns a function which, acting on an array, inserts a dimension of size 1 at `dim`.
# Examples
```jldoctest
julia> rand(21, 22, 23) |> Flux.unsqueeze(2) |> size
(21, 1, 22, 23)
julia> m = Chain(Flux.unsqueeze(3), Flux.unsqueeze(4), Conv((3,3), 1=>7, pad=SamePad()));
julia> rand(Float32, 10, 10) |> m |> size
(10, 10, 7, 1)
```
"""
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim)

Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = print(io, "unsqueeze(", u.x, ")")

"""
stack(xs, dim)
Expand Down
3 changes: 3 additions & 0 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
m = flatten
@test outputsize(m, (5, 5, 3, 10)) == (75, 10)

m = Flux.unsqueeze(3)
@test outputsize(m, (5, 7, 13)) == (5, 7, 1, 13)

m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10))
@test outputsize(m, (10, 10, 3, 50)) == (10, 50)
@test outputsize(m, (10, 10, 3, 2)) == (10, 2)
Expand Down

0 comments on commit 02ea511

Please sign in to comment.