From 6b45e21d644995381fc2aa477c906161b9d023d1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Oct 2024 04:49:17 +0200 Subject: [PATCH] remove v0.13 deprecations remove v0.13 deprecations reinsert optimisers add 1.9 CI drop julia v1.9 --- .github/workflows/ci.yml | 2 +- Project.toml | 2 +- ext/FluxEnzymeExt/FluxEnzymeExt.jl | 4 -- src/deprecations.jl | 95 ------------------------------ src/layers/normalise.jl | 22 +++---- src/layers/stateless.jl | 5 +- src/losses/Losses.jl | 2 +- src/losses/functions.jl | 52 +++++++--------- test/layers/normalisation.jl | 2 +- test/layers/show.jl | 2 +- test/losses.jl | 12 ++-- test/outputsize.jl | 1 - test/runtests.jl | 16 ++--- test/utils.jl | 10 ++-- 14 files changed, 51 insertions(+), 176 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aa4c13d212..51d5407c25 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: version: - # - '1.9' # Uncomment when 1.10 is out. Replace this with the minimum Julia version that your package supports. + - '1.10' # Replace this with the minimum Julia version that your package supports. - '1' os: [ubuntu-latest] arch: [x64] diff --git a/Project.toml b/Project.toml index d805332f20..a3219459cf 100644 --- a/Project.toml +++ b/Project.toml @@ -67,4 +67,4 @@ SpecialFunctions = "2.1.2" Statistics = "1" Zygote = "0.6.67" cuDNN = "1" -julia = "1.9" +julia = "1.10" diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index e6ce51297f..5ac7c2e577 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -16,10 +16,6 @@ _applyloss(loss, model, d...) = loss(model, d...) EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true -using Flux: _old_to_new # from src/deprecations.jl -train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) = - train!(loss, model, data, _old_to_new(opt); cb) - function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end diff --git a/src/deprecations.jl b/src/deprecations.jl index 8dadadfd6d..6148894dbe 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,47 +1,6 @@ # v0.13 deprecations -function Broadcast.broadcasted(f::Recur, args...) - # This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12 - Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order. - Re-writing this as a comprehension would be better.""", :broadcasted) - map(f, args...) # map isn't really safe either, but -end - -@deprecate frequencies(xs) group_counts(xs) - -struct Zeros - function Zeros() - Base.depwarn("Flux.Zeros is no more, has ceased to be, is bereft of life, is an ex-boondoggle... please use bias=false instead", :Zeros) - false - end -end -Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros()) - -function Optimise.update!(x::AbstractArray, x̄) - Base.depwarn("`Flux.Optimise.update!(x, x̄)` was not used internally and has been removed. Please write `x .-= x̄` instead.", :update!) - x .-= x̄ -end - -function Diagonal(size::Integer...; kw...) - Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) - Scale(size...; kw...) -end -function Diagonal(size::Tuple; kw...) - Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) - Scale(size...; kw...) -end - -# Deprecate this eventually once saving models w/o structure is no more -function loadparams!(m, xs) - Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!) - for (p, x) in zip(params(m), xs) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") - copyto!(p, x) - end -end - # Channel notation: Changed to match Conv, but very softly deprecated! # Perhaps change to @deprecate for v0.15, but there is no plan to remove these. Dense(in::Integer, out::Integer, σ = identity; kw...) = @@ -56,32 +15,6 @@ LSTMCell(in::Integer, out::Integer; kw...) = LSTMCell(in => out; kw...) GRUCell(in::Integer, out::Integer; kw...) = GRUCell(in => out; kw...) GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...) -# Optimisers with old naming convention -Base.@deprecate_binding ADAM Adam -Base.@deprecate_binding NADAM NAdam -Base.@deprecate_binding ADAMW AdamW -Base.@deprecate_binding RADAM RAdam -Base.@deprecate_binding OADAM OAdam -Base.@deprecate_binding ADAGrad AdaGrad -Base.@deprecate_binding ADADelta AdaDelta - -# Remove sub-module Data, while making sure Flux.Data.DataLoader keeps working -Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. The only thing it contained may be accessed as Flux.DataLoader" - -@deprecate paramtype(T,m) _paramtype(T,m) false # internal method, renamed to make this clear - -@deprecate rng_from_array() Random.default_rng() - -function istraining() - Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining) - false -end -ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) - -function _isactive(m) - Base.depwarn("_isactive(m) is deprecated, use _isactive(m,x)", :_isactive, force=true) - _isactive(m, 1:0) -end #= # Valid method in Optimise, old implicit style, is: @@ -110,7 +43,6 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb) - # Next, to use the new `setup` with the still-exported old-style `Adam` etc: import .Train: setup setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) @@ -179,33 +111,6 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, """) end -""" - trainmode!(m, active) - -!!! warning - This two-argument method is deprecated. - -Possible values of `active` are: -- `true` for training, or -- `false` for testing, same as [`testmode!`](@ref)`(m)` -- `:auto` or `nothing` for Flux to detect training automatically. -""" -function trainmode!(m, active::Bool) - Base.depwarn("trainmode!(m, active::Bool) is deprecated", :trainmode) - testmode!(m, !active) -end - -# Greek-letter keywords deprecated in Flux 0.13 -# Arguments (old => new, :function, "β" => "beta") -function _greek_ascii_depwarn(βbeta::Pair, func = :loss, names = "" => "") - Base.depwarn(LazyString("function ", func, " no longer accepts greek-letter keyword ", names.first, """ - please use ascii """, names.second, " instead"), func) - βbeta.first -end -_greek_ascii_depwarn(βbeta::Pair{Nothing}, _...) = βbeta.second - -ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...) - # v0.14 deprecations @deprecate default_rng_value() Random.default_rng() diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index c6663cca88..99092f9756 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -191,10 +191,9 @@ struct LayerNorm{F,D,T,N} affine::Bool end -function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5, ϵ=nothing) - ε = _greek_ascii_depwarn(ϵ => eps, :LayerNorm, "ϵ" => "eps") +function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5) diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity - return LayerNorm(λ, diag, ε, size, affine) + return LayerNorm(λ, diag, eps, size, affine) end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) @@ -328,9 +327,8 @@ end function BatchNorm(chs::Int, λ=identity; initβ=zeros32, initγ=ones32, affine::Bool=true, track_stats::Bool=true, active::Union{Bool,Nothing}=nothing, - eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) + eps::Real=1f-5, momentum::Real=0.1f0) - ε = _greek_ascii_depwarn(ϵ => eps, :BatchNorm, "ϵ" => "eps") β = affine ? initβ(chs) : nothing γ = affine ? initγ(chs) : nothing @@ -338,7 +336,7 @@ function BatchNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return BatchNorm(λ, β, γ, - μ, σ², ε, momentum, + μ, σ², eps, momentum, affine, track_stats, active, chs) end @@ -421,9 +419,7 @@ end function InstanceNorm(chs::Int, λ=identity; initβ=zeros32, initγ=ones32, affine::Bool=false, track_stats::Bool=false, active::Union{Bool,Nothing}=nothing, - eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) - - ε = _greek_ascii_depwarn(ϵ => eps, :InstanceNorm, "ϵ" => "eps") + eps::Real=1f-5, momentum::Real=0.1f0) β = affine ? initβ(chs) : nothing γ = affine ? initγ(chs) : nothing @@ -431,7 +427,7 @@ function InstanceNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return InstanceNorm(λ, β, γ, - μ, σ², ε, momentum, + μ, σ², eps, momentum, affine, track_stats, active, chs) end @@ -520,9 +516,7 @@ end function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, affine::Bool=true, active::Union{Bool,Nothing}=nothing, - eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) - - ε = _greek_ascii_depwarn(ϵ => eps, :GroupNorm, "ϵ" => "eps") + eps::Real=1f-5, momentum::Real=0.1f0) chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)") @@ -535,7 +529,7 @@ function GroupNorm(chs::Int, G::Int, λ=identity; return GroupNorm(G, λ, β, γ, μ, σ², - ε, momentum, + eps, momentum, affine, track_stats, active, chs) end diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 2565ea2e84..4fb739e0c3 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -34,11 +34,10 @@ julia> isapprox(std(y; dims=1, corrected=false), ones(1, 10), atol=1e-5) true ``` """ -@inline function normalise(x::AbstractArray; dims=ndims(x), eps=ofeltype(x, 1e-5), ϵ=nothing) - ε = _greek_ascii_depwarn(ϵ => eps, :InstanceNorm, "ϵ" => "eps") +@inline function normalise(x::AbstractArray; dims=ndims(x), eps=ofeltype(x, 1e-5)) μ = mean(x, dims=dims) σ = std(x, dims=dims, mean=μ, corrected=false) - return @. (x - μ) / (σ + ε) + return @. (x - μ) / (σ + eps) end """ diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 34315baadd..5b4a1d697b 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -4,7 +4,7 @@ using Statistics using Zygote using Zygote: @adjoint using ChainRulesCore -using ..Flux: ofeltype, epseltype, _greek_ascii_depwarn +using ..Flux: ofeltype, epseltype using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 7897cc5754..5f9778e93c 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -66,10 +66,9 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3) 0.011100831f0 ``` """ -function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :msle, "ϵ" => "eps") +function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) - agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 ) + agg((log.((ŷ .+ eps) ./ (y .+ eps))) .^2 ) end function _huber_metric(abs_error, δ) @@ -101,9 +100,8 @@ julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| > 0.003750000000000005 ``` """ -function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing) - delta_tmp = _greek_ascii_depwarn(δ => delta, :huber_loss, "δ" => "delta") - δ = ofeltype(ŷ, delta_tmp) +function huber_loss(ŷ, y; agg = mean, delta::Real = 1) + δ = ofeltype(ŷ, delta) _check_sizes(ŷ, y) abs_error = abs.(ŷ .- y) @@ -230,10 +228,9 @@ julia> Flux.crossentropy(y_model, y_smooth) 1.5776052f0 ``` """ -function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :crossentropy, "ϵ" => "eps") +function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) - agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims)) + agg(.-sum(xlogy.(y, ŷ .+ eps); dims = dims)) end """ @@ -319,10 +316,9 @@ julia> Flux.crossentropy(y_prob, y_hot) 0.43989f0 ``` """ -function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :binarycrossentropy, "ϵ" => "eps") +function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) - agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ))) + agg(@.(-xlogy(y, ŷ + eps) - xlogy(1 - y, 1 - ŷ + eps))) end """ @@ -390,11 +386,10 @@ julia> Flux.kldivergence(p1, p2; eps = 0) # about 17.3 with the regulator Inf ``` """ -function kldivergence(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :kldivergence, "ϵ" => "eps") +function kldivergence(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) entropy = agg(sum(xlogx.(y); dims = dims)) - cross_entropy = crossentropy(ŷ, y; dims, agg, eps=ϵ) + cross_entropy = crossentropy(ŷ, y; dims, agg, eps) return entropy + cross_entropy end @@ -531,13 +526,12 @@ Calculated as: """ function tversky_loss(ŷ, y; beta::Real = 0.7, β = nothing) - beta_temp = _greek_ascii_depwarn(β => beta, :tversky_loss, "β" => "beta") - β = ofeltype(ŷ, beta_temp) - _check_sizes(ŷ, y) - #TODO add agg - num = sum(y .* ŷ) + 1 - den = sum(y .* ŷ + β * (1 .- y) .* ŷ + (1 - β) * y .* (1 .- ŷ)) + 1 - 1 - num / den + β = ofeltype(ŷ, beta) + _check_sizes(ŷ, y) + #TODO add agg + num = sum(y .* ŷ) + 1 + den = sum(y .* ŷ + β * (1 .- y) .* ŷ + (1 - β) * y .* (1 .- ŷ)) + 1 + 1 - num / den end """ @@ -568,12 +562,10 @@ julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 true ``` """ -function binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ = nothing, γ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :binary_focal_loss, "ϵ" => "eps") - gamma_temp = _greek_ascii_depwarn(γ => gamma, :binary_focal_loss, "γ" => "gamma") - γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp) +function binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps::Real=epseltype(ŷ)) + γ = gamma isa Integer ? gamma : ofeltype(ŷ, gamma) _check_sizes(ŷ, y) - ŷϵ = ŷ .+ ϵ + ŷϵ = ŷ .+ eps p_t = y .* ŷϵ + (1 .- y) .* (1 .- ŷϵ) ce = .-log.(p_t) weight = (1 .- p_t) .^ γ @@ -616,11 +608,9 @@ See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ=nothing, γ=nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :focal_loss, "ϵ" => "eps") - gamma_temp = _greek_ascii_depwarn(γ => gamma, :focal_loss, "γ" => "gamma") - γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp) + γ = gamma isa Integer ? gamma : ofeltype(ŷ, gamma) _check_sizes(ŷ, y) - ŷϵ = ŷ .+ ϵ + ŷϵ = ŷ .+ eps agg(sum(@. -y * (1 - ŷϵ)^γ * log(ŷϵ); dims)) end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 6c1b78919f..be7c5dec92 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -446,7 +446,7 @@ end @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) m2 = Chain(BatchNorm(3), sum) - @test_broken Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) + @test Zygote.hessian_reverse(m2, Float32[1 2; 3 4; 5 6]) == zeros(Float32, 6, 6) broken = VERSION >= v"1.11" end @testset "ForwardDiff" begin diff --git a/test/layers/show.jl b/test/layers/show.jl index 6910e5fa08..95ddca0571 100644 --- a/test/layers/show.jl +++ b/test/layers/show.jl @@ -76,7 +76,7 @@ end # Bug when no children, https://github.com/FluxML/Flux.jl/issues/2208 struct NoFields end -Flux.@functor NoFields +Flux.@layer NoFields @testset "show with no fields" begin str = repr("text/plain", Chain(Dense(1=>1), Dense(1=>1), NoFields())) diff --git a/test/losses.jl b/test/losses.jl index a5ce1139df..1285f30148 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -76,7 +76,7 @@ y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:] @test crossentropy(ŷ, y_smoothed) ≈ lossvalue_smoothed @test crossentropy(ylp, label_smoothing(yl, 2sf)) ≈ -sum(yls.*log.(ylp)) @test crossentropy(ylp, yl) ≈ -sum(yl.*log.(ylp)) - @test iszero(crossentropy(y_same, ya, ϵ=0)) # ε is deprecated + @test iszero(crossentropy(y_same, ya, eps=0)) @test iszero(crossentropy(ya, ya, eps=0)) @test crossentropy(y_sim, ya) < crossentropy(y_sim, ya_smoothed) @test crossentropy(y_dis, ya) > crossentropy(y_dis, ya_smoothed) @@ -92,15 +92,15 @@ logŷ, y = randn(3), rand(3) yls = y.*(1-2sf).+sf @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); eps=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) @test binarycrossentropy(σ.(logŷ), y; eps=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) @test binarycrossentropy(σ.(logŷ), y) ≈ mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) @test binarycrossentropy([0.1,0.2,0.9], 1) ≈ -mean(log, [0.1,0.2,0.9]) # constant label end @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); ϵ=0) - @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; ϵ=0) + @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); eps=0) + @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; eps=0) end y = onehotbatch([1], 0:1) @@ -152,7 +152,7 @@ end @testset "tversky_loss" begin @test Flux.tversky_loss(ŷ, y) ≈ -0.06772009029345383 - @test Flux.tversky_loss(ŷ, y, β=0.8) ≈ -0.09490740740740744 + @test Flux.tversky_loss(ŷ, y, beta=0.8) ≈ -0.09490740740740744 @test Flux.tversky_loss(y, y) ≈ -0.5576923076923075 end @@ -180,7 +180,7 @@ end 0.4 0.7] @test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 - @test Flux.binary_focal_loss(ŷ, y; γ=0.0) ≈ Flux.binarycrossentropy(ŷ, y) + @test Flux.binary_focal_loss(ŷ, y; gamma=0.0) ≈ Flux.binarycrossentropy(ŷ, y) end @testset "focal_loss" begin diff --git a/test/outputsize.jl b/test/outputsize.jl index 0eab572eb7..55cb823c5c 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -10,7 +10,6 @@ @test outputsize(m, (10,); padbatch=true) == (2, 1) @test outputsize(m, (10, 30)) == (2, 30) - @info "Don't mind the following error, it's for testing purpose." m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) @test_throws DimensionMismatch outputsize(m, (10,)) diff --git a/test/runtests.jl b/test/runtests.jl index c48b281c92..ef3d67f4d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,14 +65,6 @@ Random.seed!(0) @testset "functors" begin include("functors.jl") end - - @static if VERSION == v"1.9" - using Documenter - @testset "Docs" begin - DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - doctest(Flux) - end - end else @info "Skipping CPU tests." end @@ -143,10 +135,10 @@ Random.seed!(0) end if get(ENV, "FLUX_TEST_ENZYME", "true") == "true" - @testset "Enzyme" begin - import Enzyme - include("ext_enzyme/enzyme.jl") - end + @testset "Enzyme" begin + import Enzyme + include("ext_enzyme/enzyme.jl") + end else @info "Skipping Enzyme tests, set FLUX_TEST_ENZYME=true to run them." end diff --git a/test/utils.jl b/test/utils.jl index e05d5f4562..b526b63286 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -273,14 +273,14 @@ end @testset "params gradient" begin m = (x=[1,2.0], y=[3.0]); - @test_broken begin - # Explicit -- was broken by #2054 / then fixed / now broken again on julia v0.11 + @test begin + # Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11 gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] @test gnew.y ≈ [1.0] true - end - + end broken = VERSION >= v"1.11" + # Implicit gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m)) @test gold[m.x] ≈ [0.4472135954999579, 0.8944271909999159] @@ -689,6 +689,6 @@ end # make sure rng_from_array is non_differentiable @testset "rng_from_array" begin - m(x) = (rand(rng_from_array(x)) * x)[1] + m(x) = (rand(Flux.rng_from_array(x)) * x)[1] gradient(m, ones(2)) end