diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000000..ee89ad89b1 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,5 @@ +indent = 4 +whitespace_in_kwargs = true +normalize_line_endings = "unix" +style = "blue" +import_to_using = false \ No newline at end of file diff --git a/src/Flux.jl b/src/Flux.jl index 80d999de38..59aa858f15 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -10,24 +10,67 @@ using MacroTools: @forward using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient -export Chain, Dense, Maxout, SkipConnection, Parallel, flatten, - RNN, LSTM, GRU, GRUv3, - SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, - AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, - Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, - Upsample, PixelShuffle, - params, fmap, cpu, gpu, f32, f64, - testmode!, trainmode! +export Chain, + Dense, + Maxout, + SkipConnection, + Parallel, + flatten, + RNN, + LSTM, + GRU, + GRUv3, + SamePad, + Conv, + CrossCor, + ConvTranspose, + DepthwiseConv, + AdaptiveMaxPool, + AdaptiveMeanPool, + GlobalMaxPool, + GlobalMeanPool, + MaxPool, + MeanPool, + Dropout, + AlphaDropout, + LayerNorm, + BatchNorm, + InstanceNorm, + GroupNorm, + Upsample, + PixelShuffle, + params, + fmap, + cpu, + gpu, + f32, + f64, + testmode!, + trainmode! include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs using .Optimise: skip -export Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, OADAM, - ADAMW, RADAM, AdaBelief, InvDecay, ExpDecay, - WeightDecay, ClipValue, ClipNorm - +export Descent, + ADAM, + Momentum, + Nesterov, + RMSProp, + ADAGrad, + AdaMax, + ADADelta, + AMSGrad, + NADAM, + OADAM, + ADAMW, + RADAM, + AdaBelief, + InvDecay, + ExpDecay, + WeightDecay, + ClipValue, + ClipNorm using CUDA const use_cuda = Ref{Union{Nothing,Bool}}(nothing) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 0ef3c65308..b1b71fe100 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -5,7 +5,8 @@ using ..CUDA import ..Flux: Flux import Zygote using Zygote: @adjoint -import NNlib, NNlibCUDA +using NNlib: NNlib +using NNlibCUDA: NNlibCUDA include("cudnn.jl") diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 4a3b2618c8..9ddff9efc9 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,20 +1,34 @@ import NNlibCUDA: batchnorm, ∇batchnorm -function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, - cache=nothing) where T<:Union{Float32, Float64} - - @assert BN.affine "BatchNorm: only affine=true supported on gpu" - @assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu" - @assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels" - return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; - cache=cache, alpha=1, beta=0, eps=BN.ϵ, - training=Flux._isactive(BN))) +function (BN::Flux.BatchNorm)( + x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing +) where {T<:Union{Float32,Float64}} + @assert BN.affine "BatchNorm: only affine=true supported on gpu" + @assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu" + @assert length(BN.β) == size(x, ndims(x) - 1) "BatchNorm: input has wronng number of channels" + return BN.λ.( + batchnorm( + BN.γ, + BN.β, + x, + BN.μ, + BN.σ², + BN.momentum; + cache = cache, + alpha = 1, + beta = 0, + eps = BN.ϵ, + training = Flux._isactive(BN), + ), + ) end @adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...) - y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) - function batchnorm_pullback(Δ) - ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing - end - y, batchnorm_pullback + y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) + function batchnorm_pullback(Δ) + return ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., + nothing, nothing, + nothing + end + return y, batchnorm_pullback end diff --git a/src/data/Data.jl b/src/data/Data.jl index d00aeb709c..48c8b93d5c 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -10,17 +10,19 @@ export DataLoader ## Also remove the following deps: ## AbstractTrees, ZipFiles, CodecZLib -import ..Flux -import SHA +using ..Flux: Flux +using SHA: SHA -deprecation_message() = @warn("Flux's datasets are deprecated, please use the package MLDatasets.jl") +function deprecation_message() + @warn("Flux's datasets are deprecated, please use the package MLDatasets.jl") +end function deps(path...) - if isnothing(@__DIR__) # sysimages - joinpath("deps", path...) - else - joinpath(@__DIR__, "..", "..", "deps", path...) - end + if isnothing(@__DIR__) # sysimages + joinpath("deps", path...) + else + joinpath(@__DIR__, "..", "..", "deps", path...) + end end function download_and_verify(url, path, hash) @@ -30,16 +32,16 @@ function download_and_verify(url, path, hash) bytes2hex(SHA.sha256(f)) end if hash_download !== hash - msg = "Hash Mismatch!\n" + msg = "Hash Mismatch!\n" msg *= " Expected sha256: $hash\n" msg *= " Calculated sha256: $hash_download" error(msg) end - mv(tmppath, path; force=true) + return mv(tmppath, path; force = true) end function __init__() - mkpath(deps()) + return mkpath(deps()) end include("mnist.jl") @@ -50,7 +52,8 @@ export FashionMNIST include("cmudict.jl") export CMUDict -using .CMUDict; export cmudict +using .CMUDict +export cmudict include("tree.jl") include("sentiment.jl") diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl index d096727c2a..a301711670 100644 --- a/src/data/cmudict.jl +++ b/src/data/cmudict.jl @@ -8,20 +8,25 @@ const version = "0.7b" const cache_prefix = "https://cache.julialang.org" function load() - suffixes_and_hashes = [("" , "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"), - (".phones" , "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"), - (".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027")] - if isdir(deps("cmudict")) - if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes) - return + suffixes_and_hashes = [ + ("", "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"), + (".phones", "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"), + (".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027"), + ] + if isdir(deps("cmudict")) + if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes) + return nothing + end + end + @info "Downloading CMUDict dataset" + mkpath(deps("cmudict")) + for (x, hash) in suffixes_and_hashes + download_and_verify( + "$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", + deps("cmudict", "cmudict$x"), + hash, + ) end - end - @info "Downloading CMUDict dataset" - mkpath(deps("cmudict")) - for (x, hash) in suffixes_and_hashes - download_and_verify("$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", - deps("cmudict", "cmudict$x"), hash) - end end """ @@ -29,10 +34,18 @@ end Return a `Vector` containing the phones used in the CMU Pronouncing Dictionary. """ function phones() - deprecation_message() - load() - Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String), - "\n", keepempty = false), "\t"))) + deprecation_message() + load() + return Symbol.( + first.( + split.( + split( + read(deps("cmudict", "cmudict.phones"), String), "\n"; keepempty = false + ), + "\t", + ), + ), + ) end """ @@ -42,10 +55,11 @@ A symbol is a phone with optional auxiliary symbols, indicating for example the amount of stress on the phone. """ function symbols() - deprecation_message() - load() - Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String), - "\n", keepempty = false)) + deprecation_message() + load() + return Symbol.( + split(read(deps("cmudict", "cmudict.symbols"), String), "\n"; keepempty = false) + ) end """ @@ -53,10 +67,12 @@ end Return the unfiltered CMU Pronouncing Dictionary. """ function rawdict() - deprecation_message() - load() - Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in - filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n")))) + deprecation_message() + load() + return Dict( + String(xs[1]) => Symbol.(xs[2:end]) for xs in + filter(!isempty, split.(split(read(deps("cmudict", "cmudict"), String), "\n"))) + ) end validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s) @@ -68,8 +84,8 @@ It is filtered so each word contains only ASCII characters and a combination of word characters (as determined by the regex engine using `\\w`), '-' and '.'. """ function cmudict() - deprecation_message() - filter(p -> validword(p.first), rawdict()) + deprecation_message() + return filter(p -> validword(p.first), rawdict()) end alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.'] diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl index 422747776c..408bc2277a 100644 --- a/src/data/dataloader.jl +++ b/src/data/dataloader.jl @@ -74,7 +74,7 @@ julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize= 10×4 Matrix{Int8} ``` """ -function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG) +function DataLoader(data; batchsize = 1, shuffle = false, partial = true, rng = GLOBAL_RNG) batchsize > 0 || throw(ArgumentError("Need positive batchsize")) n = _nobs(data) @@ -83,39 +83,43 @@ function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_R batchsize = n end imax = partial ? n : n - batchsize + 1 - DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng) + return DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng) end -@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize] +@propagate_inbounds function Base.iterate(d::DataLoader, i = 0) # returns data in d.indices[i+1:i+batchsize] i >= d.imax && return nothing if d.shuffle && i == 0 shuffle!(d.rng, d.indices) end nexti = min(i + d.batchsize, d.nobs) - ids = d.indices[i+1:nexti] + ids = d.indices[(i + 1):nexti] batch = _getobs(d.data, ids) return (batch, nexti) end function Base.length(d::DataLoader) n = d.nobs / d.batchsize - d.partial ? ceil(Int,n) : floor(Int,n) + return d.partial ? ceil(Int, n) : floor(Int, n) end _nobs(data::AbstractArray) = size(data)[end] -function _nobs(data::Union{Tuple, NamedTuple}) +function _nobs(data::Union{Tuple,NamedTuple}) length(data) > 0 || throw(ArgumentError("Need at least one data input")) n = _nobs(data[1]) for i in keys(data) ni = _nobs(data[i]) - n == ni || throw(DimensionMismatch("All data inputs should have the same number of observations, i.e. size in the last dimension. " * - "But data[$(repr(first(keys(data))))] ($(summary(data[1]))) has $n, while data[$(repr(i))] ($(summary(data[i]))) has $ni.")) + n == ni || throw( + DimensionMismatch( + "All data inputs should have the same number of observations, i.e. size in the last dimension. " * + "But data[$(repr(first(keys(data))))] ($(summary(data[1]))) has $n, while data[$(repr(i))] ($(summary(data[i]))) has $ni.", + ), + ) end return n end _getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i] -_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data) +_getobs(data::Union{Tuple,NamedTuple}, i) = map(Base.Fix2(_getobs, i), data) -Base.eltype(::DataLoader{D}) where D = D +Base.eltype(::DataLoader{D}) where {D} = D diff --git a/src/data/fashion-mnist.jl b/src/data/fashion-mnist.jl index 18999e9d5e..3bb1eaf2dc 100644 --- a/src/data/fashion-mnist.jl +++ b/src/data/fashion-mnist.jl @@ -5,25 +5,43 @@ using ..Data: download_and_verify, deprecation_message const dir = if isnothing(@__DIR__) joinpath("deps", "fashion-mnist") - else +else joinpath(@__DIR__, "../../deps/fashion-mnist") end function load() - mkpath(dir) - cd(dir) do - for (file, hash) in [("train-images-idx3-ubyte", "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"), - ("train-labels-idx1-ubyte", "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"), - ("t10k-images-idx3-ubyte" , "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"), - ("t10k-labels-idx1-ubyte" , "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5")] - isfile(file) && continue - @info "Downloading Fashion-MNIST dataset" - download_and_verify("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz", hash) - open(file, "w") do io - write(io, gzopen(read, "$file.gz")) - end + mkpath(dir) + cd(dir) do + for (file, hash) in [ + ( + "train-images-idx3-ubyte", + "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84", + ), + ( + "train-labels-idx1-ubyte", + "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845", + ), + ( + "t10k-images-idx3-ubyte", + "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073", + ), + ( + "t10k-labels-idx1-ubyte", + "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", + ), + ] + isfile(file) && continue + @info "Downloading Fashion-MNIST dataset" + download_and_verify( + "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", + "$file.gz", + hash, + ) + open(file, "w") do io + write(io, gzopen(read, "$file.gz")) + end + end end - end end const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte") @@ -41,11 +59,11 @@ Return the 60,000 training images by default; pass `:test` to retrieve the 10,000 test images. """ function images(set = :train) - deprecation_message() - load() - io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) - _, N, nrows, ncols = imageheader(io) - [rawimage(io) for _ in 1:N] + deprecation_message() + load() + io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) + _, N, nrows, ncols = imageheader(io) + return [rawimage(io) for _ in 1:N] end """ @@ -57,11 +75,11 @@ Return the 60,000 training labels by default; pass `:test` to retrieve the 10,000 test labels. """ function labels(set = :train) - deprecation_message() - load() - io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) - _, N = labelheader(io) - [rawlabel(io) for _ = 1:N] + deprecation_message() + load() + io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) + _, N = labelheader(io) + return [rawlabel(io) for _ in 1:N] end end diff --git a/src/data/housing.jl b/src/data/housing.jl index 4202f4d822..c8abc245fa 100644 --- a/src/data/housing.jl +++ b/src/data/housing.jl @@ -40,12 +40,14 @@ using ..Data: deps, download_and_verify, deprecation_message const cache_prefix = "" function load() - isfile(deps("housing.data")) && return + isfile(deps("housing.data")) && return nothing @info "Downloading the Boston housing Dataset" - download_and_verify("$(cache_prefix)http://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data", - deps("housing.data"), - "baadf72995725d76efe787b664e1f083388c79ba21ef9a7990d87f774184735a") + download_and_verify( + "$(cache_prefix)http://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data", + deps("housing.data"), + "baadf72995725d76efe787b664e1f083388c79ba21ef9a7990d87f774184735a", + ) #@info "Download complete. Working on the files" path = deps() @@ -58,7 +60,7 @@ function load() end end end - mv(joinpath(path, "tempfile.data"), deps("housing.data"), force=true) + return mv(joinpath(path, "tempfile.data"), deps("housing.data"); force = true) end """ @@ -75,18 +77,30 @@ function targets() deprecation_message() load() housing = readdlm(deps("housing.data"), ',') - reshape(Vector{Float64}(housing[1:end,end]), (506, 1)) + return reshape(Vector{Float64}(housing[1:end, end]), (506, 1)) end - """ Gets the names of the features provided in the dataset """ function feature_names() - ["crim","zn","indus","chas","nox","rm","age","dis","rad","tax","ptratio","b","lstat"] + return [ + "crim", + "zn", + "indus", + "chas", + "nox", + "rm", + "age", + "dis", + "rad", + "tax", + "ptratio", + "b", + "lstat", + ] end - """ Gets the features of the Boston Housing Dataset. This is a 506x13 Matrix of Float64 datatypes. The values are in the order ["crim","zn","indus","chas","nox","rm","age","dis","rad","tax","ptratio","b","lstat"]. @@ -113,8 +127,7 @@ function features() deprecation_message() load() housing = readdlm(deps("housing.data"), ',') - Matrix{Float64}(housing[1:end, 1:13]) + return Matrix{Float64}(housing[1:end, 1:13]) end - end \ No newline at end of file diff --git a/src/data/iris.jl b/src/data/iris.jl index 4529aa8a40..6c0f611820 100644 --- a/src/data/iris.jl +++ b/src/data/iris.jl @@ -15,12 +15,14 @@ using ..Data: deps, download_and_verify, deprecation_message const cache_prefix = "https://cache.julialang.org/" function load() - isfile(deps("iris.data")) && return + isfile(deps("iris.data")) && return nothing @info "Downloading iris dataset." - download_and_verify("$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data", - deps("iris.data"), - "6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0") + return download_and_verify( + "$(cache_prefix)https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data", + deps("iris.data"), + "6f608b71a7317216319b4d27b4d9bc84e6abd734eda7872b71a458569e2656c0", + ) end """ @@ -39,7 +41,7 @@ function labels() deprecation_message() load() iris = readdlm(deps("iris.data"), ',') - Vector{String}(iris[1:end, end]) + return Vector{String}(iris[1:end, end]) end """ @@ -63,7 +65,7 @@ function features() deprecation_message() load() iris = readdlm(deps("iris.data"), ',') - Matrix{Float64}(iris[1:end, 1:4]') + return Matrix{Float64}(iris[1:end, 1:4]') end end diff --git a/src/data/mnist.jl b/src/data/mnist.jl index 45e51178ae..634ab37d6b 100644 --- a/src/data/mnist.jl +++ b/src/data/mnist.jl @@ -7,31 +7,49 @@ const Gray = Colors.Gray{Colors.N0f8} const dir = if isnothing(@__DIR__) joinpath("deps", "mnist") - else +else joinpath(@__DIR__, "../../deps/mnist") end function gzopen(f, file) - open(file) do io - f(GzipDecompressorStream(io)) - end + open(file) do io + f(GzipDecompressorStream(io)) + end end function load() - mkpath(dir) - cd(dir) do - for (file, hash) in [("train-images-idx3-ubyte", "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"), - ("train-labels-idx1-ubyte", "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"), - ("t10k-images-idx3-ubyte" , "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"), - ("t10k-labels-idx1-ubyte" , "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6")] - isfile(file) && continue - @info "Downloading MNIST dataset" - download_and_verify("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz", hash) - open(file, "w") do io - write(io, gzopen(read, "$file.gz")) - end + mkpath(dir) + cd(dir) do + for (file, hash) in [ + ( + "train-images-idx3-ubyte", + "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609", + ), + ( + "train-labels-idx1-ubyte", + "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c", + ), + ( + "t10k-images-idx3-ubyte", + "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6", + ), + ( + "t10k-labels-idx1-ubyte", + "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", + ), + ] + isfile(file) && continue + @info "Downloading MNIST dataset" + download_and_verify( + "https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", + "$file.gz", + hash, + ) + open(file, "w") do io + write(io, gzopen(read, "$file.gz")) + end + end end - end end const IMAGEOFFSET = 16 @@ -46,37 +64,37 @@ const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte") const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte") function imageheader(io::IO) - magic_number = bswap(read(io, UInt32)) - total_items = bswap(read(io, UInt32)) - nrows = bswap(read(io, UInt32)) - ncols = bswap(read(io, UInt32)) - return magic_number, Int(total_items), Int(nrows), Int(ncols) + magic_number = bswap(read(io, UInt32)) + total_items = bswap(read(io, UInt32)) + nrows = bswap(read(io, UInt32)) + ncols = bswap(read(io, UInt32)) + return magic_number, Int(total_items), Int(nrows), Int(ncols) end function labelheader(io::IO) - magic_number = bswap(read(io, UInt32)) - total_items = bswap(read(io, UInt32)) - return magic_number, Int(total_items) + magic_number = bswap(read(io, UInt32)) + total_items = bswap(read(io, UInt32)) + return magic_number, Int(total_items) end function rawimage(io::IO) - img = Array{Gray}(undef, NCOLS, NROWS) - for i in 1:NCOLS, j in 1:NROWS - img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8)) - end - return img + img = Array{Gray}(undef, NCOLS, NROWS) + for i in 1:NCOLS, j in 1:NROWS + img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8)) + end + return img end function rawimage(io::IO, index::Integer) - seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1)) - return rawimage(io) + seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1)) + return rawimage(io) end rawlabel(io::IO) = Int(read(io, UInt8)) function rawlabel(io::IO, index::Integer) - seek(io, LABELOFFSET + (index - 1)) - return rawlabel(io) + seek(io, LABELOFFSET + (index - 1)) + return rawlabel(io) end getfeatures(io::IO, index::Integer) = vec(getimage(io, index)) @@ -91,11 +109,11 @@ Return the 60,000 training images by default; pass `:test` to retrieve the 10,000 test images. """ function images(set = :train) - deprecation_message() - load() - io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) - _, N, nrows, ncols = imageheader(io) - [rawimage(io) for _ in 1:N] + deprecation_message() + load() + io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) + _, N, nrows, ncols = imageheader(io) + return [rawimage(io) for _ in 1:N] end """ @@ -107,11 +125,11 @@ Return the 60,000 training labels by default; pass `:test` to retrieve the 10,000 test labels. """ function labels(set = :train) - deprecation_message() - load() - io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) - _, N = labelheader(io) - [rawlabel(io) for _ = 1:N] + deprecation_message() + load() + io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) + _, N = labelheader(io) + return [rawlabel(io) for _ in 1:N] end end # module diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl index aae8f70930..c1a9f37feb 100644 --- a/src/data/sentiment.jl +++ b/src/data/sentiment.jl @@ -5,19 +5,22 @@ using ZipFile using ..Data: deps, download_and_verify, deprecation_message function load() - isfile(deps("sentiment.zip")) && return - @info "Downloading sentiment treebank dataset" - download_and_verify("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", - deps("sentiment.zip"), "5c613a4f673fc74097d523a2c83f38e0cc462984d847b82c7aaf36b01cbbbfcc") + isfile(deps("sentiment.zip")) && return nothing + @info "Downloading sentiment treebank dataset" + return download_and_verify( + "https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", + deps("sentiment.zip"), + "5c613a4f673fc74097d523a2c83f38e0cc462984d847b82c7aaf36b01cbbbfcc", + ) end getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)] function getfile(name) - r = ZipFile.Reader(deps("sentiment.zip")) - text = read(getfile(r, "trees/$name"), String) - close(r) - return text + r = ZipFile.Reader(deps("sentiment.zip")) + text = read(getfile(r, "trees/$name"), String) + close(r) + return text end using ..Data: Tree @@ -27,17 +30,17 @@ totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b)) totree(t::Expr) = totree_(t.args...) function parsetree(s) - s = replace(s, "\\" => "") - s = replace(s, "\$" => "\\\$") - s = replace(s, r"[^ \n\(\)]+" => s -> "\"$s\"") - s = replace(s, " " => ", ") - return totree(Meta.parse(s)) + s = replace(s, "\\" => "") + s = replace(s, "\$" => "\\\$") + s = replace(s, r"[^ \n\(\)]+" => s -> "\"$s\"") + s = replace(s, " " => ", ") + return totree(Meta.parse(s)) end function gettrees(name) - load() - ss = split(getfile("$name.txt"), '\n', keepempty = false) - return parsetree.(ss) + load() + ss = split(getfile("$name.txt"), '\n'; keepempty = false) + return parsetree.(ss) end """ @@ -46,8 +49,8 @@ Return the train split of the Stanford Sentiment Treebank. The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format. """ function train() - deprecation_message() - gettrees("train") + deprecation_message() + return gettrees("train") end """ @@ -56,8 +59,8 @@ Return the test split of the Stanford Sentiment Treebank. The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format. """ function test() - deprecation_message() - gettrees("test") + deprecation_message() + return gettrees("test") end """ @@ -66,8 +69,8 @@ Return the dev split of the Stanford Sentiment Treebank. The data is in [treebank](https://en.wikipedia.org/wiki/Treebank) format. """ function dev() - deprecation_message() - gettrees("dev") + deprecation_message() + return gettrees("dev") end end \ No newline at end of file diff --git a/src/data/tree.jl b/src/data/tree.jl index 38a683f9c5..10ede92541 100644 --- a/src/data/tree.jl +++ b/src/data/tree.jl @@ -1,31 +1,33 @@ using AbstractTrees struct Tree{T} - value::T - children::Vector{Tree{T}} + value::T + children::Vector{Tree{T}} end -Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) -Tree{T}(x) where T = Tree(convert(T, x)) +Tree{T}(x::T, xs::Tree{T}...) where {T} = Tree{T}(x, [xs...]) +Tree{T}(x) where {T} = Tree(convert(T, x)) -Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...) +Tree(x::T, xs::Tree{T}...) where {T} = Tree{T}(x, xs...) AbstractTrees.children(t::Tree) = t.children AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) Base.show(io::IO, t::Type{Tree}) = print(io, "Tree") -Base.show(io::IO, t::Type{Tree{T}}) where T = print(io, "Tree{", @isdefined(T) ? T : :T, "}") +function Base.show(io::IO, t::Type{Tree{T}}) where {T} + return print(io, "Tree{", @isdefined(T) ? T : :T, "}") +end function Base.show(io::IO, t::Tree) - println(io, typeof(t)) - print_tree(io, t) + println(io, typeof(t)) + return print_tree(io, t) end using Juno @render Juno.Inline t::Tree begin - render(t) = Juno.Tree(t.value, render.(t.children)) - Juno.Tree(typeof(t), [render(t)]) + render(t) = Juno.Tree(t.value, render.(t.children)) + Juno.Tree(typeof(t), [render(t)]) end Base.getindex(t::Tree, i::Integer) = t.children[i] @@ -37,6 +39,9 @@ isleaf(t) = isempty(children(t)) leaves(xs::Tree) = map(x -> x.value, Leaves(xs)) -Base.map(f, t::Tree, ts::Tree...) = - Tree{Any}(f(map(t -> t.value, (t, ts...))...), - [map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]...) +function Base.map(f, t::Tree, ts::Tree...) + return Tree{Any}( + f(map(t -> t.value, (t, ts...))...), + [map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]..., + ) +end diff --git a/src/deprecations.jl b/src/deprecations.jl index 45e32a8420..6b09435b21 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,35 +1,65 @@ # v0.12 deprecations -@deprecate Dropout(p, dims) Dropout(p; dims=dims) -@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, active=nothing) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, active, length(β)) -@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, active=nothing) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, active, length(β)) -@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, active=nothing) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, true, true, active, length(β)) +@deprecate Dropout(p, dims) Dropout(p; dims = dims) +@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, active = nothing) InstanceNorm( + λ, β, γ, μ, σ², ϵ, momentum, true, true, active, length(β) +) +@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, active = nothing) BatchNorm( + λ, β, γ, μ, σ², ϵ, momentum, true, true, active, length(β) +) +@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, active = nothing) GroupNorm( + G, λ, β, γ, μ, σ², ϵ, momentum, true, true, active, length(β) +) @deprecate outdims(f, inputsize) outputsize(f, inputsize) -@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...) -@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...) -@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...) +@deprecate Conv(; weight, bias, activation = identity, kws...) Conv( + weight, bias, activation; kws... +) +@deprecate ConvTranspose(; weight, bias, activation = identity, kws...) ConvTranspose( + weight, bias, activation; kws... +) +@deprecate DepthwiseConv(; weight, bias, activation = identity, kws...) DepthwiseConv( + weight, bias, activation; kws... +) function Base.getproperty(a::Dense, s::Symbol) - if s === :W - Base.depwarn("field name dense.W is deprecated in favour of dense.weight", :Dense) - return getfield(a, :weight) - elseif s === :b - Base.depwarn("field name dense.b is deprecated in favour of dense.bias", :Dense) - return getfield(a, :bias) - end - return getfield(a, s) + if s === :W + Base.depwarn("field name dense.W is deprecated in favour of dense.weight", :Dense) + return getfield(a, :weight) + elseif s === :b + Base.depwarn("field name dense.b is deprecated in favour of dense.bias", :Dense) + return getfield(a, :bias) + end + return getfield(a, s) end function ones(dims...) - Base.depwarn("Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", :ones) - Base.ones(Float32, dims...) + Base.depwarn( + "Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", + :ones, + ) + return Base.ones(Float32, dims...) end ones(T::Type, dims...) = Base.ones(T, dims...) function zeros(dims...) - Base.depwarn("Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", :zeros) - Base.zeros(Float32, dims...) + Base.depwarn( + "Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", + :zeros, + ) + return Base.zeros(Float32, dims...) end zeros(T::Type, dims...) = Base.zeros(T, dims...) -ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type")) -zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type")) +function ones32(::Type, dims...) + return throw( + ArgumentError( + "Flux.ones32 is always Float32, use Base.ones to specify the element type" + ), + ) +end +function zeros32(::Type, dims...) + return throw( + ArgumentError( + "Flux.zeros32 is always Float32, use Base.zeros to specify the element type" + ), + ) +end diff --git a/src/functor.jl b/src/functor.jl index ee9eb1d543..125107152b 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,5 +1,5 @@ import Adapt: adapt, adapt_storage -using LinearAlgebra: Cholesky +using LinearAlgebra: Cholesky using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray @@ -41,25 +41,24 @@ trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) - x in seen && return - push!(seen, x) - for child in trainable(x) - params!(p, child, seen) - end + x in seen && return nothing + push!(seen, x) + for child in trainable(x) + params!(p, child, seen) + end end function params(m...) - ps = Params() - params!(ps, m) - return ps + ps = Params() + params!(ps, m) + return ps end function loadparams!(m, xs) - for (p, x) in zip(params(m), xs) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") - copyto!(p, x) - end + for (p, x) in zip(params(m), xs) + size(p) == size(x) || error("Expected param size $(size(p)), got $(size(x))") + copyto!(p, x) + end end struct FluxCUDAAdaptor end @@ -75,16 +74,20 @@ struct FluxCPUAdaptor end adapt_storage(to::FluxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::AbstractRange) = x adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x -adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) +function adapt_storage( + to::FluxCPUAdaptor, x::T +) where {T<:CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix} + return adapt(Array, x) +end adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x Zygote.@adjoint function Array(x::CUDA.CuArray) - Array(x), d -> (CUDA.cu(d),) + return Array(x), d -> (CUDA.cu(d),) end Zygote.@adjoint function Adapt.adapt_storage(to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) - adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d),) + return adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d)) end # CPU/GPU movement conveniences @@ -115,7 +118,7 @@ Matrix{Float32} cpu(x) = fmap(x -> adapt(FluxCPUAdaptor(), x), x) _isbitsarray(::AbstractArray{<:Number}) = true -_isbitsarray(::AbstractArray{T}) where T = isbitstype(T) +_isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) _isbitsarray(x) = false """ @@ -142,17 +145,21 @@ CuArray{Float32, 2} ``` """ function gpu(x) - check_use_cuda() - use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isbitsarray) : x + check_use_cuda() + return if use_cuda[] + fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isbitsarray) + else + x + end end function check_use_cuda() - if use_cuda[] === nothing - use_cuda[] = CUDA.functional() - if use_cuda[] && !CUDA.has_cudnn() - @warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available." + if use_cuda[] === nothing + use_cuda[] = CUDA.functional() + if use_cuda[] && !CUDA.has_cudnn() + @warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available." + end end - end end Zygote.@nograd check_use_cuda diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1c87ceb204..dfda0e4221 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -29,17 +29,19 @@ true ``` """ struct Chain{T} - layers::T - Chain(xs...) = new{typeof(xs)}(xs) - function Chain(; kw...) - :layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`")) - isempty(kw) && return new{Tuple{}}(()) - new{typeof(values(kw))}(values(kw)) - end + layers::T + Chain(xs...) = new{typeof(xs)}(xs) + function Chain(; kw...) + :layers in Base.keys(kw) && + throw(ArgumentError("a Chain cannot have a named layer called `layers`")) + isempty(kw) && return new{Tuple{}}(()) + return new{typeof(values(kw))}(values(kw)) + end end -@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, - Base.iterate, Base.lastindex, Base.keys +@forward Chain.layers Base.getindex, +Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, +Base.keys functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...) @@ -49,16 +51,19 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) (c::Chain)(x) = applychain(Tuple(c.layers), x) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) -Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = - Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...) +function Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) + return Chain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...) +end function Base.show(io::IO, c::Chain) - print(io, "Chain(") - _show_layers(io, c.layers) - print(io, ")") + print(io, "Chain(") + _show_layers(io, c.layers) + return print(io, ")") end _show_layers(io, layers::Tuple) = join(io, layers, ", ") -_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") +function _show_layers(io, layers::NamedTuple) + return join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ") +end # This is a temporary and naive implementation # it might be replaced in the future for better performance @@ -73,12 +78,11 @@ Calculate the forward results of each layers in Chain `c` with `input` as model activations(c::Chain, input) = extraChain(Tuple(c.layers), input) function extraChain(fs::Tuple, x) - res = first(fs)(x) - return (res, extraChain(Base.tail(fs), res)...) + res = first(fs)(x) + return (res, extraChain(Base.tail(fs), res)...) end extraChain(::Tuple{}, x) = () - """ Dense(in, out, σ=identity; bias=true, init=glorot_uniform) Dense(W::AbstractMatrix, [bias, σ]) @@ -120,52 +124,62 @@ julia> Flux.params(d1) # no trainable bias Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]]) ``` """ -struct Dense{F, M<:AbstractMatrix, B} - weight::M - bias::B - σ::F - function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F} - b = create_bias(W, bias, size(W,1)) - new{F,M,typeof(b)}(W, b, σ) - end +struct Dense{F,M<:AbstractMatrix,B} + weight::M + bias::B + σ::F + function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix,F} + b = create_bias(W, bias, size(W, 1)) + return new{F,M,typeof(b)}(W, b, σ) + end end -function Dense(in::Integer, out::Integer, σ = identity; - initW = nothing, initb = nothing, - init = glorot_uniform, bias=true) - - W = if initW !== nothing - Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense) - initW(out, in) - else - init(out, in) - end - - b = if bias === true && initb !== nothing - Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense) - initb(out) - else - bias - end - - return Dense(W, b, σ) +function Dense( + in::Integer, + out::Integer, + σ = identity; + initW = nothing, + initb = nothing, + init = glorot_uniform, + bias = true, +) + W = if initW !== nothing + Base.depwarn( + "keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", + :Dense, + ) + initW(out, in) + else + init(out, in) + end + + b = if bias === true && initb !== nothing + Base.depwarn( + "keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", + :Dense, + ) + initb(out) + else + bias + end + + return Dense(W, b, σ) end @functor Dense function (a::Dense)(x::AbstractVecOrMat) - W, b, σ = a.weight, a.bias, a.σ - return σ.(W*x .+ b) + W, b, σ = a.weight, a.bias, a.σ + return σ.(W * x .+ b) end -(a::Dense)(x::AbstractArray) = - reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...) +(a::Dense)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) function Base.show(io::IO, l::Dense) - print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1)) - l.σ == identity || print(io, ", ", l.σ) - l.bias == Zeros() && print(io, "; bias=false") - print(io, ")") + print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1)) + l.σ == identity || print(io, ", ", l.σ) + l.bias == Zeros() && print(io, "; bias=false") + return print(io, ")") end """ @@ -182,24 +196,30 @@ The learnable arrays are initialised `α = ones(Float32, size)` and Used by [`LayerNorm`](@ref). """ struct Diagonal{T} - α::T - β::T + α::T + β::T end function Diagonal(sz::Integer...; initα = nothing, initβ = nothing) - α = if initα !== nothing - Base.depwarn("keyword initα is deprecated, please simply supply the desired vectors", :Diagonal) - initα(sz...) - else - ones32(sz...) - end - β = if initβ !== nothing - Base.depwarn("keyword initβ is deprecated, please simply supply the desired vectors", :Diagonal) - initβ(sz...) - else - zeros32(sz...) - end - Diagonal(α, β) + α = if initα !== nothing + Base.depwarn( + "keyword initα is deprecated, please simply supply the desired vectors", + :Diagonal, + ) + initα(sz...) + else + ones32(sz...) + end + β = if initβ !== nothing + Base.depwarn( + "keyword initβ is deprecated, please simply supply the desired vectors", + :Diagonal, + ) + initβ(sz...) + else + zeros32(sz...) + end + return Diagonal(α, β) end @functor Diagonal @@ -207,7 +227,7 @@ end (a::Diagonal)(x) = a.α .* x .+ a.β function Base.show(io::IO, l::Diagonal) - print(io, "Diagonal(", join(size(l.α), ", "), ")") + return print(io, "Diagonal(", join(size(l.α), ", "), ")") end """ @@ -243,14 +263,14 @@ julia> Maxout(()->Dense(insize, outsize), 4); ``` """ function Maxout(f, n_alts) - over = Tuple(f() for _ in 1:n_alts) - return Maxout(over) + over = Tuple(f() for _ in 1:n_alts) + return Maxout(over) end @functor Maxout function (mo::Maxout)(input::AbstractArray) - mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) + return mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end """ @@ -279,18 +299,18 @@ true ``` """ struct SkipConnection{T,F} - layers::T - connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b + layers::T + connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b end @functor SkipConnection function (skip::SkipConnection)(input) - skip.connection(skip.layers(input), input) + return skip.connection(skip.layers(input), input) end function Base.show(io::IO, b::SkipConnection) - print(io, "SkipConnection(", b.layers, ", ", b.connection, ")") + return print(io, "SkipConnection(", b.layers, ", ", b.connection, ")") end """ @@ -340,50 +360,65 @@ Bilinear(8, 16, 4, tanh, bias=false) ``` """ struct Bilinear{F,A,B} - weight::A - bias::B - σ::F - function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F} - ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights")) - b = create_bias(W, bias, size(W,1)) - new{F,A,typeof(b)}(W, b, σ) - end + weight::A + bias::B + σ::F + function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray,F} + ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights")) + b = create_bias(W, bias, size(W, 1)) + return new{F,A,typeof(b)}(W, b, σ) + end end @functor Bilinear -function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; - init = glorot_uniform, bias = true) - Bilinear(init(out, in1, in2), bias, σ) +function Bilinear( + in1::Integer, + in2::Integer, + out::Integer, + σ = identity; + init = glorot_uniform, + bias = true, +) + return Bilinear(init(out, in1, in2), bias, σ) end function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) - W, b, σ = a.weight, a.bias, a.σ - - d_z, d_x, d_y = size(W) - d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W")) - size(x,2) == size(y,2) || throw(DimensionMismatch("Data inputs must agree on number of columns, got $(size(x,2)) and $(size(y,2))")) - - # @einsum Wy[o,i,s] := W[o,i,j] * y[j,s] - Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :)) - - # @einsum Z[o,s] := Wy[o,i,s] * x[i,s] - Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :))) - Z = reshape(Wyx, (d_z, :)) - - # @einsum out[o,s] := σ(Z[o,i] + b[o]) - σ.(Z .+ b) + W, b, σ = a.weight, a.bias, a.σ + + d_z, d_x, d_y = size(W) + d_x == size(x, 1) && d_y == size(y, 1) || + throw(DimensionMismatch("number of rows in data must match W")) + size(x, 2) == size(y, 2) || throw( + DimensionMismatch( + "Data inputs must agree on number of columns, got $(size(x,2)) and $(size(y,2))", + ), + ) + + # @einsum Wy[o,i,s] := W[o,i,j] * y[j,s] + Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :)) + + # @einsum Z[o,s] := Wy[o,i,s] * x[i,s] + Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :))) + Z = reshape(Wyx, (d_z, :)) + + # @einsum out[o,s] := σ(Z[o,i] + b[o]) + return σ.(Z .+ b) end (a::Bilinear)(x::AbstractVecOrMat) = a(x, x) -(a::Bilinear)(x::AbstractVector, y::AbstractVector) = vec(a(reshape(x, :,1), reshape(y, :,1))) -(a::Bilinear)(x::NTuple{2, AbstractArray}) = a(x[1], x[2]) +function (a::Bilinear)(x::AbstractVector, y::AbstractVector) + return vec(a(reshape(x, :, 1), reshape(y, :, 1))) +end +(a::Bilinear)(x::NTuple{2,AbstractArray}) = a(x[1], x[2]) function Base.show(io::IO, l::Bilinear) - print(io, "Bilinear(", size(l.weight, 2), ", ", size(l.weight, 3), ", ", size(l.weight, 1)) - l.σ == identity || print(io, ", ", l.σ) - l.bias == Flux.Zeros() && print(io, ", bias=false") - print(io, ")") + print( + io, "Bilinear(", size(l.weight, 2), ", ", size(l.weight, 3), ", ", size(l.weight, 1) + ) + l.σ == identity || print(io, ", ", l.σ) + l.bias == Flux.Zeros() && print(io, ", bias=false") + return print(io, ")") end """ @@ -426,20 +461,24 @@ julia> model2[:β] == model2[2] true ``` """ -struct Parallel{F, T} - connection::F - layers::T +struct Parallel{F,T} + connection::F + layers::T end Parallel(connection, layers...) = Parallel(connection, layers) function Parallel(connection; kw...) - layers = NamedTuple(kw) - if :layers in Base.keys(layers) || :connection in Base.keys(layers) - throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) - elseif isempty(layers) - Parallel(connection, ()) - end - Parallel(connection, layers) + layers = NamedTuple(kw) + if :layers in Base.keys(layers) || :connection in Base.keys(layers) + throw( + ArgumentError( + "a Parallel layer cannot have a named sub-layer called `connection` or `layers`", + ), + ) + elseif isempty(layers) + Parallel(connection, ()) + end + return Parallel(connection, layers) end @functor Parallel @@ -456,9 +495,9 @@ Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) trainable(m::Parallel) = (m.connection, m.layers...) function Base.show(io::IO, m::Parallel) - print(io, "Parallel(", m.connection, ", ") - _show_layers(io, m.layers) - print(io, ")") + print(io, "Parallel(", m.connection, ", ") + _show_layers(io, m.layers) + return print(io, ")") end """ @@ -497,23 +536,26 @@ julia> model(vocab_idxs) == model(x) true """ struct Embedding{W} - weight::W + weight::W end @functor Embedding Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in)) - (m::Embedding)(x::Integer) = m.weight[:, x] (m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) -function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L} - size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) - return m(onecold(x)) +function (m::Embedding)(x::Union{OneHotVector{T,L},OneHotMatrix{T,L}}) where {T,L} + size(m.weight, 2) == L || throw( + DimensionMismatch( + "Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L" + ), + ) + return m(onecold(x)) end - + function Base.show(io::IO, m::Embedding) - print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))") + return print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))") end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0b6372364d..6be6007947 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -17,16 +17,18 @@ See also [`Conv`](@ref), [`MaxPool`](@ref). """ struct SamePad end -calc_padding(lt, pad, k::NTuple{N,T}, dilation, stride) where {T,N}= expand(Val(2*N), pad) +function calc_padding(lt, pad, k::NTuple{N,T}, dilation, stride) where {T,N} + return expand(Val(2 * N), pad) +end function calc_padding(lt, ::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T} - #Ref: "A guide to convolution arithmetic for deep learning" https://arxiv.org/abs/1603.07285 + #Ref: "A guide to convolution arithmetic for deep learning" https://arxiv.org/abs/1603.07285 - # Effective kernel size, including dilation - k_eff = @. k + (k - 1) * (dilation - 1) - # How much total padding needs to be applied? - pad_amt = @. k_eff - 1 - # In case amount of padding is odd we need to apply different amounts to each side. - return Tuple(mapfoldl(i -> [cld(i, 2), fld(i,2)], vcat, pad_amt)) + # Effective kernel size, including dilation + k_eff = @. k + (k - 1) * (dilation - 1) + # How much total padding needs to be applied? + pad_amt = @. k_eff - 1 + # In case amount of padding is odd we need to apply different amounts to each side. + return Tuple(mapfoldl(i -> [cld(i, 2), fld(i, 2)], vcat, pad_amt)) end """ @@ -93,13 +95,13 @@ julia> Conv((5,5), 3 => 7; stride = 2, dilation = 4)(xs) |> size ``` """ struct Conv{N,M,F,A,V} - σ::F - weight::A - bias::V - stride::NTuple{N,Int} - pad::NTuple{M,Int} - dilation::NTuple{N,Int} - groups::Int + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{M,Int} + dilation::NTuple{N,Int} + groups::Int end """ @@ -125,20 +127,35 @@ julia> params(c1) |> length 2 ``` """ -function Conv(w::AbstractArray{T,N}, b = true, σ = identity; - stride = 1, pad = 0, dilation = 1, groups = 1) where {T,N} - stride = expand(Val(N-2), stride) - dilation = expand(Val(N-2), dilation) - pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(w, b, size(w, N)) - return Conv(σ, w, bias, stride, pad, dilation, groups) -end - -function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groups = 1, - weight = convfilter(k, ch; init, groups), bias = true) where N - - Conv(weight, bias, σ; stride, pad, dilation, groups) +function Conv( + w::AbstractArray{T,N}, + b = true, + σ = identity; + stride = 1, + pad = 0, + dilation = 1, + groups = 1, +) where {T,N} + stride = expand(Val(N - 2), stride) + dilation = expand(Val(N - 2), dilation) + pad = calc_padding(Conv, pad, size(w)[1:(N - 2)], dilation, stride) + bias = create_bias(w, b, size(w, N)) + return Conv(σ, w, bias, stride, pad, dilation, groups) +end + +function Conv( + k::NTuple{N,Integer}, + ch::Pair{<:Integer,<:Integer}, + σ = identity; + init = glorot_uniform, + stride = 1, + pad = 0, + dilation = 1, + groups = 1, + weight = convfilter(k, ch; init, groups), + bias = true, +) where {N} + return Conv(weight, bias, σ; stride, pad, dilation, groups) end """ @@ -152,39 +169,50 @@ distribution. See also: [`depthwiseconvfilter`](@ref) """ -function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; - init = glorot_uniform, groups = 1) where N - cin, cout = ch - init(filter..., cin÷groups, cout) +function convfilter( + filter::NTuple{N,Integer}, + ch::Pair{<:Integer,<:Integer}; + init = glorot_uniform, + groups = 1, +) where {N} + cin, cout = ch + return init(filter..., cin ÷ groups, cout) end @functor Conv function (c::Conv)(x::AbstractArray) - σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1) - cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) - σ.(conv(x, c.weight, cdims) .+ b) -end - -_channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups + σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1) + cdims = DenseConvDims( + x, + c.weight; + stride = c.stride, + padding = c.pad, + dilation = c.dilation, + groups = c.groups, + ) + return σ.(conv(x, c.weight, cdims) .+ b) +end + +_channels_in(l::Conv) = size(l.weight, ndims(l.weight) - 1) * l.groups _channels_out(l::Conv) = size(l.weight, ndims(l.weight)) function Base.show(io::IO, l::Conv) - print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", _channels_in(l), " => ", _channels_out(l)) - _print_conv_opt(io, l) - print(io, ")") + print(io, "Conv(", size(l.weight)[1:(ndims(l.weight) - 2)]) + print(io, ", ", _channels_in(l), " => ", _channels_out(l)) + _print_conv_opt(io, l) + return print(io, ")") end function _print_conv_opt(io::IO, l) - l.σ == identity || print(io, ", ", l.σ) - all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad)) - all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride)) - all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation)) - if hasproperty(l, :groups) - (l.groups == 1) || print(io, ", groups=", l.groups) - end - (l.bias isa Zeros) && print(io, ", bias=false") + l.σ == identity || print(io, ", ", l.σ) + all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad)) + all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride)) + all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation)) + if hasproperty(l, :groups) + (l.groups == 1) || print(io, ", groups=", l.groups) + end + return (l.bias isa Zeros) && print(io, ", bias=false") end """ @@ -219,17 +247,17 @@ julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size ``` """ struct ConvTranspose{N,M,F,A,V} - σ::F - weight::A - bias::V - stride::NTuple{N,Int} - pad::NTuple{M,Int} - dilation::NTuple{N,Int} - groups::Int + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{M,Int} + dilation::NTuple{N,Int} + groups::Int end -_channels_in(l::ConvTranspose) = size(l.weight)[end] -_channels_out(l::ConvTranspose) = size(l.weight)[end-1]*l.groups +_channels_in(l::ConvTranspose) = size(l.weight)[end] +_channels_out(l::ConvTranspose) = size(l.weight)[end - 1] * l.groups """ ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation, groups]) @@ -237,62 +265,79 @@ _channels_out(l::ConvTranspose) = size(l.weight)[end-1]*l.groups Constructs a layer with the given weight and bias arrays. Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method. """ -function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity; - stride = 1, pad = 0, dilation = 1, groups=1) where {T,N} - stride = expand(Val(N-2), stride) - dilation = expand(Val(N-2), dilation) - pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride) - b = create_bias(w, bias, size(w, N-1) * groups) - return ConvTranspose(σ, w, b, stride, pad, dilation, groups) -end - -function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1, - groups = 1, - weight = convfilter(k, reverse(ch); init, groups), - bias = true, - ) where N - - ConvTranspose(weight, bias, σ; stride, pad, dilation, groups) +function ConvTranspose( + w::AbstractArray{T,N}, + bias = true, + σ = identity; + stride = 1, + pad = 0, + dilation = 1, + groups = 1, +) where {T,N} + stride = expand(Val(N - 2), stride) + dilation = expand(Val(N - 2), dilation) + pad = calc_padding(ConvTranspose, pad, size(w)[1:(N - 2)], dilation, stride) + b = create_bias(w, bias, size(w, N - 1) * groups) + return ConvTranspose(σ, w, b, stride, pad, dilation, groups) +end + +function ConvTranspose( + k::NTuple{N,Integer}, + ch::Pair{<:Integer,<:Integer}, + σ = identity; + init = glorot_uniform, + stride = 1, + pad = 0, + dilation = 1, + groups = 1, + weight = convfilter(k, reverse(ch); init, groups), + bias = true, +) where {N} + return ConvTranspose(weight, bias, σ; stride, pad, dilation, groups) end @functor ConvTranspose function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) - # Calculate size of "input", from ∇conv_data()'s perspective... - combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end]) - I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad - C_in = size(c.weight)[end-1] * c.groups - batch_size = size(x)[end] - # Create DenseConvDims() that looks like the corresponding conv() - w_size = size(c.weight) - return DenseConvDims((I..., C_in, batch_size), w_size; - stride=c.stride, - padding=c.pad, - dilation=c.dilation, - groups=c.groups, - ) + # Calculate size of "input", from ∇conv_data()'s perspective... + combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end]) + I = + (size(x)[1:(end - 2)] .- 1) .* c.stride .+ 1 .+ + (size(c.weight)[1:(end - 2)] .- 1) .* c.dilation .- combined_pad + C_in = size(c.weight)[end - 1] * c.groups + batch_size = size(x)[end] + # Create DenseConvDims() that looks like the corresponding conv() + w_size = size(c.weight) + return DenseConvDims( + (I..., C_in, batch_size), + w_size; + stride = c.stride, + padding = c.pad, + dilation = c.dilation, + groups = c.groups, + ) end # TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900 @nograd conv_transpose_dims function (c::ConvTranspose)(x::AbstractArray) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - cdims = conv_transpose_dims(c, x) - σ.(∇conv_data(x, c.weight, cdims) .+ b) + σ, b = c.σ, reshape(c.bias, map(_ -> 1, c.stride)..., :, 1) + cdims = conv_transpose_dims(c, x) + return σ.(∇conv_data(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::ConvTranspose) - print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", _channels_in(l), " => ", _channels_out(l)) - _print_conv_opt(io, l) - print(io, ")") + print(io, "ConvTranspose(", size(l.weight)[1:(ndims(l.weight) - 2)]) + print(io, ", ", _channels_in(l), " => ", _channels_out(l)) + _print_conv_opt(io, l) + return print(io, ")") end - -function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T} - calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride) +function calc_padding( + ::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride +) where {N,T} + return calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride) end """ @@ -324,12 +369,12 @@ julia> DepthwiseConv((5,5), 3 => 9, stride=2, pad=2)(xs) |> size ``` """ struct DepthwiseConv{N,M,F,A,V} - σ::F - weight::A - bias::V - stride::NTuple{N,Int} - pad::NTuple{M,Int} - dilation::NTuple{N,Int} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{M,Int} + dilation::NTuple{N,Int} end """ @@ -338,20 +383,29 @@ end Constructs a layer with the given weight and bias arrays. Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method. """ -function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity; - stride = 1, pad = 0, dilation = 1) where {T,N} - stride = expand(Val(N-2), stride) - dilation = expand(Val(N-2), dilation) - pad = calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride) - b = create_bias(w, bias, prod(size(w)[N-1:end])) - return DepthwiseConv(σ, w, b, stride, pad, dilation) -end - -function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1, - weight = depthwiseconvfilter(k, ch, init = init), bias = true) where N - @assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels" - return DepthwiseConv(weight, bias, σ; stride, pad, dilation) +function DepthwiseConv( + w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1 +) where {T,N} + stride = expand(Val(N - 2), stride) + dilation = expand(Val(N - 2), dilation) + pad = calc_padding(DepthwiseConv, pad, size(w)[1:(N - 2)], dilation, stride) + b = create_bias(w, bias, prod(size(w)[(N - 1):end])) + return DepthwiseConv(σ, w, b, stride, pad, dilation) +end + +function DepthwiseConv( + k::NTuple{N,Integer}, + ch::Pair{<:Integer,<:Integer}, + σ = identity; + init = glorot_uniform, + stride = 1, + pad = 0, + dilation = 1, + weight = depthwiseconvfilter(k, ch; init = init), + bias = true, +) where {N} + @assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels" + return DepthwiseConv(weight, bias, σ; stride, pad, dilation) end @functor DepthwiseConv @@ -367,23 +421,27 @@ distribution. See also: [`convfilter`](@ref) """ -depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; - init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1]) +function depthwiseconvfilter( + filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init = glorot_uniform +) where {N} + return init(filter..., div(ch[2], ch[1]), ch[1]) +end function (c::DepthwiseConv)(x) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) - σ.(depthwiseconv(x, c.weight, cdims) .+ b) + σ, b = c.σ, reshape(c.bias, map(_ -> 1, c.stride)..., :, 1) + cdims = DepthwiseConvDims( + x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation + ) + return σ.(depthwiseconv(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::DepthwiseConv) - print(io, "DepthwiseConv(", size(l.weight)[1:end-2]) - print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[end-1:end])) - _print_conv_opt(io, l) - print(io, ")") + print(io, "DepthwiseConv(", size(l.weight)[1:(end - 2)]) + print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[(end - 1):end])) + _print_conv_opt(io, l) + return print(io, ")") end - """ CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) @@ -412,12 +470,12 @@ julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size ``` """ struct CrossCor{N,M,F,A,V} - σ::F - weight::A - bias::V - stride::NTuple{N,Int} - pad::NTuple{M,Int} - dilation::NTuple{N,Int} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{M,Int} + dilation::NTuple{N,Int} end """ @@ -426,40 +484,56 @@ end Constructs a layer with the given weight and bias arrays. Accepts the same keywords as the `CrossCor((4,4), 3 => 7, relu)` method. """ -function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity; - stride = 1, pad = 0, dilation = 1) where {T,N} - stride = expand(Val(N-2), stride) - dilation = expand(Val(N-2), dilation) - pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride) - b = create_bias(w, bias, size(w, N)) - return CrossCor(σ, w, b, stride, pad, dilation) -end - -function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1, - weight = convfilter(k, ch, init = init), bias = true) where N - - return CrossCor(weight, bias, σ; stride, pad, dilation) +function CrossCor( + w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1 +) where {T,N} + stride = expand(Val(N - 2), stride) + dilation = expand(Val(N - 2), dilation) + pad = calc_padding(CrossCor, pad, size(w)[1:(N - 2)], dilation, stride) + b = create_bias(w, bias, size(w, N)) + return CrossCor(σ, w, b, stride, pad, dilation) +end + +function CrossCor( + k::NTuple{N,Integer}, + ch::Pair{<:Integer,<:Integer}, + σ = identity; + init = glorot_uniform, + stride = 1, + pad = 0, + dilation = 1, + weight = convfilter(k, ch; init = init), + bias = true, +) where {N} + return CrossCor(weight, bias, σ; stride, pad, dilation) end @functor CrossCor function crosscor(x, w, ddims::DenseConvDims) - ddims = DenseConvDims(ddims, F=true) - return conv(x, w, ddims) + ddims = DenseConvDims(ddims; F = true) + return conv(x, w, ddims) end function (c::CrossCor)(x::AbstractArray) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) - σ.(crosscor(x, c.weight, cdims) .+ b) + σ, b = c.σ, reshape(c.bias, map(_ -> 1, c.stride)..., :, 1) + cdims = DenseConvDims( + x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation + ) + return σ.(crosscor(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::CrossCor) - print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2]) - print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight))) - _print_conv_opt(io, l) - print(io, ")") + print(io, "CrossCor(", size(l.weight)[1:(ndims(l.weight) - 2)]) + print( + io, + ", ", + size(l.weight, ndims(l.weight) - 1), + " => ", + size(l.weight, ndims(l.weight)), + ) + _print_conv_opt(io, l) + return print(io, ")") end """ @@ -484,23 +558,23 @@ julia> MaxPool((4,4))(xs) ≈ AdaptiveMaxPool((25, 25))(xs) true ``` """ -struct AdaptiveMaxPool{S, O} - out::NTuple{O, Int} - AdaptiveMaxPool(out::NTuple{O, Int}) where O = new{O + 2, O}(out) +struct AdaptiveMaxPool{S,O} + out::NTuple{O,Int} + AdaptiveMaxPool(out::NTuple{O,Int}) where {O} = new{O + 2,O}(out) end -function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}) where {S, T} - insize = size(x)[1:end-2] - outsize = a.out - stride = insize .÷ outsize - k = insize .- (outsize .- 1) .* stride - pad = 0 - pdims = PoolDims(x, k; padding=pad, stride=stride) - return maxpool(x, pdims) +function (a::AdaptiveMaxPool{S})(x::AbstractArray{T,S}) where {S,T} + insize = size(x)[1:(end - 2)] + outsize = a.out + stride = insize .÷ outsize + k = insize .- (outsize .- 1) .* stride + pad = 0 + pdims = PoolDims(x, k; padding = pad, stride = stride) + return maxpool(x, pdims) end function Base.show(io::IO, a::AdaptiveMaxPool) - print(io, "AdaptiveMaxPool(", a.out, ")") + return print(io, "AdaptiveMaxPool(", a.out, ")") end """ @@ -525,23 +599,23 @@ julia> MeanPool((4,4))(xs) ≈ AdaptiveMeanPool((25, 25))(xs) true ``` """ -struct AdaptiveMeanPool{S, O} - out::NTuple{O, Int} - AdaptiveMeanPool(out::NTuple{O, Int}) where O = new{O + 2, O}(out) +struct AdaptiveMeanPool{S,O} + out::NTuple{O,Int} + AdaptiveMeanPool(out::NTuple{O,Int}) where {O} = new{O + 2,O}(out) end -function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}) where {S, T} - insize = size(x)[1:end-2] - outsize = a.out - stride = insize .÷ outsize - k = insize .- (outsize .- 1) .* stride - pad = 0 - pdims = PoolDims(x, k; padding=pad, stride=stride) - return meanpool(x, pdims) +function (a::AdaptiveMeanPool{S})(x::AbstractArray{T,S}) where {S,T} + insize = size(x)[1:(end - 2)] + outsize = a.out + stride = insize .÷ outsize + k = insize .- (outsize .- 1) .* stride + pad = 0 + pdims = PoolDims(x, k; padding = pad, stride = stride) + return meanpool(x, pdims) end function Base.show(io::IO, a::AdaptiveMeanPool) - print(io, "AdaptiveMeanPool(", a.out, ")") + return print(io, "AdaptiveMeanPool(", a.out, ")") end """ @@ -569,18 +643,18 @@ julia> GlobalMaxPool()(rand(3,5,7)) |> size # preserves 2 dimensions struct GlobalMaxPool end function (g::GlobalMaxPool)(x) - # Input size - x_size = size(x) - # Kernel size - k = x_size[1:end-2] - # Pooling dimensions - pdims = PoolDims(x, k) + # Input size + x_size = size(x) + # Kernel size + k = x_size[1:(end - 2)] + # Pooling dimensions + pdims = PoolDims(x, k) - return maxpool(x, pdims) + return maxpool(x, pdims) end function Base.show(io::IO, g::GlobalMaxPool) - print(io, "GlobalMaxPool()") + return print(io, "GlobalMaxPool()") end """ @@ -603,18 +677,18 @@ julia> m(xs) |> size struct GlobalMeanPool end function (g::GlobalMeanPool)(x) - # Input size - x_size = size(x) - # Kernel size - k = x_size[1:end-2] - # Pooling dimensions - pdims = PoolDims(x, k) + # Input size + x_size = size(x) + # Kernel size + k = x_size[1:(end - 2)] + # Pooling dimensions + pdims = PoolDims(x, k) - return meanpool(x, pdims) + return meanpool(x, pdims) end function Base.show(io::IO, g::GlobalMeanPool) - print(io, "GlobalMeanPool()") + return print(io, "GlobalMeanPool()") end """ @@ -657,31 +731,31 @@ julia> lay(rand(Float32, 100, 7, 50)) |> size ``` """ struct MaxPool{N,M} - k::NTuple{N,Int} - pad::NTuple{M,Int} - stride::NTuple{N,Int} + k::NTuple{N,Int} + pad::NTuple{M,Int} + stride::NTuple{N,Int} end -function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N - stride = expand(Val(N), stride) - pad = calc_padding(MaxPool ,pad, k, 1, stride) - return MaxPool(k, pad, stride) +function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where {N} + stride = expand(Val(N), stride) + pad = calc_padding(MaxPool, pad, k, 1, stride) + return MaxPool(k, pad, stride) end function (m::MaxPool)(x) - pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) - return maxpool(x, pdims) + pdims = PoolDims(x, m.k; padding = m.pad, stride = m.stride) + return maxpool(x, pdims) end function Base.show(io::IO, m::MaxPool) - print(io, "MaxPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", _maybetuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) - print(io, ")") + print(io, "MaxPool(", m.k) + all(==(0), m.pad) || print(io, ", pad=", _maybetuple_string(m.pad)) + m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) + return print(io, ")") end _maybetuple_string(pad) = string(pad) -_maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) +_maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) """ MeanPool(window::NTuple; pad=0, stride=window) @@ -716,25 +790,25 @@ julia> m(xs) |> size ``` """ struct MeanPool{N,M} - k::NTuple{N,Int} - pad::NTuple{M,Int} - stride::NTuple{N,Int} + k::NTuple{N,Int} + pad::NTuple{M,Int} + stride::NTuple{N,Int} end -function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N - stride = expand(Val(N), stride) - pad = calc_padding(MeanPool, pad, k, 1, stride) - return MeanPool(k, pad, stride) +function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where {N} + stride = expand(Val(N), stride) + pad = calc_padding(MeanPool, pad, k, 1, stride) + return MeanPool(k, pad, stride) end function (m::MeanPool)(x) - pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) - return meanpool(x, pdims) + pdims = PoolDims(x, m.k; padding = m.pad, stride = m.stride) + return meanpool(x, pdims) end function Base.show(io::IO, m::MeanPool) - print(io, "MeanPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", _maybetuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) - print(io, ")") + print(io, "MeanPool(", m.k) + all(==(0), m.pad) || print(io, ", pad=", _maybetuple_string(m.pad)) + m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) + return print(io, ")") end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 146b7dba56..9507413ad2 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -5,7 +5,7 @@ istraining() = false _isactive(m) = isnothing(m.active) ? istraining() : m.active _dropout_shape(s, ::Colon) = size(s) -_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) +_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) in enumerate(size(s)))...) _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) @@ -28,22 +28,22 @@ automatically managed using the [`Dropout`](@ref) layer instead of the The [`Dropout`](@ref) layer is what you should use in most scenarios. """ -function dropout(x, p; dims=:, active::Bool=true) - active || return x - y = dropout_mask(x, p, dims=dims) - return x .* y +function dropout(x, p; dims = :, active::Bool = true) + active || return x + y = dropout_mask(x, p; dims = dims) + return x .* y end -@adjoint function dropout(x, p; dims=:, active::Bool=true) - active || return x, Δ -> (Δ, nothing) - y = dropout_mask(x, p, dims=dims) - return x .* y, Δ -> (Δ .* y, nothing) +@adjoint function dropout(x, p; dims = :, active::Bool = true) + active || return x, Δ -> (Δ, nothing) + y = dropout_mask(x, p; dims = dims) + return x .* y, Δ -> (Δ .* y, nothing) end -function dropout_mask(x, p; dims=:) - y = rand!(similar(x, _dropout_shape(x, dims))) - y .= _dropout_kernel.(y, p, 1 - p) - return y +function dropout_mask(x, p; dims = :) + y = rand!(similar(x, _dropout_shape(x, dims))) + y .= _dropout_kernel.(y, p, 1 - p) + return y end """ @@ -58,28 +58,29 @@ e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input Does nothing to the input once [`Flux.testmode!`](@ref) is `true`. """ mutable struct Dropout{F,D} - p::F - dims::D - active::Union{Bool, Nothing} + p::F + dims::D + active::Union{Bool,Nothing} end -function Dropout(p; dims=:) - @assert 0 ≤ p ≤ 1 - Dropout(p, dims, nothing) +function Dropout(p; dims = :) + @assert 0 ≤ p ≤ 1 + return Dropout(p, dims, nothing) end function (a::Dropout)(x) - _isactive(a) || return x - return dropout(x, a.p; dims=a.dims, active=true) + _isactive(a) || return x + return dropout(x, a.p; dims = a.dims, active = true) end -testmode!(m::Dropout, mode=true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +function testmode!(m::Dropout, mode = true) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end function Base.show(io::IO, d::Dropout) - print(io, "Dropout(", d.p) - d.dims != (:) && print(io, ", dims = $(repr(d.dims))") - print(io, ")") + print(io, "Dropout(", d.p) + d.dims != (:) && print(io, ", dims = $(repr(d.dims))") + return print(io, ")") end """ @@ -93,30 +94,31 @@ remain the same as before. Does nothing to the input once [`testmode!`](@ref) is true. """ mutable struct AlphaDropout{F} - p::F - active::Union{Bool, Nothing} - function AlphaDropout(p, active = nothing) - @assert 0 ≤ p ≤ 1 - new{typeof(p)}(p, active) - end + p::F + active::Union{Bool,Nothing} + function AlphaDropout(p, active = nothing) + @assert 0 ≤ p ≤ 1 + return new{typeof(p)}(p, active) + end end -function (a::AlphaDropout)(x::AbstractArray{T}) where T - _isactive(a) || return x - p = a.p - iszero(p) && return x - isone(p) && return sign.(x) .* T(0) +function (a::AlphaDropout)(x::AbstractArray{T}) where {T} + _isactive(a) || return x + p = a.p + iszero(p) && return x + isone(p) && return sign.(x) .* T(0) - α′ = T(-1.7580993408473766) # selu(-Inf) == -λα - A = T(inv(sqrt((1 - p) * (1 + p * α′^2)))) - B = T(-A * α′ * p) + α′ = T(-1.7580993408473766) # selu(-Inf) == -λα + A = T(inv(sqrt((1 - p) * (1 + p * α′^2)))) + B = T(-A * α′ * p) - noise = rand!(similar(x)) - return A .* ifelse.(noise .> p, x, α′) .+ B + noise = rand!(similar(x)) + return A .* ifelse.(noise .> p, x, α′) .+ B end -testmode!(m::AlphaDropout, mode=true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +function testmode!(m::AlphaDropout, mode = true) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end """ LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5) @@ -137,64 +139,66 @@ as in the [`Diagonal`](@ref) layer. Se also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref). """ struct LayerNorm{F,D,T,N} - λ::F - diag::D - ϵ::T - size::NTuple{N,Int} - affine::Bool + λ::F + diag::D + ϵ::T + size::NTuple{N,Int} + affine::Bool end -function LayerNorm(sz, λ=identity; affine=true, ϵ=1f-5) - sz = sz isa Integer ? (sz,) : sz - diag = affine ? Diagonal(sz...) : nothing - return LayerNorm(λ, diag, ϵ, sz, affine) +function LayerNorm(sz, λ = identity; affine = true, ϵ = 1f-5) + sz = sz isa Integer ? (sz,) : sz + diag = affine ? Diagonal(sz...) : nothing + return LayerNorm(λ, diag, ϵ, sz, affine) end @functor LayerNorm function (a::LayerNorm)(x) - x = normalise(x, dims=1:length(a.size), ϵ=a.ϵ) - a.diag === nothing ? a.λ.(x) : a.λ.(a.diag(x)) + x = normalise(x; dims = 1:length(a.size), ϵ = a.ϵ) + return a.diag === nothing ? a.λ.(x) : a.λ.(a.diag(x)) end function Base.show(io::IO, l::LayerNorm) - print(io, "LayerNorm($(l.size)") - l.λ == identity || print(io, ", $(l.λ)") - hasaffine(l) || print(io, ", affine=false") - print(io, ")") + print(io, "LayerNorm($(l.size)") + l.λ == identity || print(io, ", $(l.λ)") + hasaffine(l) || print(io, ", affine=false") + return print(io, ")") end # For InstanceNorm, GroupNorm, and BatchNorm. # Compute the statistics on the slices specified by reduce_dims. # reduce_dims=[1,...,N-2,N] for BatchNorm # reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm -function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N} - if !_isactive(l) && l.track_stats # testmode with tracked stats - stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) - μ = reshape(l.μ, stats_shape) - σ² = reshape(l.σ², stats_shape) - else # trainmode or testmode without tracked stats - μ = mean(x; dims=reduce_dims) - σ² = mean((x .- μ).^2; dims=reduce_dims) - if l.track_stats - ## update moving mean/std - Zygote.ignore() do - mtm = l.momentum - m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var - μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) - σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) - l.μ = (1-mtm) .* l.μ .+ mtm .* μnew - l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new - end +function _norm_layer_forward( + l, x::AbstractArray{T,N}; reduce_dims, affine_shape +) where {T,N} + if !_isactive(l) && l.track_stats # testmode with tracked stats + stats_shape = ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) + μ = reshape(l.μ, stats_shape) + σ² = reshape(l.σ², stats_shape) + else # trainmode or testmode without tracked stats + μ = mean(x; dims = reduce_dims) + σ² = mean((x .- μ) .^ 2; dims = reduce_dims) + if l.track_stats + ## update moving mean/std + Zygote.ignore() do + mtm = l.momentum + m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var + μnew = vec(N ∈ reduce_dims ? μ : mean(μ; dims = N)) + σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ²; dims = N)) + l.μ = (1 - mtm) .* l.μ .+ mtm .* μnew + l.σ² = (1 - mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new + end + end + end + if hasaffine(l) + γ = reshape(l.γ, affine_shape) + β = reshape(l.β, affine_shape) + return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β) + else + return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) end - end - if hasaffine(l) - γ = reshape(l.γ, affine_shape) - β = reshape(l.β, affine_shape) - return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β) - else - return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) - end end """ @@ -234,57 +238,59 @@ m = Chain( ``` """ mutable struct BatchNorm{F,V,N,W} - λ::F # activation function - β::V # bias - γ::V # scale - μ::W # moving mean - σ²::W # moving var - ϵ::N - momentum::N - affine::Bool - track_stats::Bool - active::Union{Bool, Nothing} - chs::Int # number of channels + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving var + ϵ::N + momentum::N + affine::Bool + track_stats::Bool + active::Union{Bool,Nothing} + chs::Int # number of channels end -function BatchNorm(chs::Int, λ=identity; - initβ=zeros32, initγ=ones32, - affine=true, track_stats=true, - ϵ=1f-5, momentum=0.1f0) - - β = affine ? initβ(chs) : nothing - γ = affine ? initγ(chs) : nothing - μ = track_stats ? zeros32(chs) : nothing - σ² = track_stats ? ones32(chs) : nothing - - return BatchNorm(λ, β, γ, - μ, σ², ϵ, momentum, - affine, track_stats, - nothing, chs) +function BatchNorm( + chs::Int, + λ = identity; + initβ = zeros32, + initγ = ones32, + affine = true, + track_stats = true, + ϵ = 1f-5, + momentum = 0.1f0, +) + β = affine ? initβ(chs) : nothing + γ = affine ? initγ(chs) : nothing + μ = track_stats ? zeros32(chs) : nothing + σ² = track_stats ? ones32(chs) : nothing + + return BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, affine, track_stats, nothing, chs) end @functor BatchNorm trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : () function (BN::BatchNorm)(x) - @assert size(x, ndims(x)-1) == BN.chs - N = ndims(x) - reduce_dims = [1:N-2; N] - affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) - return _norm_layer_forward(BN, x; reduce_dims, affine_shape) + @assert size(x, ndims(x) - 1) == BN.chs + N = ndims(x) + reduce_dims = [1:(N - 2); N] + affine_shape = ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) + return _norm_layer_forward(BN, x; reduce_dims, affine_shape) end -testmode!(m::BatchNorm, mode=true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +function testmode!(m::BatchNorm, mode = true) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end function Base.show(io::IO, l::BatchNorm) - print(io, "BatchNorm($(l.chs)") - (l.λ == identity) || print(io, ", $(l.λ)") - hasaffine(l) || print(io, ", affine=false") - print(io, ")") + print(io, "BatchNorm($(l.chs)") + (l.λ == identity) || print(io, ", $(l.λ)") + hasaffine(l) || print(io, ", affine=false") + return print(io, ")") end - """ InstanceNorm(channels::Integer, λ=identity; initβ=zeros32, initγ=ones32, @@ -310,55 +316,58 @@ that will be used to renormalize the input in test phase. in previous Flux versions (< v0.12). """ mutable struct InstanceNorm{F,V,N,W} - λ::F # activation function - β::V # bias - γ::V # scale - μ::W # moving mean - σ²::W # moving var - ϵ::N - momentum::N - affine::Bool - track_stats::Bool - active::Union{Bool, Nothing} - chs::Int # number of channels + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving var + ϵ::N + momentum::N + affine::Bool + track_stats::Bool + active::Union{Bool,Nothing} + chs::Int # number of channels end -function InstanceNorm(chs::Int, λ=identity; - initβ=zeros32, initγ=ones32, - affine=false, track_stats=false, - ϵ=1f-5, momentum=0.1f0) - - β = affine ? initβ(chs) : nothing - γ = affine ? initγ(chs) : nothing - μ = track_stats ? zeros32(chs) : nothing - σ² = track_stats ? ones32(chs) : nothing - - return InstanceNorm(λ, β, γ, - μ, σ², ϵ, momentum, - affine, track_stats, - nothing, chs) +function InstanceNorm( + chs::Int, + λ = identity; + initβ = zeros32, + initγ = ones32, + affine = false, + track_stats = false, + ϵ = 1f-5, + momentum = 0.1f0, +) + β = affine ? initβ(chs) : nothing + γ = affine ? initγ(chs) : nothing + μ = track_stats ? zeros32(chs) : nothing + σ² = track_stats ? ones32(chs) : nothing + + return InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, affine, track_stats, nothing, chs) end @functor InstanceNorm trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : () function (l::InstanceNorm)(x) - @assert ndims(x) > 2 - @assert size(x, ndims(x)-1) == l.chs - N = ndims(x) - reduce_dims = 1:N-2 - affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) - return _norm_layer_forward(l, x; reduce_dims, affine_shape) + @assert ndims(x) > 2 + @assert size(x, ndims(x) - 1) == l.chs + N = ndims(x) + reduce_dims = 1:(N - 2) + affine_shape = ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) + return _norm_layer_forward(l, x; reduce_dims, affine_shape) end -testmode!(m::InstanceNorm, mode=true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +function testmode!(m::InstanceNorm, mode = true) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end function Base.show(io::IO, l::InstanceNorm) - print(io, "InstanceNorm($(l.chs)") - l.λ == identity || print(io, ", $(l.λ)") - hasaffine(l) || print(io, ", affine=false") - print(io, ")") + print(io, "InstanceNorm($(l.chs)") + l.λ == identity || print(io, ", $(l.λ)") + hasaffine(l) || print(io, ", affine=false") + return print(io, ")") end """ @@ -387,65 +396,68 @@ If `track_stats=true`, accumulates mean and var statistics in training phase that will be used to renormalize the input in test phase. """ mutable struct GroupNorm{F,V,N,W} - G::Int # number of groups - λ::F # activation function - β::V # bias - γ::V # scale - μ::W # moving mean - σ²::W # moving std - ϵ::N - momentum::N - affine::Bool - track_stats::Bool - active::Union{Bool, Nothing} - chs::Int # number of channels + G::Int # number of groups + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std + ϵ::N + momentum::N + affine::Bool + track_stats::Bool + active::Union{Bool,Nothing} + chs::Int # number of channels end @functor GroupNorm trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : () -function GroupNorm(chs::Int, G::Int, λ=identity; - initβ=zeros32, initγ=ones32, - affine=true, track_stats=false, - ϵ=1f-5, momentum=0.1f0) - - chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)") - - β = affine ? initβ(chs) : nothing - γ = affine ? initγ(chs) : nothing - μ = track_stats ? zeros32(G) : nothing - σ² = track_stats ? ones32(G) : nothing - - return GroupNorm(G, λ, - β, γ, - μ, σ², - ϵ, momentum, - affine, track_stats, - nothing, chs) +function GroupNorm( + chs::Int, + G::Int, + λ = identity; + initβ = zeros32, + initγ = ones32, + affine = true, + track_stats = false, + ϵ = 1f-5, + momentum = 0.1f0, +) + chs % G == 0 || + error("The number of groups ($(G)) must divide the number of channels ($chs)") + + β = affine ? initβ(chs) : nothing + γ = affine ? initγ(chs) : nothing + μ = track_stats ? zeros32(G) : nothing + σ² = track_stats ? ones32(G) : nothing + + return GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, affine, track_stats, nothing, chs) end function (gn::GroupNorm)(x) - @assert ndims(x) > 2 - @assert size(x, ndims(x)-1) == gn.chs - N = ndims(x) - sz = size(x) - x = reshape(x, sz[1:N-2]..., sz[N-1]÷gn.G, gn.G, sz[N]) - N = ndims(x) - reduce_dims = 1:N-2 - affine_shape = ntuple(i -> i ∈ (N-1, N-2) ? size(x, i) : 1, N) - x = _norm_layer_forward(gn, x; reduce_dims, affine_shape) - return reshape(x, sz) + @assert ndims(x) > 2 + @assert size(x, ndims(x) - 1) == gn.chs + N = ndims(x) + sz = size(x) + x = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ gn.G, gn.G, sz[N]) + N = ndims(x) + reduce_dims = 1:(N - 2) + affine_shape = ntuple(i -> i ∈ (N - 1, N - 2) ? size(x, i) : 1, N) + x = _norm_layer_forward(gn, x; reduce_dims, affine_shape) + return reshape(x, sz) end -testmode!(m::GroupNorm, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +function testmode!(m::GroupNorm, mode = true) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +end function Base.show(io::IO, l::GroupNorm) - # print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G) - print(io, "GroupNorm($(l.chs), $(l.G)") - l.λ == identity || print(io, ", ", l.λ) - hasaffine(l) || print(io, ", affine=false") - print(io, ")") + # print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G) + print(io, "GroupNorm($(l.chs), $(l.G)") + l.λ == identity || print(io, ", ", l.λ) + hasaffine(l) || print(io, ", affine=false") + return print(io, ")") end """ @@ -456,4 +468,4 @@ scale parameters, `false` otherwise. See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref). """ -hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine +hasaffine(l::Union{BatchNorm,InstanceNorm,LayerNorm,GroupNorm}) = l.affine diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 38f17fbb31..4473d1d189 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,7 +1,7 @@ -gate(h, n) = (1:h) .+ h*(n-1) -gate(x::AbstractVector, h, n) = @view x[gate(h,n)] -gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] +gate(h, n) = (1:h) .+ h * (n - 1) +gate(x::AbstractVector, h, n) = @view x[gate(h, n)] +gate(x::AbstractMatrix, h, n) = x[gate(h, n), :] # Stateful recurrence @@ -39,13 +39,13 @@ rnn.state # 60 """ mutable struct Recur{T,S} - cell::T - state::S + cell::T + state::S end function (m::Recur)(x) - m.state, y = m.cell(m.state, x) - return y + m.state, y = m.cell(m.state, x) + return y end @functor Recur @@ -66,53 +66,62 @@ rnn.state = hidden(rnn.cell) reset!(m::Recur) = (m.state = m.cell.state0) reset!(m) = foreach(reset!, functor(m)[1]) - # TODO remove in v0.13 function Base.getproperty(m::Recur, sym::Symbol) - if sym === :init - Zygote.ignore() do - @warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead." + if sym === :init + Zygote.ignore() do + @warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead." + end + return getfield(m.cell, :state0) + else + return getfield(m, sym) end - return getfield(m.cell, :state0) - else - return getfield(m, sym) - end end flip(f, xs) = reverse(f.(reverse(xs))) -function (m::Recur)(x::AbstractArray{T, 3}) where T - h = [m(view(x, :, :, i)) for i in 1:size(x, 3)] - sze = size(h[1]) - reshape(reduce(hcat, h), sze[1], sze[2], length(h)) +function (m::Recur)(x::AbstractArray{T,3}) where {T} + h = [m(view(x, :, :, i)) for i in 1:size(x, 3)] + sze = size(h[1]) + return reshape(reduce(hcat, h), sze[1], sze[2], length(h)) end # Vanilla RNN struct RNNCell{F,A,V,S} - σ::F - Wi::A - Wh::A - b::V - state0::S + σ::F + Wi::A + Wh::A + b::V + state0::S end -RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = - RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) +function RNNCell( + in::Integer, + out::Integer, + σ = tanh; + init = Flux.glorot_uniform, + initb = zeros32, + init_state = zeros32, +) + return RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out, 1)) +end -function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T} - σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b - h = σ.(Wi*x .+ Wh*h .+ b) - sz = size(x) - return h, reshape(h, :, sz[2:end]...) +function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})( + h, x::Union{AbstractVecOrMat{T},OneHotArray} +) where {F,A,V,T} + σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b + h = σ.(Wi * x .+ Wh * h .+ b) + sz = size(x) + return h, reshape(h, :, sz[2:end]...) end @functor RNNCell function Base.show(io::IO, l::RNNCell) - print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) - l.σ == identity || print(io, ", ", l.σ) - print(io, ")") + print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) + l.σ == identity || print(io, ", ", l.σ) + return print(io, ")") end """ @@ -177,51 +186,58 @@ Recur(m::RNNCell) = Recur(m, m.state0) # TODO remove in v0.13 function Base.getproperty(m::RNNCell, sym::Symbol) - if sym === :h - Zygote.ignore() do - @warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead." + if sym === :h + Zygote.ignore() do + @warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead." + end + return getfield(m, :state0) + else + return getfield(m, sym) end - return getfield(m, :state0) - else - return getfield(m, sym) - end end # LSTM struct LSTMCell{A,V,S} - Wi::A - Wh::A - b::V - state0::S + Wi::A + Wh::A + b::V + state0::S end -function LSTMCell(in::Integer, out::Integer; - init = glorot_uniform, - initb = zeros32, - init_state = zeros32) - cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1))) - cell.b[gate(out, 2)] .= 1 - return cell +function LSTMCell( + in::Integer, out::Integer; init = glorot_uniform, initb = zeros32, init_state = zeros32 +) + cell = LSTMCell( + init(out * 4, in), + init(out * 4, out), + initb(out * 4), + (init_state(out, 1), init_state(out, 1)), + ) + cell.b[gate(out, 2)] .= 1 + return cell end -function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} - b, o = m.b, size(h, 1) - g = m.Wi*x .+ m.Wh*h .+ b - input = σ.(gate(g, o, 1)) - forget = σ.(gate(g, o, 2)) - cell = tanh.(gate(g, o, 3)) - output = σ.(gate(g, o, 4)) - c = forget .* c .+ input .* cell - h′ = output .* tanh.(c) - sz = size(x) - return (h′, c), reshape(h′, :, sz[2:end]...) +function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})( + (h, c), x::Union{AbstractVecOrMat{T},OneHotArray} +) where {A,V,T} + b, o = m.b, size(h, 1) + g = m.Wi * x .+ m.Wh * h .+ b + input = σ.(gate(g, o, 1)) + forget = σ.(gate(g, o, 2)) + cell = tanh.(gate(g, o, 3)) + output = σ.(gate(g, o, 4)) + c = forget .* c .+ input .* cell + h′ = output .* tanh.(c) + sz = size(x) + return (h′, c), reshape(h′, :, sz[2:end]...) end @functor LSTMCell -Base.show(io::IO, l::LSTMCell) = - print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") +function Base.show(io::IO, l::LSTMCell) + return print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1) ÷ 4, ")") +end """ LSTM(in::Integer, out::Integer) @@ -261,55 +277,61 @@ Recur(m::LSTMCell) = Recur(m, m.state0) # TODO remove in v0.13 function Base.getproperty(m::LSTMCell, sym::Symbol) - if sym === :h - Zygote.ignore() do - @warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead." + if sym === :h + Zygote.ignore() do + @warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead." + end + return getfield(m, :state0)[1] + elseif sym === :c + Zygote.ignore() do + @warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead." + end + return getfield(m, :state0)[2] + else + return getfield(m, sym) end - return getfield(m, :state0)[1] - elseif sym === :c - Zygote.ignore() do - @warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead." - end - return getfield(m, :state0)[2] - else - return getfield(m, sym) - end end # GRU function _gru_output(Wi, Wh, b, x, h) - o = size(h, 1) - gx, gh = Wi*x, Wh*h - r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) - z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) + o = size(h, 1) + gx, gh = Wi * x, Wh * h + r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) + z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) - return gx, gh, r, z + return gx, gh, r, z end struct GRUCell{A,V,S} - Wi::A - Wh::A - b::V - state0::S + Wi::A + Wh::A + b::V + state0::S end -GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) = - GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) +function GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) + return GRUCell( + init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out, 1) + ) +end -function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} - b, o = m.b, size(h, 1) - gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h) - h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) - h′ = (1 .- z) .* h̃ .+ z .* h - sz = size(x) - return h′, reshape(h′, :, sz[2:end]...) +function (m::GRUCell{A,V,<:AbstractMatrix{T}})( + h, x::Union{AbstractVecOrMat{T},OneHotArray} +) where {A,V,T} + b, o = m.b, size(h, 1) + gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h) + h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) + h′ = (1 .- z) .* h̃ .+ z .* h + sz = size(x) + return h′, reshape(h′, :, sz[2:end]...) end @functor GRUCell -Base.show(io::IO, l::GRUCell) = - print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") +function Base.show(io::IO, l::GRUCell) + return print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1) ÷ 3, ")") +end """ GRU(in::Integer, out::Integer) @@ -350,44 +372,52 @@ Recur(m::GRUCell) = Recur(m, m.state0) # TODO remove in v0.13 function Base.getproperty(m::GRUCell, sym::Symbol) - if sym === :h - Zygote.ignore() do - @warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead." + if sym === :h + Zygote.ignore() do + @warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead." + end + return getfield(m, :state0) + else + return getfield(m, sym) end - return getfield(m, :state0) - else - return getfield(m, sym) - end end - # GRU v3 struct GRUv3Cell{A,V,S} - Wi::A - Wh::A - b::V - Wh_h̃::A - state0::S + Wi::A + Wh::A + b::V + Wh_h̃::A + state0::S +end + +function GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) + return GRUv3Cell( + init(out * 3, in), + init(out * 2, out), + initb(out * 3), + init(out, out), + init_state(out, 1), + ) end -GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) = - GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), - init(out, out), init_state(out,1)) - -function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} - b, o = m.b, size(h, 1) - gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h) - h̃ = tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3)) - h′ = (1 .- z) .* h̃ .+ z .* h - sz = size(x) - return h′, reshape(h′, :, sz[2:end]...) +function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})( + h, x::Union{AbstractVecOrMat{T},OneHotArray} +) where {A,V,T} + b, o = m.b, size(h, 1) + gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h) + h̃ = tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3)) + h′ = (1 .- z) .* h̃ .+ z .* h + sz = size(x) + return h′, reshape(h′, :, sz[2:end]...) end @functor GRUv3Cell -Base.show(io::IO, l::GRUv3Cell) = - print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") +function Base.show(io::IO, l::GRUv3Cell) + return print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1) ÷ 3, ")") +end """ GRUv3(in::Integer, out::Integer) @@ -426,7 +456,6 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) Recur(m::GRUv3Cell) = Recur(m, m.state0) - @adjoint function Broadcast.broadcasted(f::Recur, args...) - Zygote.∇map(__context__, f, args...) + return Zygote.∇map(__context__, f, args...) end diff --git a/src/layers/show.jl b/src/layers/show.jl index 94afdb92d0..72a2d8d890 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,46 +1,49 @@ for T in [ - :Chain, :Parallel, :SkipConnection, :Recur # container types - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL - _big_show(io, x) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - _layer_show(io, x) - else - show(io, x) + :Chain, + :Parallel, + :SkipConnection, + :Recur, # container types +] + @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) + end end - end end -function _big_show(io::IO, obj, indent::Int=0, name=nothing) - children = trainable(obj) - if all(_show_leaflike, children) - _layer_show(io, obj, indent, name) - else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") - if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) - # then we insert names -- can this be done more generically? - for k in Base.keys(obj) - _big_show(io, obj[k], indent+2, k) - end - elseif obj isa Parallel{<:Any, <:NamedTuple} - _big_show(io, obj.connection, indent+2) - for k in Base.keys(obj) - _big_show(io, obj[k], indent+2, k) - end - else - for c in children - _big_show(io, c, indent+2) - end - end - if indent == 0 # i.e. this is the outermost container - print(io, ")") - _big_finale(io, obj) +function _big_show(io::IO, obj, indent::Int = 0, name = nothing) + children = trainable(obj) + if all(_show_leaflike, children) + _layer_show(io, obj, indent, name) else - println(io, " "^indent, "),") + println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") + if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) + # then we insert names -- can this be done more generically? + for k in Base.keys(obj) + _big_show(io, obj[k], indent + 2, k) + end + elseif obj isa Parallel{<:Any,<:NamedTuple} + _big_show(io, obj.connection, indent + 2) + for k in Base.keys(obj) + _big_show(io, obj[k], indent + 2, k) + end + else + for c in children + _big_show(io, c, indent + 2) + end + end + if indent == 0 # i.e. this is the outermost container + print(io, ")") + _big_finale(io, obj) + else + println(io, " "^indent, "),") + end end - end end _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: @@ -49,69 +52,109 @@ _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LS _show_leaflike(::Diagonal) = true # appears inside LayerNorm for T in [ - :Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if !get(io, :compact, false) - _layer_show(io, x) - else - show(io, x) + :Conv, + :ConvTranspose, + :CrossCor, + :DepthwiseConv, + :Dense, + :BatchNorm, + :LayerNorm, + :InstanceNorm, + :GroupNorm, +] + @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end end - end end -function _layer_show(io::IO, layer, indent::Int=0, name=nothing) - _str = isnothing(name) ? "" : "$name = " - str = _str * sprint(show, layer, context=io) - print(io, " "^indent, str, indent==0 ? "" : ",") - if !isempty(params(layer)) - print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) - printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black) - nonparam = _childarray_sum(length, layer) - sum(length, params(layer)) - if nonparam > 0 - printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black) +function _layer_show(io::IO, layer, indent::Int = 0, name = nothing) + _str = isnothing(name) ? "" : "$name = " + str = _str * sprint(show, layer; context = io) + print(io, " "^indent, str, indent == 0 ? "" : ",") + if !isempty(params(layer)) + print(io, " "^max(2, (indent == 0 ? 20 : 39) - indent - length(str))) + printstyled( + io, + "# ", + underscorise(sum(length, params(layer))), + " parameters"; + color = :light_black, + ) + nonparam = _childarray_sum(length, layer) - sum(length, params(layer)) + if nonparam > 0 + printstyled( + io, + ", plus ", + underscorise(nonparam), + indent == 0 ? " non-trainable" : ""; + color = :light_black, + ) + end + _nan_show(io, params(layer)) end - _nan_show(io, params(layer)) - end - indent==0 || println(io) + return indent == 0 || println(io) end function _big_finale(io::IO, m) - ps = params(m) - if length(ps) > 2 - pars = underscorise(sum(length, ps)) - bytes = Base.format_bytes(Base.summarysize(m)) - noncnt = _childarray_sum(_->1, m) - length(ps) - if noncnt > 0 - nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) - printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) - println(io, pars, " parameters,") - printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black) - print(io, bytes, ".") - else - printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black) - print(io, pars, " parameters, ", bytes, ".") + ps = params(m) + if length(ps) > 2 + pars = underscorise(sum(length, ps)) + bytes = Base.format_bytes(Base.summarysize(m)) + noncnt = _childarray_sum(_ -> 1, m) - length(ps) + if noncnt > 0 + nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) + printstyled( + io, + " "^09, + "# Total: ", + length(ps), + " trainable arrays, "; + color = :light_black, + ) + println(io, pars, " parameters,") + printstyled( + io, + " "^10, + "# plus ", + noncnt, + " non-trainable, ", + nonparam, + " parameters, summarysize "; + color = :light_black, + ) + print(io, bytes, ".") + else + printstyled( + io, " "^19, "# Total: ", length(ps), " arrays, "; color = :light_black + ) + print(io, pars, " parameters, ", bytes, ".") + end end - end end _childarray_sum(f, x::AbstractArray) = f(x) -_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) +function _childarray_sum(f, x) + return isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) +end # utility functions -underscorise(n::Integer) = - join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') +function underscorise(n::Integer) + return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') +end function _nan_show(io::IO, x) - if !isempty(x) && _all(iszero, x) - printstyled(io, " (all zero)", color=:cyan) - elseif _any(isnan, x) - printstyled(io, " (some NaN)", color=:red) - elseif _any(isinf, x) - printstyled(io, " (some Inf)", color=:red) - end + if !isempty(x) && _all(iszero, x) + printstyled(io, " (all zero)"; color = :cyan) + elseif _any(isnan, x) + printstyled(io, " (some NaN)"; color = :red) + elseif _any(isinf, x) + printstyled(io, " (some Inf)"; color = :red) + end end _any(f, xs::AbstractArray{<:Number}) = any(f, xs) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 1a3a0df5ec..01485c470c 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -23,7 +23,7 @@ julia> xs |> m |> size ``` """ function flatten(x::AbstractArray) - return reshape(x, :, size(x)[end]) + return reshape(x, :, size(x)[end]) end """ @@ -33,9 +33,9 @@ Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given b Per default, `dims` is the last dimension. `ϵ` is a small additive factor added to the denominator for numerical stability. """ -function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) - μ = mean(x, dims=dims) +function normalise(x::AbstractArray; dims = ndims(x), ϵ = ofeltype(x, 1e-5)) + μ = mean(x; dims = dims) # σ = std(x, dims=dims, mean=μ, corrected=false) # use this when Zygote#478 gets merged - σ = std(x, dims=dims, corrected=false) - return (x .- μ) ./ (σ .+ ϵ) + σ = std(x; dims = dims, corrected = false) + return (x .- μ) ./ (σ .+ ϵ) end diff --git a/src/layers/upsample.jl b/src/layers/upsample.jl index b8388bb35f..2505d2cc7c 100644 --- a/src/layers/upsample.jl +++ b/src/layers/upsample.jl @@ -29,41 +29,38 @@ julia> m(ones(2, 2, 1, 1)) |> size (4, 5, 1, 1) ``` """ -struct Upsample{mode, S, T} - scale::S - size::T +struct Upsample{mode,S,T} + scale::S + size::T end function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing) - mode in [:nearest, :bilinear] || - throw(ArgumentError("mode=:$mode is not supported.")) - if !(isnothing(scale) ⊻ isnothing(size)) - throw(ArgumentError("Either scale or size should be specified (but not both).")) - end - return Upsample{mode,typeof(scale),typeof(size)}(scale, size) + mode in [:nearest, :bilinear] || throw(ArgumentError("mode=:$mode is not supported.")) + if !(isnothing(scale) ⊻ isnothing(size)) + throw(ArgumentError("Either scale or size should be specified (but not both).")) + end + return Upsample{mode,typeof(scale),typeof(size)}(scale, size) end Upsample(scale, mode::Symbol = :nearest) = Upsample(mode; scale) -(m::Upsample{:nearest})(x::AbstractArray) = - NNlib.upsample_nearest(x, m.scale) -function (m::Upsample{:nearest, Int})(x::AbstractArray{T, N}) where {T, N} - NNlib.upsample_nearest(x, ntuple(i -> m.scale, N-2)) +(m::Upsample{:nearest})(x::AbstractArray) = NNlib.upsample_nearest(x, m.scale) +function (m::Upsample{:nearest,Int})(x::AbstractArray{T,N}) where {T,N} + return NNlib.upsample_nearest(x, ntuple(i -> m.scale, N - 2)) end -(m::Upsample{:nearest, Nothing})(x::AbstractArray) = - NNlib.upsample_nearest(x; size=m.size) +(m::Upsample{:nearest,Nothing})(x::AbstractArray) = NNlib.upsample_nearest(x; size = m.size) -(m::Upsample{:bilinear})(x::AbstractArray) = - NNlib.upsample_bilinear(x, m.scale) -(m::Upsample{:bilinear, Nothing})(x::AbstractArray) = - NNlib.upsample_bilinear(x; size=m.size) +(m::Upsample{:bilinear})(x::AbstractArray) = NNlib.upsample_bilinear(x, m.scale) +function (m::Upsample{:bilinear,Nothing})(x::AbstractArray) + return NNlib.upsample_bilinear(x; size = m.size) +end function Base.show(io::IO, u::Upsample{mode}) where {mode} - print(io, "Upsample(") - print(io, ":", mode) - u.scale !== nothing && print(io, ", scale = $(u.scale)") - u.size !== nothing && print(io, ", size = $(u.size)") - print(io, ")") + print(io, "Upsample(") + print(io, ":", mode) + u.scale !== nothing && print(io, ", scale = $(u.scale)") + u.size !== nothing && print(io, ", size = $(u.size)") + return print(io, ")") end """ @@ -73,8 +70,8 @@ Pixel shuffling layer with upscale factor `r`. See [`NNlib.pixel_shuffle`](@ref). """ -struct PixelShuffle - r::Int +struct PixelShuffle + r::Int end (m::PixelShuffle)(x) = NNlib.pixel_shuffle(x, m.r) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 413c4ee034..e651871759 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -8,18 +8,24 @@ using CUDA using NNlib: logsoftmax, logσ import Base.Broadcast: broadcasted -export mse, mae, msle, +export mse, + mae, + msle, label_smoothing, - crossentropy, logitcrossentropy, - binarycrossentropy, logitbinarycrossentropy, + crossentropy, + logitcrossentropy, + binarycrossentropy, + logitbinarycrossentropy, kldivergence, huber_loss, tversky_loss, dice_coeff_loss, poisson_loss, - hinge_loss, squared_hinge_loss, + hinge_loss, + squared_hinge_loss, ctc_loss, - binary_focal_loss, focal_loss + binary_focal_loss, + focal_loss include("utils.jl") include("functions.jl") diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index e18a2ad88a..44d39d897f 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -3,7 +3,6 @@ using Zygote: @adjoint using Statistics using NNlib - ## CPU implementation """ @@ -11,18 +10,18 @@ using NNlib Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))` """ function logaddexp(a, b) - isinf(a) && return b - isinf(b) && return a - - # always want the greater number on the left in the exponentiation; - # the magnitude difference may end up making the number very positive - # which will cause exp() to return Inf - # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be - # Inf for Float32 values - if a < b - a, b = b, a - end - return a + log(1+exp(b-a)) + isinf(a) && return b + isinf(b) && return a + + # always want the greater number on the left in the exponentiation; + # the magnitude difference may end up making the number very positive + # which will cause exp() to return Inf + # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be + # Inf for Float32 values + if a < b + a, b = b, a + end + return a + log(1 + exp(b - a)) end """ @@ -31,82 +30,89 @@ end Adds blanks to the start and end of `z`, and between items in `z` """ function add_blanks(z, blank) - z′ = fill(blank, 2*length(z) + 1) - z′[2 .* eachindex(z)] = z - return z′ + z′ = fill(blank, 2 * length(z) + 1) + z′[2 .* eachindex(z)] = z + return z′ end function ctc_alpha(ŷ::AbstractArray, y) - typed_zero = zero(ŷ[1]) - ŷ = logsoftmax(ŷ) - blank = size(ŷ, 1) - z′ = add_blanks(y, blank) - T = size(ŷ, 2) - U′ = length(z′) - - α = fill(log(typed_zero), U′, T) - α[1,1] = ŷ[blank, 1] - α[2,1] = ŷ[z′[2], 1] - for t=2:T - bound = max(1, U′ - 2(T - t) - 1) - for u=bound:U′ - if u == 1 - α[u,t] = α[u, t-1] - else - α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) - - # array bounds check and f(u) function from Eq. 7.9 - if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) - α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) - end - end - α[u,t] += ŷ[z′[u], t] + typed_zero = zero(ŷ[1]) + ŷ = logsoftmax(ŷ) + blank = size(ŷ, 1) + z′ = add_blanks(y, blank) + T = size(ŷ, 2) + U′ = length(z′) + + α = fill(log(typed_zero), U′, T) + α[1, 1] = ŷ[blank, 1] + α[2, 1] = ŷ[z′[2], 1] + for t in 2:T + bound = max(1, U′ - 2(T - t) - 1) + for u in bound:U′ + if u == 1 + α[u, t] = α[u, t - 1] + else + α[u, t] = logaddexp(α[u, t - 1], α[u - 1, t - 1]) + + # array bounds check and f(u) function from Eq. 7.9 + if u > 2 && !(z′[u] == blank || z′[u - 2] == z′[u]) + α[u, t] = logaddexp(α[u, t], α[u - 2, t - 1]) + end + end + α[u, t] += ŷ[z′[u], t] + end end - end - return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ) + return ( + loss = -1 * logaddexp(α[end, T], α[end - 1, T]), + alpha = α, + zprime = z′, + logsoftyhat = ŷ, + ) end - + function ∇ctc_loss(ŷ::AbstractArray, y, out) - loss, α, z′, ŷ = out - U′, T = size(α) - blank = size(ŷ, 1) - typed_zero = zero(first(α)) - - # Calculate beta coefficients, from the bottom-right, to the upper-left - β = fill(log(typed_zero), U′, T) - - # Fill bottom-right corner so bounding errors can be avoided - # by starting `u` at `U′-1` - β[U′, T] = typed_zero - β[U′-1, T] = typed_zero - - # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 - for t=(T-1):-1:1 - bound = min(U′, 2t) - for u=bound:-1:1 - if u == U′ - β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] - else - β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) - - # array bounds check and g(u) function from Eq. 7.16 - if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] - β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) + loss, α, z′, ŷ = out + U′, T = size(α) + blank = size(ŷ, 1) + typed_zero = zero(first(α)) + + # Calculate beta coefficients, from the bottom-right, to the upper-left + β = fill(log(typed_zero), U′, T) + + # Fill bottom-right corner so bounding errors can be avoided + # by starting `u` at `U′-1` + β[U′, T] = typed_zero + β[U′ - 1, T] = typed_zero + + # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 + for t in (T - 1):-1:1 + bound = min(U′, 2t) + for u in bound:-1:1 + if u == U′ + β[u, t] = ŷ[z′[u], t + 1] + β[u, t + 1] + else + β[u, t] = logaddexp( + ŷ[z′[u], t + 1] + β[u, t + 1], ŷ[z′[u + 1], t + 1] + β[u + 1, t + 1] + ) + + # array bounds check and g(u) function from Eq. 7.16 + if u + 2 <= U′ && z′[u] != blank && z′[u] != z′[u + 2] + β[u, t] = logaddexp(β[u, t], ŷ[z′[u + 2], t + 1] + β[u + 2, t + 1]) + end + end end - end end - end - - # Accumulate alpha-beta products for each category, - # then calculate gradients - accum = fill(log(typed_zero), size(ŷ)) - for t=1:T - for u=1:U′ - accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t]) + + # Accumulate alpha-beta products for each category, + # then calculate gradients + accum = fill(log(typed_zero), size(ŷ)) + for t in 1:T + for u in 1:U′ + accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u, t] + β[u, t]) + end end - end - grads = exp.(ŷ) .- exp.(accum .+ loss) - return grads + grads = exp.(ŷ) .- exp.(accum .+ loss) + return grads end """ @@ -134,12 +140,11 @@ for mathematical details. ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss @adjoint function ctc_loss(ŷ, y) - out = ctc_alpha(ŷ, y) - ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) - return out.loss, ctc_loss_pullback + out = ctc_alpha(ŷ, y) + ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) + return out.loss, ctc_loss_pullback end - ## GPU implementation # a port of the GPU kernels from Baidu's C++ warp-ctc package, @@ -153,218 +158,260 @@ end const MAX_THREADS = 256 function log_plus_f(p1, p2) - isinf(p1) && return p2 - isinf(p2) && return p1 - if p1 < p2 - p1, p2 = p2, p1 - end - return p1 + log(1+exp(p2 - p1)) + isinf(p1) && return p2 + isinf(p2) && return p1 + if p1 < p2 + p1, p2 = p2, p1 + end + return p1 + log(1 + exp(p2 - p1)) end function count_repeats(A) - repeats = 0 - for (i,elem) in enumerate(A) - if i > 1 && A[i] == A[i-1] - repeats += 1 + repeats = 0 + for (i, elem) in enumerate(A) + if i > 1 && A[i] == A[i - 1] + repeats += 1 + end end - end - return repeats + return repeats end -function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel) - - tid = threadIdx().x - L = labelSize - T = uttLength - S = length(labelsWithBlanks) +function compute_alpha_kernel( + probs, + labelSize, + uttLength, + repeats, + labelsWithoutBlanks, + labelsWithBlanks, + alpha, + blankLabel, +) + tid = threadIdx().x + L = labelSize + T = uttLength + S = length(labelsWithBlanks) + + if L + repeats > T + return nothing + end + labels = labelsWithBlanks - if L + repeats > T - return nothing - end - labels = labelsWithBlanks - - # Corner-case checking - start = (L + repeats <= T) ? 0 : 1 - last = S > 1 ? 2 : 1 - - # Fill in first column (time step) - i = tid - while i <= last - start - alpha[start+i, 1] = probs[labels[start+i], 1] - i += blockDim().x - end - sync_threads() - - # Fill in coefficients for each time step - for t=2:T # Corner-case checking - if tid == 1 && !(1 < S - 2*(T-t) - 1) - if start == 0 - alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t] - elseif start == 1 - alpha[1, t] = alpha[1, t-1] - end + start = (L + repeats <= T) ? 0 : 1 + last = S > 1 ? 2 : 1 + + # Fill in first column (time step) + i = tid + while i <= last - start + alpha[start + i, 1] = probs[labels[start + i], 1] + i += blockDim().x end sync_threads() - # Fill in coefficients for each label class in the target output sequence; - # each thread will process the calculations for one class - idx = tid+1 - while idx <= S - prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1]) - if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] - prevSum = log_plus_f(prevSum, alpha[idx-2, t-1]) - end - if idx < S - 2*(T-t) - 1 - alpha[idx, t] = -Inf32 - else - alpha[idx, t] = prevSum + probs[labels[idx], t] - end - idx += blockDim().x + # Fill in coefficients for each time step + for t in 2:T + # Corner-case checking + if tid == 1 && !(1 < S - 2 * (T - t) - 1) + if start == 0 + alpha[1, t] = alpha[1, t - 1] + probs[blankLabel, t] + elseif start == 1 + alpha[1, t] = alpha[1, t - 1] + end + end + sync_threads() + + # Fill in coefficients for each label class in the target output sequence; + # each thread will process the calculations for one class + idx = tid + 1 + while idx <= S + prevSum = log_plus_f(alpha[idx, t - 1], alpha[idx - 1, t - 1]) + if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx - 2] + prevSum = log_plus_f(prevSum, alpha[idx - 2, t - 1]) + end + if idx < S - 2 * (T - t) - 1 + alpha[idx, t] = -Inf32 + else + alpha[idx, t] = prevSum + probs[labels[idx], t] + end + idx += blockDim().x + end + sync_threads() end - sync_threads() - end - return nothing + return nothing end -function compute_beta_and_grad_kernel(probs, labelSize, uttLength, - repeatsInLabel, labelsWithBlanks, - alphas, beta, output, accum, - grad, blankLabel, loss) - - tid = threadIdx().x - L = labelSize - T = uttLength - S = 2*L + 1 - repeats = repeatsInLabel - labels = labelsWithBlanks - - if (L+repeats) > T - return nothing - end - - # Corner-case checking - start = S > 1 ? S-2 : 0 - last = L + repeats < T ? S : S-1 - sync_threads() - i = tid - - # Calculate coefficients for last column (time step) - # then determine alpha and beta product - while i <= last - start - beta[i+start, T] = 0 - output[i+start, T] = beta[i+start, T] + alphas[i+start, T] - i += blockDim().x - end - sync_threads() - - # Fill in `accum` for last column (time step) - if tid == 1 - for i=1:S - labelIdx = labels[i] - accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) - end - end - sync_threads() - - # Fill in `grad` for last column (time step) - idx = tid - while idx <= size(grad, 1) - s = -Inf32 - for i=1:S - s = log_plus_f(s, output[i, T]) +function compute_beta_and_grad_kernel( + probs, + labelSize, + uttLength, + repeatsInLabel, + labelsWithBlanks, + alphas, + beta, + output, + accum, + grad, + blankLabel, + loss, +) + tid = threadIdx().x + L = labelSize + T = uttLength + S = 2 * L + 1 + repeats = repeatsInLabel + labels = labelsWithBlanks + + if (L + repeats) > T + return nothing end - # ∂L/∂a (where a is activation before logsoftmax) - grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s) - idx += blockDim().x - end - sync_threads() - - # Fill in the rest of the coefficients - t = T-1 - while t >= 1 - if t < T - idx = tid - while idx <= S - nextSum = probs[labels[idx], t+1] + beta[idx, t+1] - if idx < S - nextSum = log_plus_f(nextSum, - probs[labels[idx+1], t+1] + beta[idx+1, t+1]) - end - if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] - nextSum = log_plus_f(nextSum, - probs[labels[idx+2], t+1] + beta[idx + 2, t+1]) - end - if idx > 2*t - beta[idx, t] = -Inf32 - else - beta[idx, t] = nextSum - end - idx += blockDim().x - end - sync_threads() - idx = tid - while idx <= S - output[idx, t] = alphas[idx, t] + beta[idx, t] - idx += blockDim().x - end - sync_threads() + # Corner-case checking + start = S > 1 ? S - 2 : 0 + last = L + repeats < T ? S : S - 1 + sync_threads() + i = tid + + # Calculate coefficients for last column (time step) + # then determine alpha and beta product + while i <= last - start + beta[i + start, T] = 0 + output[i + start, T] = beta[i + start, T] + alphas[i + start, T] + i += blockDim().x end sync_threads() - # Calculate accumulated alpha-beta products for each label class for - # each time step; used in calculating gradients + # Fill in `accum` for last column (time step) if tid == 1 - for i=1:S - labelIdx = labels[i] - accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) - end + for i in 1:S + labelIdx = labels[i] + accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) + end end sync_threads() - idx = tid - # Calculate gradients + # Fill in `grad` for last column (time step) + idx = tid while idx <= size(grad, 1) + s = -Inf32 + for i in 1:S + s = log_plus_f(s, output[i, T]) + end - # ∂L/∂a (where a is activation before logsoftmax) - grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss) - idx += blockDim().x + # ∂L/∂a (where a is activation before logsoftmax) + grad[idx, T] = exp(probs[idx, T]) - exp(accum[idx, T] - s) + idx += blockDim().x end sync_threads() - t -= 1 - sync_threads() - end - return nothing + + # Fill in the rest of the coefficients + t = T - 1 + while t >= 1 + if t < T + idx = tid + while idx <= S + nextSum = probs[labels[idx], t + 1] + beta[idx, t + 1] + if idx < S + nextSum = log_plus_f( + nextSum, probs[labels[idx + 1], t + 1] + beta[idx + 1, t + 1] + ) + end + if labels[idx] != blankLabel && + idx != S - 1 && + labels[idx] != labels[idx + 2] + nextSum = log_plus_f( + nextSum, probs[labels[idx + 2], t + 1] + beta[idx + 2, t + 1] + ) + end + if idx > 2 * t + beta[idx, t] = -Inf32 + else + beta[idx, t] = nextSum + end + idx += blockDim().x + end + sync_threads() + idx = tid + while idx <= S + output[idx, t] = alphas[idx, t] + beta[idx, t] + idx += blockDim().x + end + sync_threads() + end + sync_threads() + + # Calculate accumulated alpha-beta products for each label class for + # each time step; used in calculating gradients + if tid == 1 + for i in 1:S + labelIdx = labels[i] + accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) + end + end + sync_threads() + idx = tid + + # Calculate gradients + while idx <= size(grad, 1) + + # ∂L/∂a (where a is activation before logsoftmax) + grad[idx, t] = exp(probs[idx, t]) - exp(accum[idx, t] + loss) + idx += blockDim().x + end + sync_threads() + t -= 1 + sync_threads() + end + return nothing end function ctc_alpha(ŷ::CuArray, y) - ŷ = logsoftmax(ŷ) - blank = size(ŷ, 1) - ycu = cu(y) - z′ = CUDA.fill(blank, 2 * length(y) + 1) - z′[eachindex(y) .* 2] .= ycu - T = size(ŷ, 2) - U′ = 2*length(y) + 1 - alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T) - nRepeats = count_repeats(cpu(y)) - nThreads = min(U′, MAX_THREADS) - @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank) - return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats) + ŷ = logsoftmax(ŷ) + blank = size(ŷ, 1) + ycu = cu(y) + z′ = CUDA.fill(blank, 2 * length(y) + 1) + z′[eachindex(y) .* 2] .= ycu + T = size(ŷ, 2) + U′ = 2 * length(y) + 1 + alphas = CUDA.fill(log(zero(eltype(ŷ))), U′, T) + nRepeats = count_repeats(cpu(y)) + nThreads = min(U′, MAX_THREADS) + @cuda blocks = 1 threads = nThreads compute_alpha_kernel( + ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank + ) + return ( + loss = -1 * logsumexp(alphas[(end - 1):end]), + alpha = alphas, + z′ = z′, + yhat = ŷ, + nRepeats = nRepeats, + ) end ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss function ∇ctc_loss(ŷ::CuArray, y, out) - loss, alphas, z′, ŷ, nRepeats = out - U′, T = size(alphas) - blank = size(ŷ, 1) - typed_zero = zero(eltype(ŷ)) - betas = CUDA.fill(log(typed_zero), U′, T) - output = CUDA.fill(log(typed_zero), U′, T) - nThreads = min(U′, MAX_THREADS) - grads = CUDA.fill(log(typed_zero), size(ŷ)) - accum = CUDA.fill(log(typed_zero), size(ŷ)) - @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss) - return grads + loss, alphas, z′, ŷ, nRepeats = out + U′, T = size(alphas) + blank = size(ŷ, 1) + typed_zero = zero(eltype(ŷ)) + betas = CUDA.fill(log(typed_zero), U′, T) + output = CUDA.fill(log(typed_zero), U′, T) + nThreads = min(U′, MAX_THREADS) + grads = CUDA.fill(log(typed_zero), size(ŷ)) + accum = CUDA.fill(log(typed_zero), size(ŷ)) + @cuda blocks = 1 threads = nThreads compute_beta_and_grad_kernel( + ŷ, + length(y), + T, + nRepeats, + CuArray(z′), + alphas, + betas, + output, + accum, + grads, + blank, + loss, + ) + return grads end diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 42a1ab27e7..4aefb2e268 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -19,8 +19,8 @@ julia> Flux.mae(y_model, 1:3) ``` """ function mae(ŷ, y; agg = mean) - _check_sizes(ŷ, y) - agg(abs.(ŷ .- y)) + _check_sizes(ŷ, y) + return agg(abs.(ŷ .- y)) end """ @@ -43,8 +43,8 @@ julia> Flux.mse(y_model, y_true) ``` """ function mse(ŷ, y; agg = mean) - _check_sizes(ŷ, y) - agg((ŷ .- y) .^ 2) + _check_sizes(ŷ, y) + return agg((ŷ .- y) .^ 2) end """ @@ -67,8 +67,8 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3) ``` """ function msle(ŷ, y; agg = mean, ϵ = epseltype(ŷ)) - _check_sizes(ŷ, y) - agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 ) + _check_sizes(ŷ, y) + return agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^ 2) end """ @@ -82,12 +82,12 @@ given the prediction `ŷ` and true values `y`. | δ * (|ŷ - y| - 0.5 * δ), otherwise """ function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1)) - _check_sizes(ŷ, y) - abs_error = abs.(ŷ .- y) - #TODO: remove dropgrad when Zygote can handle this function with CuArrays - temp = Zygote.dropgrad(abs_error .< δ) - x = ofeltype(ŷ, 0.5) - agg(((abs_error .^ 2) .* temp) .* x .+ δ * (abs_error .- x * δ) .* (1 .- temp)) + _check_sizes(ŷ, y) + abs_error = abs.(ŷ .- y) + #TODO: remove dropgrad when Zygote can handle this function with CuArrays + temp = Zygote.dropgrad(abs_error .< δ) + x = ofeltype(ŷ, 0.5) + return agg(((abs_error .^ 2) .* temp) .* x .+ δ * (abs_error .- x * δ) .* (1 .- temp)) end """ @@ -146,9 +146,9 @@ function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int = throw(ArgumentError("α must be between 0 and 1")) end if dims == 0 - y_smoothed = y .* (1 - α) .+ α*1//2 + y_smoothed = y .* (1 - α) .+ α * 1//2 elseif dims == 1 - y_smoothed = y .* (1 - α) .+ α* 1 // size(y, 1) + y_smoothed = y .* (1 - α) .+ α * 1//size(y, 1) else throw(ArgumentError("`dims` should be either 0 or 1")) end @@ -212,8 +212,8 @@ julia> Flux.crossentropy(y_model, y_smooth) ``` """ function crossentropy(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ)) - _check_sizes(ŷ, y) - agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims)) + _check_sizes(ŷ, y) + return agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims)) end """ @@ -251,8 +251,8 @@ julia> Flux.crossentropy(softmax(y_model), y_label) ``` """ function logitcrossentropy(ŷ, y; dims = 1, agg = mean) - _check_sizes(ŷ, y) - agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims)) + _check_sizes(ŷ, y) + return agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims)) end """ @@ -300,8 +300,8 @@ julia> Flux.crossentropy(y_prob, y_hot) ``` """ function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ)) - _check_sizes(ŷ, y) - agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ))) + _check_sizes(ŷ, y) + return agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ))) end """ @@ -330,8 +330,8 @@ julia> Flux.binarycrossentropy(sigmoid.(y_model), y_bin) ``` """ function logitbinarycrossentropy(ŷ, y; agg = mean) - _check_sizes(ŷ, y) - agg(@.((1 - y) * ŷ - logσ(ŷ))) + _check_sizes(ŷ, y) + return agg(@.((1 - y) * ŷ - logσ(ŷ))) end """ @@ -370,10 +370,10 @@ Inf ``` """ function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ)) - _check_sizes(ŷ, y) - entropy = agg(sum(xlogx.(y), dims = dims)) - cross_entropy = crossentropy(ŷ, y; dims = dims, agg = agg, ϵ = ϵ) - return entropy + cross_entropy + _check_sizes(ŷ, y) + entropy = agg(sum(xlogx.(y); dims = dims)) + cross_entropy = crossentropy(ŷ, y; dims = dims, agg = agg, ϵ = ϵ) + return entropy + cross_entropy end """ @@ -385,8 +385,8 @@ end [More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). """ function poisson_loss(ŷ, y; agg = mean) - _check_sizes(ŷ, y) - agg(ŷ .- xlogy.(y, ŷ)) + _check_sizes(ŷ, y) + return agg(ŷ .- xlogy.(y, ŷ)) end """ @@ -399,8 +399,8 @@ prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as See also: [`squared_hinge_loss`](@ref) """ function hinge_loss(ŷ, y; agg = mean) - _check_sizes(ŷ, y) - agg(max.(0, 1 .- ŷ .* y)) + _check_sizes(ŷ, y) + return agg(max.(0, 1 .- ŷ .* y)) end """ @@ -412,8 +412,8 @@ Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y See also: [`hinge_loss`](@ref) """ function squared_hinge_loss(ŷ, y; agg = mean) - _check_sizes(ŷ, y) - agg((max.(0, 1 .- ŷ .* y)) .^ 2) + _check_sizes(ŷ, y) + return agg((max.(0, 1 .- ŷ .* y)) .^ 2) end """ @@ -427,8 +427,8 @@ Similar to the F1_score. Calculated as: 1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth) """ function dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0)) - _check_sizes(ŷ, y) - 1 - (2 * sum(y .* ŷ) + smooth) / (sum(y .^ 2) + sum(ŷ .^ 2) + smooth) #TODO agg + _check_sizes(ŷ, y) + return 1 - (2 * sum(y .* ŷ) + smooth) / (sum(y .^ 2) + sum(ŷ .^ 2) + smooth) #TODO agg end """ @@ -445,7 +445,7 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) #TODO add agg num = sum(y .* ŷ) + 1 den = sum(y .* ŷ + β * (1 .- y) .* ŷ + (1 - β) * y .* (1 .- ŷ)) + 1 - 1 - num / den + return 1 - num / den end """ @@ -477,14 +477,14 @@ true See also: [`Losses.focal_loss`](@ref) for multi-class setting """ -function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) +function binary_focal_loss(ŷ, y; agg = mean, γ = 2, ϵ = epseltype(ŷ)) _check_sizes(ŷ, y) ŷ = ŷ .+ ϵ - p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) + p_t = y .* ŷ + (1 .- y) .* (1 .- ŷ) ce = -log.(p_t) weight = (1 .- p_t) .^ γ loss = weight .* ce - agg(loss) + return agg(loss) end """ @@ -521,10 +521,10 @@ true See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ -function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) +function focal_loss(ŷ, y; dims = 1, agg = mean, γ = 2, ϵ = epseltype(ŷ)) _check_sizes(ŷ, y) ŷ = ŷ .+ ϵ - agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) + return agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims = dims)) end ```@meta diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 386cd67166..947d72772f 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -4,8 +4,8 @@ Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get zero. """ function xlogx(x) - result = x * log(x) - ifelse(iszero(x), zero(result), result) + result = x * log(x) + return ifelse(iszero(x), zero(result), result) end """ @@ -14,22 +14,25 @@ end Return `x * log(y)` for `y > 0`, and zero when `x == 0`. """ function xlogy(x, y) - result = x * log(y) - ifelse(iszero(x), zero(result), result) + result = x * log(y) + return ifelse(iszero(x), zero(result), result) end @adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric) - res = xlogy.(x, y) - res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) + res = xlogy.(x, y) + return res, + Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) end # This can be made an error in Flux v0.13, for now just a warning function _check_sizes(ŷ::AbstractArray, y::AbstractArray) - for d in 1:max(ndims(ŷ), ndims(y)) - if size(ŷ,d) != size(y,d) - @warn "Size mismatch in loss function! In future this will be an error. In Flux <= 0.12 broadcasting accepts this, but may not give sensible results" summary(ŷ) summary(y) maxlog=3 _id=hash(size(y)) + for d in 1:max(ndims(ŷ), ndims(y)) + if size(ŷ, d) != size(y, d) + @warn "Size mismatch in loss function! In future this will be an error. In Flux <= 0.12 broadcasting accepts this, but may not give sensible results" summary( + ŷ + ) summary(y) maxlog = 3 _id = hash(size(y)) + end end - end end _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 diff --git a/src/onehot.jl b/src/onehot.jl index 34e76fe050..c74e67ce3c 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -8,48 +8,54 @@ using LinearAlgebra, NNlib These are constructed by [`onehot`](@ref) and [`onehotbatch`](@ref). Parameter `I` is the type of the underlying storage, and `T` its eltype. """ -struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"} - indices::I +struct OneHotArray{T<:Integer,L,N,var"N+1",I<:Union{T,AbstractArray{T,N}}} <: + AbstractArray{Bool,var"N+1"} + indices::I end -OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) -OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, 1, T}(indices) -OneHotArray(indices::I, L::Integer) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, L, N, N+1, I}(indices) +OneHotArray{T,L,N,I}(indices) where {T,L,N,I} = OneHotArray{T,L,N,N + 1,I}(indices) +OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T,L,0,1,T}(indices) +OneHotArray(indices::I, L::Integer) where {T,N,I<:AbstractArray{T,N}} = + OneHotArray{T,L,N,N + 1,I}(indices) _indices(x::OneHotArray) = x.indices -_indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) = - reshape(parent(x).indices, x.dims[2:end]) +_indices(x::Base.ReshapedArray{<:Any,<:Any,<:OneHotArray}) = + reshape(parent(x).indices, x.dims[2:end]) -const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T} -const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I} +const OneHotVector{T,L} = OneHotArray{T,L,0,1,T} +const OneHotMatrix{T,L,I} = OneHotArray{T,L,1,2,I} -@doc @doc(OneHotArray) -OneHotVector(idx, L) = OneHotArray(idx, L) -@doc @doc(OneHotArray) -OneHotMatrix(indices, L) = OneHotArray(indices, L) +@doc @doc(OneHotArray) OneHotVector(idx, L) = OneHotArray(idx, L) +@doc @doc(OneHotArray) OneHotMatrix(indices, L) = OneHotArray(indices, L) # use this type so reshaped arrays hit fast paths # e.g. argmax -const OneHotLike{T, L, N, var"N+1", I} = - Union{OneHotArray{T, L, N, var"N+1", I}, - Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}} +const OneHotLike{T,L,N,var"N+1",I} = Union{ + OneHotArray{T,L,N,var"N+1",I}, + Base.ReshapedArray{Bool,var"N+1",<:OneHotArray{T,L,<:Any,<:Any,I}}, +} _isonehot(x::OneHotArray) = true -_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L) +_isonehot(x::Base.ReshapedArray{<:Any,<:Any,<:OneHotArray{<:Any,L}}) where {L} = + (size(x, 1) == L) -Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...) +Base.size(x::OneHotArray{<:Any,L}) where {L} = (Int(L), size(x.indices)...) _onehotindex(x, i) = (x == i) Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i) -Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x +Base.getindex(x::OneHotVector{T,L}, ::Colon) where {T,L} = x Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i) -Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L) -Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x -Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...] +Base.getindex(x::OneHotArray{<:Any,L}, ::Colon, I...) where {L} = + OneHotArray(x.indices[I...], L) +Base.getindex(x::OneHotArray{<:Any,<:Any,<:Any,N}, ::Vararg{Colon,N}) where {N} = x +Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where {N} = x[I[1], Tuple(I)[2:N]...] function Base.showarg(io::IO, x::OneHotArray, toplevel) - print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(") + print( + io, + ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(", + ) Base.showarg(io, x.indices, false) print(io, ')') toplevel && print(io, " with eltype Bool") @@ -57,48 +63,62 @@ function Base.showarg(io::IO, x::OneHotArray, toplevel) end # this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots: -function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString) - x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s +function Base.replace_in_print_matrix( + x::OneHotLike, + i::Integer, + j::Integer, + s::AbstractString, +) + x[i, j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s end # copy CuArray versions back before trying to print them: -Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} = - Base.print_array(io, cpu(X)) -Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} = - Base.print_array(io, cpu(X)) - -_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} -_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} - -function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L - if isone(dims) || any(x -> !_isonehot(x), (x, xs...)) - return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims) - else - return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L) - end +Base.print_array(io::IO, X::OneHotLike{T,L,N,var"N+1",<:CuArray}) where {T,L,N,var"N+1"} = + Base.print_array(io, cpu(X)) +Base.print_array( + io::IO, + X::LinearAlgebra.AdjOrTrans{Bool,<:OneHotLike{T,L,N,var"N+1",<:CuArray}}, +) where {T,L,N,var"N+1"} = Base.print_array(io, cpu(X)) + +_onehot_bool_type( + x::OneHotLike{<:Any,<:Any,<:Any,N,<:Union{Integer,AbstractArray}}, +) where {N} = Array{Bool,N} +_onehot_bool_type(x::OneHotLike{<:Any,<:Any,<:Any,N,<:CuArray}) where {N} = CuArray{Bool,N} + +function Base.cat(x::OneHotLike{<:Any,L}, xs::OneHotLike{<:Any,L}...; dims::Int) where {L} + if isone(dims) || any(x -> !_isonehot(x), (x, xs...)) + return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims) + else + return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L) + end end Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2) Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1) # optimized concatenation for matrices and vectors of same parameters -Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) -Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) +Base.hcat(x::T, xs::T...) where {L,T<:OneHotLike{<:Any,L,<:Any,2}} = + OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) +Base.hcat(x::T, xs::T...) where {L,T<:OneHotLike{<:Any,L,<:Any,1}} = + OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) -batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L) +batch(xs::AbstractArray{<:OneHotVector{<:Any,L}}) where {L} = OneHotArray(_indices.(xs), L) -Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L) +Adapt.adapt_structure(T, x::OneHotArray{<:Any,L}) where {L} = + OneHotArray(adapt(T, _indices(x)), L) -Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}() +Base.BroadcastStyle(::Type{<:OneHotArray{<:Any,<:Any,<:Any,N,<:CuArray}}) where {N} = + CUDA.CuArrayStyle{N}() Base.map(f, x::OneHotLike) = Base.broadcast(f, x) Base.argmax(x::OneHotLike; dims = Colon()) = - (_isonehot(x) && dims == 1) ? - reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : - invoke(argmax, Tuple{AbstractArray}, x; dims = dims) + (_isonehot(x) && dims == 1) ? + reshape( + CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), + 1, + size(_indices(x))..., + ) : invoke(argmax, Tuple{AbstractArray}, x; dims = dims) """ onehot(x, labels, [default]) @@ -131,15 +151,15 @@ julia> hcat(αβγ...) # preserves sparsity ``` """ function onehot(x, labels) - i = something(findfirst(isequal(x), labels), 0) - i > 0 || error("Value $x is not in labels") - OneHotVector{UInt32, length(labels)}(i) + i = something(findfirst(isequal(x), labels), 0) + i > 0 || error("Value $x is not in labels") + OneHotVector{UInt32,length(labels)}(i) end function onehot(x, labels, default) - i = something(findfirst(isequal(x), labels), 0) - i > 0 || return onehot(default, labels) - OneHotVector{UInt32, length(labels)}(i) + i = something(findfirst(isequal(x), labels), 0) + i > 0 || return onehot(default, labels) + OneHotVector{UInt32,length(labels)}(i) end """ @@ -203,55 +223,79 @@ julia> Flux.onecold([ 1 0 0 1 0 1 0 1 0 0 1 """ onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)] function onecold(y::AbstractArray, labels = 1:size(y, 1)) - indices = _fast_argmax(y) - xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA + indices = _fast_argmax(y) + xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA - return map(xi -> labels[xi[1]], xs) + return map(xi -> labels[xi[1]], xs) end _fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1) function _fast_argmax(x::OneHotLike) - if _isonehot(x) - return _indices(x) - else - return _fast_argmax(convert(_onehot_bool_type(x), x)) - end + if _isonehot(x) + return _indices(x) + else + return _fast_argmax(convert(_onehot_bool_type(x), x)) + end end @nograd OneHotArray, onecold, onehot, onehotbatch -function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L - _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) - size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - return A[:, onecold(B)] +function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any,L}) where {L} + _isonehot(B) || return invoke(*, Tuple{AbstractMatrix,AbstractMatrix}, A, B) + size(A, 2) == L || throw( + DimensionMismatch( + "Matrix column must correspond with OneHot size: $(size(A, 2)) != $L", + ), + ) + return A[:, onecold(B)] end -function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L - _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) - size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - return NNlib.gather(A, _indices(B)) +function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any,L,1}) where {L} + _isonehot(B) || return invoke(*, Tuple{AbstractMatrix,AbstractMatrix}, A, B) + size(A, 2) == L || throw( + DimensionMismatch( + "Matrix column must correspond with OneHot size: $(size(A, 2)) != $L", + ), + ) + return NNlib.gather(A, _indices(B)) end -function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix}) - B_dim = length(_indices(parent(B))) - size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim")) - return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2))) +function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool,<:OneHotMatrix}) + B_dim = length(_indices(parent(B))) + size(A, 2) == B_dim || throw( + DimensionMismatch( + "Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim", + ), + ) + return NNlib.scatter(+, A, _indices(parent(B)), dstsize = (size(A, 1), size(B, 2))) end for wrapper in [:Adjoint, :Transpose] - @eval begin - function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T} - size(A, 2) == L || - throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - - return A[:, onecold(b)] - end - - function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{<:Any, L}) where {L, T} - size(A, 2) == L || - throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L")) - - return A[onecold(b)] + @eval begin + function Base.:*( + A::$wrapper{<:Any,<:AbstractMatrix{T}}, + b::OneHotVector{<:Any,L}, + ) where {L,T} + size(A, 2) == L || throw( + DimensionMismatch( + "Matrix column must correspond with OneHot size: $(size(A, 2)) != $L", + ), + ) + + return A[:, onecold(b)] + end + + function Base.:*( + A::$wrapper{<:Number,<:AbstractVector{T}}, + b::OneHotVector{<:Any,L}, + ) where {L,T} + size(A, 2) == L || throw( + DimensionMismatch( + "Matrix column must correspond with OneHot size: $(size(A, 2)) != $L", + ), + ) + + return A[onecold(b)] + end end - end end diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 010cbfc9bb..4b1af069d0 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,13 +1,32 @@ module Optimise using LinearAlgebra -import ArrayInterface +using ArrayInterface: ArrayInterface -export train!, update!, - Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, - InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, - ClipValue, ClipNorm +export train!, + update!, + Descent, + ADAM, + Momentum, + Nesterov, + RMSProp, + ADAGrad, + AdaMax, + ADADelta, + AMSGrad, + NADAM, + ADAMW, + RADAM, + OADAM, + AdaBelief, + InvDecay, + ExpDecay, + WeightDecay, + stop, + skip, + Optimiser, + ClipValue, + ClipNorm include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index b36560a9b9..8ac91a2671 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -33,13 +33,13 @@ Flux.Optimise.update!(opt, ps, gs) ``` """ mutable struct Descent <: AbstractOptimiser - eta::Float64 + eta::Float64 end Descent() = Descent(0.1) function apply!(o::Descent, x, Δ) - Δ .*= o.eta + Δ .*= o.eta end """ @@ -61,18 +61,18 @@ opt = Momentum(0.01, 0.99) ``` """ mutable struct Momentum <: AbstractOptimiser - eta::Float64 - rho::Float64 - velocity::IdDict + eta::Float64 + rho::Float64 + velocity::IdDict end Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) function apply!(o::Momentum, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v + η, ρ = o.eta, o.rho + v = get!(() -> zero(x), o.velocity, x)::typeof(x) + @. v = ρ * v - η * Δ + @. Δ = -v end """ @@ -94,19 +94,19 @@ opt = Nesterov(0.003, 0.95) ``` """ mutable struct Nesterov <: AbstractOptimiser - eta::Float64 - rho::Float64 - velocity::IdDict + eta::Float64 + rho::Float64 + velocity::IdDict end Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) function apply!(o::Nesterov, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - d = @. ρ^2 * v - (1+ρ) * η * Δ - @. v = ρ*v - η*Δ - @. Δ = -d + η, ρ = o.eta, o.rho + v = get!(() -> zero(x), o.velocity, x)::typeof(x) + d = @. ρ^2 * v - (1 + ρ) * η * Δ + @. v = ρ * v - η * Δ + @. Δ = -d end """ @@ -131,18 +131,18 @@ opt = RMSProp(0.002, 0.95) ``` """ mutable struct RMSProp <: AbstractOptimiser - eta::Float64 - rho::Float64 - acc::IdDict + eta::Float64 + rho::Float64 + acc::IdDict end RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) function apply!(o::RMSProp, x, Δ) - η, ρ = o.eta, o.rho - acc = get!(() -> zero(x), o.acc, x)::typeof(x) - @. acc = ρ * acc + (1 - ρ) * Δ^2 - @. Δ *= η / (√acc + ϵ) + η, ρ = o.eta, o.rho + acc = get!(() -> zero(x), o.acc, x)::typeof(x) + @. acc = ρ * acc + (1 - ρ) * Δ^2 + @. Δ *= η / (√acc + ϵ) end """ @@ -164,26 +164,26 @@ opt = ADAM(0.001, (0.9, 0.8)) ``` """ mutable struct ADAM <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict end ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) function apply!(o::ADAM, x, Δ) - η, β = o.eta, o.beta + η, β = o.eta, o.beta - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} + mt, vt, βp = get!(o.state, x) do + (zero(x), zero(x), Float64[β[1], β[2]]) + end::Tuple{typeof(x),typeof(x),Vector{Float64}} - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η - βp .= βp .* β + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η + βp .= βp .* β - return Δ + return Δ end """ @@ -205,34 +205,39 @@ opt = RADAM(0.001, (0.9, 0.8)) ``` """ mutable struct RADAM <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict end RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict()) function apply!(o::RADAM, x, Δ) - η, β = o.eta, o.beta - ρ∞ = 2/(1-β[2])-1 - - mt, vt, βp, t = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]], Ref(1)) - end :: Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2]) - if ρ > 4 - r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ)) - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r - else - @. Δ = mt / (1 - βp[1]) * η - end - βp .= βp .* β - t[] += 1 + η, β = o.eta, o.beta + ρ∞ = 2 / (1 - β[2]) - 1 + + mt, vt, βp, t = get!(o.state, x) do + ( + zero(x), + zero(x), + Float64[β[1], β[2]], + Ref(1), + ) + end::Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}} + + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2]) + if ρ > 4 + r = sqrt((ρ - 4) * (ρ - 2) * ρ∞ / ((ρ∞ - 4) * (ρ∞ - 2) * ρ)) + @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r + else + @. Δ = mt / (1 - βp[1]) * η + end + βp .= βp .* β + t[] += 1 - return Δ + return Δ end """ @@ -254,26 +259,26 @@ opt = AdaMax(0.001, (0.9, 0.995)) ``` """ mutable struct AdaMax <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict end AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) function apply!(o::AdaMax, x, Δ) - η, β = o.eta, o.beta + η, β = o.eta, o.beta - mt, ut, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} + mt, ut, βp = get!(o.state, x) do + (zero(x), zero(x), Float64[β[1], β[2]]) + end::Tuple{typeof(x),typeof(x),Vector{Float64}} - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. ut = max(β[2] * ut, abs(Δ)) - @. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ) - βp .= βp .* β + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. ut = max(β[2] * ut, abs(Δ)) + @. Δ = (η / (1 - βp[1])) * mt / (ut + ϵ) + βp .= βp .* β - return Δ + return Δ end """ @@ -296,28 +301,33 @@ opt = OADAM(0.001, (0.9, 0.995)) ``` """ mutable struct OADAM <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict end OADAM(η = 0.001, β = (0.5, 0.9)) = OADAM(η, β, IdDict()) function apply!(o::OADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, Δ_, βp = get!(o.state, x) do - (zero(x), zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}} + η, β = o.eta, o.beta + + mt, vt, Δ_, βp = get!(o.state, x) do + ( + zero(x), + zero(x), + zero(x), + Float64[β[1], β[2]], + ) + end::Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}} + + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = -Δ_ + @. Δ_ = η * mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) + @. Δ += 2Δ_ + βp .= βp .* β - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = -Δ_ - @. Δ_ = η * mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) - @. Δ += 2Δ_ - βp .= βp .* β - - return Δ + return Δ end """ @@ -339,17 +349,17 @@ opt = ADAGrad(0.001) ``` """ mutable struct ADAGrad <: AbstractOptimiser - eta::Float64 - acc::IdDict + eta::Float64 + acc::IdDict end ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) function apply!(o::ADAGrad, x, Δ) - η = o.eta - acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x) - @. acc += Δ^2 - @. Δ *= η / (√acc + ϵ) + η = o.eta + acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x) + @. acc += Δ^2 + @. Δ *= η / (√acc + ϵ) end """ @@ -370,21 +380,21 @@ opt = ADADelta(0.89) ``` """ mutable struct ADADelta <: AbstractOptimiser - rho::Float64 - state::IdDict + rho::Float64 + state::IdDict end ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) function apply!(o::ADADelta, x, Δ) - ρ = o.rho - acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. acc = ρ * acc + (1 - ρ) * Δ^2 - # DON'T remove epsilon from numerator - # or even out of the square roots - @. Δ *= √(Δacc + ϵ) / √(acc + ϵ) - @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 - return Δ + ρ = o.rho + acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} + @. acc = ρ * acc + (1 - ρ) * Δ^2 + # DON'T remove epsilon from numerator + # or even out of the square roots + @. Δ *= √(Δacc + ϵ) / √(acc + ϵ) + @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 + return Δ end """ @@ -407,24 +417,24 @@ opt = AMSGrad(0.001, (0.89, 0.995)) ``` """ mutable struct AMSGrad <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64, Float64} - state::IdDict + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict end AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) function apply!(o::AMSGrad, x, Δ) - η, β = o.eta, o.beta + η, β = o.eta, o.beta - mt, vt, v̂t = get!(o.state, x) do - (fill!(similar(x), ϵ), fill!(similar(x), ϵ), fill!(similar(x), ϵ)) - end :: NTuple{3,typeof(x)} + mt, vt, v̂t = get!(o.state, x) do + (fill!(similar(x), ϵ), fill!(similar(x), ϵ), fill!(similar(x), ϵ)) + end::NTuple{3,typeof(x)} - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 - @. v̂t = max(v̂t, vt) - @. Δ = η * mt / (√v̂t + ϵ) + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. v̂t = max(v̂t, vt) + @. Δ = η * mt / (√v̂t + ϵ) end """ @@ -447,27 +457,33 @@ opt = NADAM(0.002, (0.89, 0.995)) ``` """ mutable struct NADAM <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64, Float64} - state::IdDict + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict end NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) function apply!(o::NADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[o.beta[1], o.beta[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - β1p, β2p = βp - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η - βp .= βp .* β + η, β = o.eta, o.beta + + mt, vt, βp = get!(o.state, x) do + ( + zero(x), + zero(x), + Float64[o.beta[1], o.beta[2]], + ) + end::Tuple{typeof(x),typeof(x),Vector{Float64}} + β1p, β2p = βp + + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ^2 + @. Δ = + (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / + (√(vt * β[2] / (1 - β2p)) + ϵ) * η + βp .= βp .* β - return Δ + return Δ end """ @@ -491,7 +507,7 @@ opt = ADAMW(0.001, (0.89, 0.995), 0.1) ``` """ ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = - Optimiser(ADAM(1, β), WeightDecay(decay), Descent(η)) + Optimiser(ADAM(1, β), WeightDecay(decay), Descent(η)) """ AdaBelief(η = 0.001, β::Tuple = (0.9, 0.999)) @@ -513,20 +529,20 @@ opt = AdaBelief(0.001, (0.9, 0.8)) ``` """ mutable struct AdaBelief - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict + eta::Float64 + beta::Tuple{Float64,Float64} + state::IdDict end AdaBelief(η = 0.001, β = (0.9, 0.999)) = AdaBelief(η, β, IdDict()) function apply!(o::AdaBelief, x, Δ) - η, β = o.eta, o.beta - mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. st = β[2] * st + (1 - β[2]) * (Δ - mt)^2 - @. Δ = η * mt / (√(st) + ϵ) - return Δ + η, β = o.eta, o.beta + mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. st = β[2] * st + (1 - β[2]) * (Δ - mt)^2 + @. Δ = η * mt / (√(st) + ϵ) + return Δ end @@ -540,21 +556,26 @@ that will be fed into the next, and this is finally applied to the parameter as usual. """ mutable struct Optimiser <: AbstractOptimiser - os::Vector{Any} + os::Vector{Any} end Optimiser(o...) = Optimiser(Any[o...]) -@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! +@forward Optimiser.os Base.getindex, +Base.first, +Base.last, +Base.lastindex, +Base.push!, +Base.setindex! @forward Optimiser.os Base.iterate Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) function apply!(o::Optimiser, x, Δ) - for opt in o.os - Δ = apply!(opt, x, Δ) - end - return Δ + for opt in o.os + Δ = apply!(opt, x, Δ) + end + return Δ end """ @@ -579,18 +600,18 @@ opt = Optimiser(Adam(1f-3), InvDecay(1f-2)) ``` """ mutable struct InvDecay <: AbstractOptimiser - gamma::Float64 - state::IdDict + gamma::Float64 + state::IdDict end InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) function apply!(o::InvDecay, x, Δ) - γ = o.gamma - n = get!(o.state, x, 1) - Δ .*= 1 / (1 + γ * n) - o.state[x] = n + 1 - return Δ + γ = o.gamma + n = get!(o.state, x, 1) + Δ .*= 1 / (1 + γ * n) + o.state[x] = n + 1 + return Δ end """ @@ -620,24 +641,24 @@ opt = Optimiser(ADAM(), ExpDecay()) ``` """ mutable struct ExpDecay <: AbstractOptimiser - eta::Float64 - decay::Float64 - step::Int64 - clip::Float64 - current::IdDict + eta::Float64 + decay::Float64 + step::Int64 + clip::Float64 + current::IdDict end -ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = - ExpDecay(opt, decay, decay_step, clip, IdDict()) +ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = + ExpDecay(opt, decay, decay_step, clip, IdDict()) function apply!(o::ExpDecay, x, Δ) - η, s, decay = o.eta, o.step, o.decay - n = o.current[x] = get(o.current, x, 0) + 1 - if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 - η = max(η * decay, o.clip) - o.eta = η - end - @. Δ *= η + η, s, decay = o.eta, o.step, o.decay + n = o.current[x] = get(o.current, x, 0) + 1 + if o.current[x] % s == 0 && count(x -> x % s == 0, values(o.current)) == 1 + η = max(η * decay, o.clip) + o.eta = η + end + @. Δ *= η end """ @@ -655,14 +676,14 @@ opt = Optimiser(WeightDecay(1f-4), ADAM()) ``` """ mutable struct WeightDecay <: AbstractOptimiser - wd::Real + wd::Real end WeightDecay() = WeightDecay(0) function apply!(o::WeightDecay, x, Δ) - wd = o.wd - @. Δ += wd * x + wd = o.wd + @. Δ += wd * x end """ diff --git a/src/optimise/train.jl b/src/optimise/train.jl index a224bf3389..4a1daaefc5 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -7,7 +7,7 @@ import Zygote: Params, gradient Update the array `x` according to `x .-= x̄`. """ function update!(x::AbstractArray, x̄) - x .-= x̄ + x .-= x̄ end """ @@ -21,16 +21,16 @@ As a result, the parameters are mutated and the optimizer's internal state may c The gradient could be mutated as well. """ function update!(opt, x, x̄) - x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's - # output are not mutable, see #1510 - x .-= apply!(opt, x, x̄r) + x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's + # output are not mutable, see #1510 + x .-= apply!(opt, x, x̄r) end function update!(opt, xs::Params, gs) - for x in xs - isnothing(gs[x]) && continue - update!(opt, x, gs[x]) - end + for x in xs + isnothing(gs[x]) && continue + update!(opt, x, gs[x]) + end end # Callback niceties @@ -54,7 +54,7 @@ end ``` """ function skip() - throw(SkipException()) + throw(SkipException()) end @@ -74,7 +74,7 @@ end ``` """ function stop() - throw(StopException()) + throw(StopException()) end batchmemaybe(x) = tuple(x) @@ -97,25 +97,25 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, ps, data, opt; cb = () -> ()) - ps = Params(ps) - cb = runall(cb) - @progress for d in data - try - gs = gradient(ps) do - loss(batchmemaybe(d)...) - end - update!(opt, ps, gs) - cb() - catch ex - if ex isa StopException - break - elseif ex isa SkipException - continue - else - rethrow(ex) - end + ps = Params(ps) + cb = runall(cb) + @progress for d in data + try + gs = gradient(ps) do + loss(batchmemaybe(d)...) + end + update!(opt, ps, gs) + cb() + catch ex + if ex isa StopException + break + elseif ex isa SkipException + continue + else + rethrow(ex) + end + end end - end end """ @@ -134,8 +134,8 @@ hello ``` """ macro epochs(n, ex) - :(@progress for i = 1:$(esc(n)) - @info "Epoch $i" - $(esc(ex)) + :(@progress for i = 1:$(esc(n)) + @info "Epoch $i" + $(esc(ex)) end) end diff --git a/src/outputsize.jl b/src/outputsize.jl index 303e2c9c8b..4f7ff012eb 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -1,7 +1,7 @@ module NilNumber using NNlib -import Random +using Random: Random """ Nil <: Number @@ -13,21 +13,36 @@ struct Nil <: Number end const nil = Nil() -Nil(::T) where T<:Number = nil -(::Type{T})(::Nil) where T<:Number = nil +Nil(::T) where {T<:Number} = nil +(::Type{T})(::Nil) where {T<:Number} = nil Base.convert(::Type{Nil}, ::Number) = nil Base.float(::Type{Nil}) = Nil -for f in [:copy, :zero, :one, :oneunit, - :+, :-, :abs, :abs2, :inv, - :exp, :log, :log1p, :log2, :log10, - :sqrt, :tanh, :conj] - @eval Base.$f(::Nil) = nil +for f in [ + :copy, + :zero, + :one, + :oneunit, + :+, + :-, + :abs, + :abs2, + :inv, + :exp, + :log, + :log1p, + :log2, + :log10, + :sqrt, + :tanh, + :conj, +] + @eval Base.$f(::Nil) = nil end for f in [:+, :-, :*, :/, :^, :mod, :div, :rem] - @eval Base.$f(::Nil, ::Nil) = nil + @eval Base.$f(::Nil, ::Nil) = nil end Base.isless(::Nil, ::Nil) = true @@ -88,27 +103,27 @@ julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) # Vector of layers becom (2, 1) ``` """ -function outputsize(m, inputsizes::Tuple...; padbatch=false) - x = nil_input(padbatch, inputsizes...) - return size(m(x)) +function outputsize(m, inputsizes::Tuple...; padbatch = false) + x = nil_input(padbatch, inputsizes...) + return size(m(x)) end -nil_input(pad::Bool, s::Tuple{Vararg{Integer}}) = pad ? fill(nil, (s...,1)) : fill(nil, s) +nil_input(pad::Bool, s::Tuple{Vararg{Integer}}) = pad ? fill(nil, (s..., 1)) : fill(nil, s) nil_input(pad::Bool, multi::Tuple{Vararg{Integer}}...) = nil_input.(pad, multi) nil_input(pad::Bool, tup::Tuple{Vararg{Tuple}}) = nil_input(pad, tup...) -function outputsize(m::Chain, inputsizes::Tuple{Vararg{Integer}}...; padbatch=false) - x = nil_input(padbatch, inputsizes...) - for (i,lay) in enumerate(m.layers) - try - x = lay(x) - catch err - str = x isa AbstractArray ? "with input of size $(size(x))" : "" - @error "layer $lay, index $i in Chain, gave an error $str" - rethrow(err) +function outputsize(m::Chain, inputsizes::Tuple{Vararg{Integer}}...; padbatch = false) + x = nil_input(padbatch, inputsizes...) + for (i, lay) in enumerate(m.layers) + try + x = lay(x) + catch err + str = x isa AbstractArray ? "with input of size $(size(x))" : "" + @error "layer $lay, index $i in Chain, gave an error $str" + rethrow(err) + end end - end - return size(x) + return size(x) end """ @@ -142,29 +157,35 @@ outputsize ## make tuples and vectors be like Chains -outputsize(m::Tuple, input::Tuple...; padbatch=false) = outputsize(Chain(m...), input...; padbatch=padbatch) -outputsize(m::AbstractVector, input::Tuple...; padbatch=false) = outputsize(Chain(m...), input...; padbatch=padbatch) +function outputsize(m::Tuple, input::Tuple...; padbatch = false) + return outputsize(Chain(m...), input...; padbatch = padbatch) +end +function outputsize(m::AbstractVector, input::Tuple...; padbatch = false) + return outputsize(Chain(m...), input...; padbatch = padbatch) +end ## bypass statistics in normalization layers for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm) - @eval (l::$layer)(x::AbstractArray{Nil}) = x + @eval (l::$layer)(x::AbstractArray{Nil}) = x end ## fixes for layers that don't work out of the box for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) - @eval begin - function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims) - fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end]) - end - - function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) - NNlib.$fn(fill(nil, size(a)), b, dims) - end - - function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) - NNlib.$fn(a, fill(nil, size(b)), dims) + @eval begin + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims) + return fill( + nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end] + ) + end + + function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) + return NNlib.$fn(fill(nil, size(a)), b, dims) + end + + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) + return NNlib.$fn(a, fill(nil, size(b)), dims) + end end - end end diff --git a/src/utils.jl b/src/utils.jl index c1888829d4..965d709aca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,7 +28,7 @@ nfan() = 1, 1 # fan_in, fan_out nfan(n) = 1, n # A vector is treated as a n×1 matrix nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices nfan(dims::Tuple) = nfan(dims...) -nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels +nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels ofeltype(x, y) = convert(float(eltype(x)), y) epseltype(x) = eps(float(eltype(x))) @@ -61,7 +61,9 @@ julia> Flux.glorot_uniform(2, 3) [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...))) +function glorot_uniform(rng::AbstractRNG, dims...) + return (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...))) +end glorot_uniform(dims...) = glorot_uniform(Random.GLOBAL_RNG, dims...) glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...) @@ -94,7 +96,9 @@ julia> Flux.glorot_normal(3, 2) [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...))) +function glorot_normal(rng::AbstractRNG, dims...) + return randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...))) +end glorot_normal(dims...) = glorot_normal(Random.GLOBAL_RNG, dims...) glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...) @@ -128,12 +132,14 @@ julia> Flux.kaiming_uniform(3, 2) [1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015. """ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2) - bound = Float32(√3 * gain / sqrt(first(nfan(dims...)))) # fan_in - return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound + bound = Float32(√3 * gain / sqrt(first(nfan(dims...)))) # fan_in + return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound end kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...) -kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +function kaiming_uniform(rng::AbstractRNG; init_kwargs...) + return (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +end """ kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2) @@ -164,13 +170,15 @@ julia> Flux.kaiming_normal(3, 2) [1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015. """ -function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0) - std = Float32(gain / sqrt(first(nfan(dims...)))) # fan_in - return randn(rng, Float32, dims...) .* std +function kaiming_normal(rng::AbstractRNG, dims...; gain = √2.0f0) + std = Float32(gain / sqrt(first(nfan(dims...)))) # fan_in + return randn(rng, Float32, dims...) .* std end kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...) -kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +function kaiming_normal(rng::AbstractRNG; init_kwargs...) + return (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +end """ orthogonal([rng=GLOBAL_RNG], dims...; gain = 1) @@ -217,26 +225,29 @@ true """ function orthogonal(rng::AbstractRNG, rows::Integer, cols::Integer; gain = 1) - mat = rows > cols ? randn(rng, Float32, rows, cols) : randn(rng, Float32, cols, rows) + mat = rows > cols ? randn(rng, Float32, rows, cols) : randn(rng, Float32, cols, rows) - Q, R = LinearAlgebra.qr(mat) - Q = Array(Q) * sign.(LinearAlgebra.Diagonal(R)) - if rows < cols - Q = transpose(Q) - end + Q, R = LinearAlgebra.qr(mat) + Q = Array(Q) * sign.(LinearAlgebra.Diagonal(R)) + if rows < cols + Q = transpose(Q) + end - return gain * Q + return gain * Q end function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...) - dims = (d1, ds...) - rows = prod(dims[1:end-1]) - cols = dims[end] - return reshape(orthogonal(rng, rows, cols; kwargs...), dims) + dims = (d1, ds...) + rows = prod(dims[1:(end - 1)]) + cols = dims[end] + return reshape(orthogonal(rng, rows, cols; kwargs...), dims) end orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...) -orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) +function orthogonal(rng::AbstractRNG; init_kwargs...) + return (dims::Integer...; kwargs...) -> + orthogonal(rng, dims...; init_kwargs..., kwargs...) +end """ sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01) @@ -268,19 +279,25 @@ julia> Flux.sparse_init(3, 2, sparsity=0.1) [1] Martens, J, "Deep learning via Hessian-free optimization" _Proceedings of the 27th International Conference on International Conference on Machine Learning_. 2010. """ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01) - if length(dims) != 2 - throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) - end - rows, cols = dims - prop_zero = min(1.0, sparsity) - num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, Float32, dims...) .* Float32(std) - sparse_array[1:num_zeros, :] .= 0f0 - return mapslices(shuffle, sparse_array, dims=1) + if length(dims) != 2 + throw( + ArgumentError( + "Only 2-dimensional outputs are supported for sparse initialization." + ), + ) + end + rows, cols = dims + prop_zero = min(1.0, sparsity) + num_zeros = ceil(Integer, prop_zero * rows) + sparse_array = randn(rng, Float32, dims...) .* Float32(std) + sparse_array[1:num_zeros, :] .= 0.0f0 + return mapslices(shuffle, sparse_array; dims = 1) end sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...) -sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) +function sparse_init(rng::AbstractRNG; init_kwargs...) + return (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) +end """ identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0) @@ -346,25 +363,29 @@ julia> Flux.identity_init(3,3,2,2) ``` """ # Assume bias -identity_init(cols; gain=1, shift=0) = zeros32(cols) +identity_init(cols; gain = 1, shift = 0) = zeros32(cols) # Assume matrix multiplication -identity_init(rows, cols; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift) +function identity_init(rows, cols; gain = 1, shift = 0) + return circshift(Matrix{Float32}(I * gain, rows, cols), shift) +end # Assume convolution -function identity_init(dims...; gain=1, shift=0) - nin, nout = dims[end-1], dims[end] - centers = map(d -> cld(d, 2), dims[1:end-2]) - weights = zeros32(dims) - for i in 1:min(nin,nout) - weights[centers..., i, i] = gain - end - return circshift(weights, shift) +function identity_init(dims...; gain = 1, shift = 0) + nin, nout = dims[end - 1], dims[end] + centers = map(d -> cld(d, 2), dims[1:(end - 2)]) + weights = zeros32(dims) + for i in 1:min(nin, nout) + weights[centers..., i, i] = gain + end + return circshift(weights, shift) end identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...) identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...) -identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) +function identity_init(rng::AbstractRNG; init_kwargs...) + return (args...; kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) +end ones32(dims...) = Base.ones(Float32, dims...) zeros32(dims...) = Base.zeros(Float32, dims...) @@ -382,11 +403,12 @@ to the constructor's keyword `bias=bias`. * `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted. """ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) - bias ? fill!(similar(weights, dims...), 0) : Zeros() + return bias ? fill!(similar(weights, dims...), 0) : Zeros() end function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) - size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) - bias + size(bias) == dims || + throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) + return bias end """ @@ -421,7 +443,13 @@ julia> Flux.unsqueeze(xs, 1) ``` """ function unsqueeze(xs::AbstractArray, dim::Integer) - sz = ntuple(i -> i < dim ? size(xs, i) : i == dim ? 1 : size(xs, i - 1), ndims(xs) + 1) + sz = ntuple(i -> if i < dim + size(xs, i) + elseif i == dim + 1 + else + size(xs, i - 1) + end, ndims(xs) + 1) return reshape(xs, sz) end @@ -443,7 +471,9 @@ julia> rand(Float32, 10, 10) |> m |> size """ unsqueeze(dim::Integer) = Base.Fix2(unsqueeze, dim) -Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) = print(io, "unsqueeze(", u.x, ")") +function Base.show_function(io::IO, u::Base.Fix2{typeof(unsqueeze)}, ::Bool) + return print(io, "unsqueeze(", u.x, ")") +end """ stack(xs, dim) @@ -472,7 +502,7 @@ julia> cat(xs, dims=1) [5, 6] ``` """ -stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim) +stack(xs, dim) = cat(unsqueeze.(xs, dim)...; dims = dim) """ unstack(xs, dim) @@ -511,7 +541,7 @@ julia> Flux.chunk(collect(1:10), 3) [9, 10] ``` """ -chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n))) +chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs) / n))) batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i) @@ -529,11 +559,11 @@ Dict{Char, Int64} with 2 entries: ``` """ function frequencies(xs) - fs = Dict{eltype(xs),Int}() - for x in xs - fs[x] = get(fs, x, 0) + 1 - end - return fs + fs = Dict{eltype(xs),Int}() + for x in xs + fs[x] = get(fs, x, 0) + 1 + end + return fs end head(x::Tuple) = reverse(Base.tail(reverse(x))) @@ -557,13 +587,15 @@ julia> Flux.batch([[1,2,3],[4,5,6]]) ``` """ function batch(xs) - data = first(xs) isa AbstractArray ? - similar(first(xs), size(first(xs))..., length(xs)) : - Vector{eltype(xs)}(undef, length(xs)) - for (i, x) in enumerate(xs) - data[batchindex(data, i)...] = x - end - return data + data = if first(xs) isa AbstractArray + similar(first(xs), size(first(xs))..., length(xs)) + else + Vector{eltype(xs)}(undef, length(xs)) + end + for (i, x) in enumerate(xs) + data[batchindex(data, i)...] = x + end + return data end """ @@ -625,32 +657,32 @@ julia> Flux.batchseq([[1, 2, 3], [4, 5]], 0) ``` """ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs)) - xs_ = [rpad(x, n, pad) for x in xs] - [batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n] + xs_ = [rpad(x, n, pad) for x in xs] + return [batch([xs_[j][i] for j in 1:length(xs_)]) for i in 1:n] end # Flattening models to weight vectors, and back function _restructure(m, xs) - i = 0 - m̄ = fmap(m) do x - x isa AbstractArray || return x - x = reshape(xs[i.+(1:length(x))], size(x)) - i += length(x) - return x - end - length(xs) == i || @warn "Expected $(i) params, got $(length(xs))" - return m̄ + i = 0 + m̄ = fmap(m) do x + x isa AbstractArray || return x + x = reshape(xs[i .+ (1:length(x))], size(x)) + i += length(x) + return x + end + length(xs) == i || @warn "Expected $(i) params, got $(length(xs))" + return m̄ end @adjoint function _restructure(m, xs) - m̄, numel = _restructure(m, xs), length(xs) - function _restructure_pullback(dm) - xs′ = destructure(dm)[1] - numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))" - return (nothing, xs′) - end - return m̄, _restructure_pullback + m̄, numel = _restructure(m, xs), length(xs) + function _restructure_pullback(dm) + xs′ = destructure(dm)[1] + numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))" + return (nothing, xs′) + end + return m̄, _restructure_pullback end """ @@ -675,12 +707,12 @@ modifications to the weight vector (for example, with a hypernetwork). Chain(Dense(10, 5, σ), Dense(5, 2), softmax) """ function destructure(m) - xs = Zygote.Buffer([]) - fmap(m) do x - x isa AbstractArray && push!(xs, x) - return x - end - return vcat(vec.(copy(xs))...), p -> _restructure(m, p) + xs = Zygote.Buffer([]) + fmap(m) do x + x isa AbstractArray && push!(xs, x) + return x + end + return vcat(vec.(copy(xs))...), p -> _restructure(m, p) end # Other @@ -696,39 +728,38 @@ going more than once per `wait` duration; but if you'd like to disable the execution on the leading edge, pass `leading=false`. To enable execution on the trailing edge, pass `trailing=true`. """ -function throttle(f, timeout; leading=true, trailing=false) - cooldown = true - later = nothing - result = nothing - - function throttled(args...; kwargs...) - yield() - - if cooldown - if leading - result = f(args...; kwargs...) - else - later = () -> f(args...; kwargs...) - end - - cooldown = false - @async try - while (sleep(timeout); later != nothing) - later() - later = nothing +function throttle(f, timeout; leading = true, trailing = false) + cooldown = true + later = nothing + result = nothing + + return function throttled(args...; kwargs...) + yield() + + if cooldown + if leading + result = f(args...; kwargs...) + else + later = () -> f(args...; kwargs...) + end + + cooldown = false + @async try + while (sleep(timeout); later != nothing) + later() + later = nothing + end + finally + cooldown = true + end + elseif trailing + later = () -> (result = f(args...; kwargs...)) end - finally - cooldown = true - end - elseif trailing - later = () -> (result = f(args...; kwargs...)) - end - return result - end + return result + end end - """ modules(m) @@ -799,13 +830,13 @@ julia> Flux.@epochs 10 begin ``` """ function patience(predicate, wait) - let count = 0 - function on_trigger(args...; kwargs...) - count = predicate(args...; kwargs...) ? count + 1 : 0 + let count = 0 + function on_trigger(args...; kwargs...) + count = predicate(args...; kwargs...) ? count + 1 : 0 - return count >= wait + return count >= wait + end end - end end """ @@ -836,17 +867,17 @@ julia> Flux.@epochs 10 begin ``` """ function early_stopping(f, delay; distance = -, init_score = 0, min_dist = 0) - trigger = let best_score = init_score - (args...; kwargs...) -> begin - score = f(args...; kwargs...) - Δ = distance(best_score, score) - best_score = Δ < 0 ? best_score : score + trigger = let best_score = init_score + (args...; kwargs...) -> begin + score = f(args...; kwargs...) + Δ = distance(best_score, score) + best_score = Δ < 0 ? best_score : score - return Δ < min_dist + return Δ < min_dist + end end - end - return patience(trigger, delay) + return patience(trigger, delay) end """ @@ -878,15 +909,15 @@ julia> Flux.@epochs 10 begin ``` """ function plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6) - is_plateau = let last_score = init_score - (args...; kwargs...) -> begin - score = f(args...; kwargs...) - Δ = abs(distance(last_score, score)) - last_score = score + is_plateau = let last_score = init_score + (args...; kwargs...) -> begin + score = f(args...; kwargs...) + Δ = abs(distance(last_score, score)) + last_score = score - return Δ < min_dist + return Δ < min_dist + end end - end - return patience(is_plateau, width) + return patience(is_plateau, width) end diff --git a/src/zeros.jl b/src/zeros.jl index 1281f4c87a..9cc9419701 100644 --- a/src/zeros.jl +++ b/src/zeros.jl @@ -1,4 +1,4 @@ -import Base: +, -, *,/, reshape, broadcasted +import Base: +, -, *, /, reshape, broadcasted """ Zeros() @@ -44,9 +44,12 @@ broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = b broadcasted(::typeof(-), a::AbstractArray, b::Zeros) = a broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = -b # Need adjoints for these or else the gradient w.r.t to the non-Zeros arg will be nothing as well -@adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = zero(a), _ -> (nothing, zero(a), nothing) -@adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) -@adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b)) +@adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = + zero(a), _ -> (nothing, zero(a), nothing) +@adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = + zero(b), _ -> (nothing, nothing, zero(b)) +@adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = + zero(b), _ -> (nothing, nothing, zero(b)) # Pass-through for layer constructors create_bias(weights::AbstractArray, bias::Flux.Zeros, dims::Integer...) = bias