diff --git a/Project.toml b/Project.toml index 5ec702c6b1..b6ebd50df1 100644 --- a/Project.toml +++ b/Project.toml @@ -52,6 +52,7 @@ ProgressLogging = "0.1" Reexport = "1.0" SpecialFunctions = "2.1.2" Statistics = "1" +Tracker = "0.2.32" Zygote = "0.6.67" cuDNN = "1" julia = "1.9" @@ -68,7 +69,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", "Tracker"] diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 884af4053b..2565ea2e84 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -81,7 +81,8 @@ end _match_eltype(layer, ::Type, x::AbstractArray) = x # 2-arg method, for common layers with layer.weight -_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x) +# NB using _eltype gets Float64 from Tracker.TrackedArray{Float64}, not TrackedReal +_match_eltype(layer, x) = _match_eltype(layer, _eltype(layer.weight), x) # Trivial rule: function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T} diff --git a/src/utils.jl b/src/utils.jl index 082d9dcb1c..1f8230c522 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -501,9 +501,12 @@ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) end function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) - convert(AbstractArray{eltype(weights)}, bias) + convert(AbstractArray{_eltype(weights)}, bias) end +# This avoids the issue that Tracker.TrackedArray{Float64} declares eltype() = TrackedReal +_eltype(::AbstractArray{T}) where T = T + # Other diff --git a/test/runtests.jl b/test/runtests.jl index 94e0c466e6..cf0d508a99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ Random.seed!(0) @testset "Optimise / Train" begin include("optimise.jl") include("train.jl") + include("tracker.jl") end @testset "Data" begin diff --git a/test/tracker.jl b/test/tracker.jl new file mode 100644 index 0000000000..598e20ae2c --- /dev/null +++ b/test/tracker.jl @@ -0,0 +1,38 @@ +using Tracker: withgradient +using Zygote: gradient +using Functors: fmapstructure +using Flux + +@testset "Tracker.jl" begin + @testset "some simple models" begin + m1 = Dense(ones32(2,3), fill(0.1f0,2), abs2) + x1 = Float32[1,2,3] + (_, v1), g1 = withgradient(m1, x1) do m, x + y1 = m(x) + sum(abs2, y1 .- [4, 5]), y1 + end + @test v1 ≈ m1(x1) + g1z = gradient(m1, x1) do m, x + sum(abs2, m(x) .- [4, 5]) + end + @test g1[1].weight ≈ g1z[1].weight + @test g1[1].bias ≈ g1z[1].bias + + m2 = Chain(Conv((2,2), 3 => 1, relu), Flux.flatten, Dense(20 => 1, tanh), only) + x2 = randn32(5,6,3,1) + v2, g2 = withgradient(m -> m(x2), m2) + g2z = gradient(m -> m(x2), m2) + @test g2[1].layers[1].weight ≈ g2z[1].layers[1].weight + @test g2[1].layers[1].bias ≈ g2z[1].layers[1].bias + @test g2[1].layers[3].weight ≈ g2z[1].layers[3].weight + end + + @testset "Dropout" begin + g1z = gradient(sum∘Dropout(0.5), ones(1000)) + v1, g1 = withgradient(sum∘Dropout(0.5), ones(1000)) + @test 800