Skip to content

Commit

Permalink
restore and test some support for Tracker.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 21, 2024
1 parent eb6492c commit 125d6e2
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 3 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ ProgressLogging = "0.1"
Reexport = "1.0"
SpecialFunctions = "2.1.2"
Statistics = "1"
Tracker = "0.2.32"
Zygote = "0.6.67"
cuDNN = "1"
julia = "1.9"
Expand All @@ -68,7 +69,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", "Tracker"]
3 changes: 2 additions & 1 deletion src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ end
_match_eltype(layer, ::Type, x::AbstractArray) = x

# 2-arg method, for common layers with layer.weight
_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x)
# NB using _eltype gets Float64 from Tracker.TrackedArray{Float64}, not TrackedReal
_match_eltype(layer, x) = _match_eltype(layer, _eltype(layer.weight), x)

# Trivial rule:
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T}
Expand Down
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,12 @@ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
end
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
convert(AbstractArray{eltype(weights)}, bias)
convert(AbstractArray{_eltype(weights)}, bias)
end

# This avoids the issue that Tracker.TrackedArray{Float64} declares eltype() = TrackedReal
_eltype(::AbstractArray{T}) where T = T


# Other

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Random.seed!(0)
@testset "Optimise / Train" begin
include("optimise.jl")
include("train.jl")
include("tracker.jl")
end

@testset "Data" begin
Expand Down
38 changes: 38 additions & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using Tracker: withgradient
using Zygote: gradient
using Functors: fmapstructure
using Flux

@testset "Tracker.jl" begin
@testset "some simple models" begin
m1 = Dense(ones32(2,3), fill(0.1f0,2), abs2)
x1 = Float32[1,2,3]
(_, v1), g1 = withgradient(m1, x1) do m, x
y1 = m(x)
sum(abs2, y1 .- [4, 5]), y1
end
@test v1 m1(x1)
g1z = gradient(m1, x1) do m, x
sum(abs2, m(x) .- [4, 5])
end
@test g1[1].weight g1z[1].weight
@test g1[1].bias g1z[1].bias

m2 = Chain(Conv((2,2), 3 => 1, relu), Flux.flatten, Dense(20 => 1, tanh), only)
x2 = randn32(5,6,3,1)
v2, g2 = withgradient(m -> m(x2), m2)
g2z = gradient(m -> m(x2), m2)
@test g2[1].layers[1].weight g2z[1].layers[1].weight
@test g2[1].layers[1].bias g2z[1].layers[1].bias
@test g2[1].layers[3].weight g2z[1].layers[3].weight
end

@testset "Dropout" begin
g1z = gradient(sumDropout(0.5), ones(1000))
v1, g1 = withgradient(sumDropout(0.5), ones(1000))
@test 800<v1<1200
@test sum(g1[1]) v1
@test 400 < count(iszero, g1[1]) < 600
end
end

0 comments on commit 125d6e2

Please sign in to comment.