Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spaces around = for kwargs #1787

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
indent = 4
whitespace_in_kwargs = true
normalize_line_endings = "unix"
style = "blue"
import_to_using = false
69 changes: 56 additions & 13 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,67 @@ using MacroTools: @forward
using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Upsample, PixelShuffle,
params, fmap, cpu, gpu, f32, f64,
testmode!, trainmode!
export Chain,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to change this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is because the original line(s) were too long. Grouping into multiple export statements per https://github.com/invenia/BlueStyle#function-exports would get as back to if not under the original line count.

Dense,
Maxout,
SkipConnection,
Parallel,
flatten,
RNN,
LSTM,
GRU,
GRUv3,
SamePad,
Conv,
CrossCor,
ConvTranspose,
DepthwiseConv,
AdaptiveMaxPool,
AdaptiveMeanPool,
GlobalMaxPool,
GlobalMeanPool,
MaxPool,
MeanPool,
Dropout,
AlphaDropout,
LayerNorm,
BatchNorm,
InstanceNorm,
GroupNorm,
Upsample,
PixelShuffle,
params,
fmap,
cpu,
gpu,
f32,
f64,
testmode!,
trainmode!

include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
using .Optimise: skip
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, OADAM,
ADAMW, RADAM, AdaBelief, InvDecay, ExpDecay,
WeightDecay, ClipValue, ClipNorm

export Descent,
ADAM,
Momentum,
Nesterov,
RMSProp,
ADAGrad,
AdaMax,
ADADelta,
AMSGrad,
NADAM,
OADAM,
ADAMW,
RADAM,
AdaBelief,
InvDecay,
ExpDecay,
WeightDecay,
ClipValue,
ClipNorm

using CUDA
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)
Expand Down
3 changes: 2 additions & 1 deletion src/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using ..CUDA
import ..Flux: Flux
import Zygote
using Zygote: @adjoint
import NNlib, NNlibCUDA
using NNlib: NNlib
using NNlibCUDA: NNlibCUDA

include("cudnn.jl")

Expand Down
42 changes: 28 additions & 14 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
import NNlibCUDA: batchnorm, ∇batchnorm

function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
cache=nothing) where T<:Union{Float32, Float64}

@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels"
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
training=Flux._isactive(BN)))
function (BN::Flux.BatchNorm)(
x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing
) where {T<:Union{Float32,Float64}}
@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x) - 1) "BatchNorm: input has wronng number of channels"
return BN.λ.(
batchnorm(
BN.γ,
BN.β,
x,
BN.μ,
BN.σ²,
BN.momentum;
cache = cache,
alpha = 1,
beta = 0,
eps = BN.ϵ,
training = Flux._isactive(BN),
),
)
end

@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function batchnorm_pullback(Δ)
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
end
y, batchnorm_pullback
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function batchnorm_pullback(Δ)
return ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)...,
nothing, nothing,
nothing
end
return y, batchnorm_pullback
end
27 changes: 15 additions & 12 deletions src/data/Data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@ export DataLoader
## Also remove the following deps:
## AbstractTrees, ZipFiles, CodecZLib

import ..Flux
import SHA
using ..Flux: Flux
using SHA: SHA

deprecation_message() = @warn("Flux's datasets are deprecated, please use the package MLDatasets.jl")
function deprecation_message()
@warn("Flux's datasets are deprecated, please use the package MLDatasets.jl")
end

function deps(path...)
if isnothing(@__DIR__) # sysimages
joinpath("deps", path...)
else
joinpath(@__DIR__, "..", "..", "deps", path...)
end
if isnothing(@__DIR__) # sysimages
joinpath("deps", path...)
else
joinpath(@__DIR__, "..", "..", "deps", path...)
end
end

function download_and_verify(url, path, hash)
Expand All @@ -30,16 +32,16 @@ function download_and_verify(url, path, hash)
bytes2hex(SHA.sha256(f))
end
if hash_download !== hash
msg = "Hash Mismatch!\n"
msg = "Hash Mismatch!\n"
msg *= " Expected sha256: $hash\n"
msg *= " Calculated sha256: $hash_download"
error(msg)
end
mv(tmppath, path; force=true)
return mv(tmppath, path; force = true)
end

function __init__()
mkpath(deps())
return mkpath(deps())
end

include("mnist.jl")
Expand All @@ -50,7 +52,8 @@ export FashionMNIST

include("cmudict.jl")
export CMUDict
using .CMUDict; export cmudict
using .CMUDict
export cmudict

