From 125d6e2b9ba2afd022d05887aee9732ad7d23dec Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 2 Mar 2024 09:22:15 -0500 Subject: [PATCH 1/3] restore and test some support for Tracker.jl --- Project.toml | 4 +++- src/layers/stateless.jl | 3 ++- src/utils.jl | 5 ++++- test/runtests.jl | 1 + test/tracker.jl | 38 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 test/tracker.jl diff --git a/Project.toml b/Project.toml index 660ca26296..29bc4f54a5 100644 --- a/Project.toml +++ b/Project.toml @@ -52,6 +52,7 @@ ProgressLogging = "0.1" Reexport = "1.0" SpecialFunctions = "2.1.2" Statistics = "1" +Tracker = "0.2.32" Zygote = "0.6.67" cuDNN = "1" julia = "1.9" @@ -68,7 +69,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", "Tracker"] diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 884af4053b..2565ea2e84 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -81,7 +81,8 @@ end _match_eltype(layer, ::Type, x::AbstractArray) = x # 2-arg method, for common layers with layer.weight -_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x) +# NB using _eltype gets Float64 from Tracker.TrackedArray{Float64}, not TrackedReal +_match_eltype(layer, x) = _match_eltype(layer, _eltype(layer.weight), x) # Trivial rule: function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T} diff --git a/src/utils.jl b/src/utils.jl index 082d9dcb1c..1f8230c522 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -501,9 +501,12 @@ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) end function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) - convert(AbstractArray{eltype(weights)}, bias) + convert(AbstractArray{_eltype(weights)}, bias) end +# This avoids the issue that Tracker.TrackedArray{Float64} declares eltype() = TrackedReal +_eltype(::AbstractArray{T}) where T = T + # Other diff --git a/test/runtests.jl b/test/runtests.jl index 8dca6becdd..01a830f815 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ Random.seed!(0) @testset "Optimise / Train" begin include("optimise.jl") include("train.jl") + include("tracker.jl") end @testset "Data" begin diff --git a/test/tracker.jl b/test/tracker.jl new file mode 100644 index 0000000000..598e20ae2c --- /dev/null +++ b/test/tracker.jl @@ -0,0 +1,38 @@ +using Tracker: withgradient +using Zygote: gradient +using Functors: fmapstructure +using Flux + +@testset "Tracker.jl" begin + @testset "some simple models" begin + m1 = Dense(ones32(2,3), fill(0.1f0,2), abs2) + x1 = Float32[1,2,3] + (_, v1), g1 = withgradient(m1, x1) do m, x + y1 = m(x) + sum(abs2, y1 .- [4, 5]), y1 + end + @test v1 ≈ m1(x1) + g1z = gradient(m1, x1) do m, x + sum(abs2, m(x) .- [4, 5]) + end + @test g1[1].weight ≈ g1z[1].weight + @test g1[1].bias ≈ g1z[1].bias + + m2 = Chain(Conv((2,2), 3 => 1, relu), Flux.flatten, Dense(20 => 1, tanh), only) + x2 = randn32(5,6,3,1) + v2, g2 = withgradient(m -> m(x2), m2) + g2z = gradient(m -> m(x2), m2) + @test g2[1].layers[1].weight ≈ g2z[1].layers[1].weight + @test g2[1].layers[1].bias ≈ g2z[1].layers[1].bias + @test g2[1].layers[3].weight ≈ g2z[1].layers[3].weight + end + + @testset "Dropout" begin + g1z = gradient(sum∘Dropout(0.5), ones(1000)) + v1, g1 = withgradient(sum∘Dropout(0.5), ones(1000)) + @test 800 Date: Wed, 20 Mar 2024 23:00:24 -0400 Subject: [PATCH 2/3] bump Tracker compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 29bc4f54a5..2d7371b1b2 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ ProgressLogging = "0.1" Reexport = "1.0" SpecialFunctions = "2.1.2" Statistics = "1" -Tracker = "0.2.32" +Tracker = "0.2.33" Zygote = "0.6.67" cuDNN = "1" julia = "1.9" From 1bfa737080946ed5426079d69bfbd3762d8e2636 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 20 Mar 2024 23:34:52 -0400 Subject: [PATCH 3/3] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1334837665..7cc923c531 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.14" +version = "0.14.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"