include("tree.jl")
include("sentiment.jl")
Expand Down
70 changes: 43 additions & 27 deletions src/data/cmudict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,44 @@ const version = "0.7b"
const cache_prefix = "https://cache.julialang.org"

function load()
suffixes_and_hashes = [("" , "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"),
(".phones" , "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"),
(".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027")]
if isdir(deps("cmudict"))
if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes)
return
suffixes_and_hashes = [
("", "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4"),
(".phones", "ffb588a5e55684723582c7256e1d2f9fadb130011392d9e59237c76e34c2cfd6"),
(".symbols", "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027"),
]
if isdir(deps("cmudict"))
if all(isfile(deps("cmudict", "cmudict$x")) for (x, _) in suffixes_and_hashes)
return nothing
end
end
@info "Downloading CMUDict dataset"
mkpath(deps("cmudict"))
for (x, hash) in suffixes_and_hashes
download_and_verify(
"$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
deps("cmudict", "cmudict$x"),
hash,
)
end
end
@info "Downloading CMUDict dataset"
mkpath(deps("cmudict"))
for (x, hash) in suffixes_and_hashes
download_and_verify("$cache_prefix/https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
deps("cmudict", "cmudict$x"), hash)
end
end

"""
phones()
Return a `Vector` containing the phones used in the CMU Pronouncing Dictionary.
"""
function phones()
deprecation_message()
load()
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
"\n", keepempty = false), "\t")))
deprecation_message()
load()
return Symbol.(
first.(
split.(
split(
read(deps("cmudict", "cmudict.phones"), String), "\n"; keepempty = false
),
"\t",
),
),
)
end

"""
Expand All @@ -42,21 +55,24 @@ A symbol is a phone with optional auxiliary symbols, indicating for example the
amount of stress on the phone.
"""
function symbols()
deprecation_message()
load()
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
"\n", keepempty = false))
deprecation_message()
load()
return Symbol.(
split(read(deps("cmudict", "cmudict.symbols"), String), "\n"; keepempty = false)
)
end

"""
rawdict()
Return the unfiltered CMU Pronouncing Dictionary.
"""
function rawdict()
deprecation_message()
load()
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n"))))
deprecation_message()
load()
return Dict(
String(xs[1]) => Symbol.(xs[2:end]) for xs in
filter(!isempty, split.(split(read(deps("cmudict", "cmudict"), String), "\n")))
)
end

validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
Expand All @@ -68,8 +84,8 @@ It is filtered so each word contains only ASCII characters and a combination of
word characters (as determined by the regex engine using `\\w`), '-' and '.'.
"""
function cmudict()
deprecation_message()
filter(p -> validword(p.first), rawdict())
deprecation_message()
return filter(p -> validword(p.first), rawdict())
end

alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
Expand Down
24 changes: 14 additions & 10 deletions src/data/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ julia> foreach(println∘summary, Flux.DataLoader(rand(Int8, 10, 64), batchsize=
10×4 Matrix{Int8}
```
"""
function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
function DataLoader(data; batchsize = 1, shuffle = false, partial = true, rng = GLOBAL_RNG)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))

n = _nobs(data)
Expand All @@ -83,39 +83,43 @@ function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_R
batchsize = n
end
imax = partial ? n : n - batchsize + 1
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng)
return DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng)
end

@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
@propagate_inbounds function Base.iterate(d::DataLoader, i = 0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
shuffle!(d.rng, d.indices)
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
ids = d.indices[(i + 1):nexti]
batch = _getobs(d.data, ids)
return (batch, nexti)
end

function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
return d.partial ? ceil(Int, n) : floor(Int, n)
end

_nobs(data::AbstractArray) = size(data)[end]

function _nobs(data::Union{Tuple, NamedTuple})
function _nobs(data::Union{Tuple,NamedTuple})
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
for i in keys(data)
ni = _nobs(data[i])
n == ni || throw(DimensionMismatch("All data inputs should have the same number of observations, i.e. size in the last dimension. " *
"But data[$(repr(first(keys(data))))] ($(summary(data[1]))) has $n, while data[$(repr(i))] ($(summary(data[i]))) has $ni."))
n == ni || throw(
DimensionMismatch(
"All data inputs should have the same number of observations, i.e. size in the last dimension. " *
"But data[$(repr(first(keys(data))))] ($(summary(data[1]))) has $n, while data[$(repr(i))] ($(summary(data[i]))) has $ni.",
),
)
end
return n
end

_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
_getobs(data::Union{Tuple,NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)

Base.eltype(::DataLoader{D}) where D = D
Base.eltype(::DataLoader{D}) where {D} = D
Loading