Skip to content

Commit

Permalink
cl/enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Mar 11, 2024
1 parent 875cd63 commit f3690cc
Showing 1 changed file with 68 additions and 9 deletions.
77 changes: 68 additions & 9 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ make_differential(model) = fmap(make_zero, model)
## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329

function ngrad(f, x...)
x = [cpu(x) for x in x]
ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x]
ps = [f64(x[1]) for x in ps_and_res]
res = [x[2] for x in ps_and_res]
Expand Down Expand Up @@ -39,7 +40,7 @@ function check_grad(g1, g2; broken=false)
!isempty(kp) && :state kp && return # ignore RNN and LSTM state
if x isa AbstractArray{<:Number}
@show kp
@test x y rtol=1e-3 atol=1e-3 broken=broken
@test x y rtol=1e-2 broken=broken
end
return x
end
Expand All @@ -50,9 +51,9 @@ function test_enzyme_grad(loss, model, x)
l = loss(model, x)
@test loss(model, x) == l # Check loss doesn't change with multiple runs

grads_fd = ngrad(loss, model, x)
grads_flux = Flux.gradient(loss, model, x)
grads_enzyme = grad(loss, model, x)
grads_fd = ngrad(loss, model, x) |> cpu
grads_flux = Flux.gradient(loss, model, x) |> cpu
grads_enzyme = grad(loss, model, x) |> cpu

# check_grad(grads_flux, grads_enzyme)
check_grad(grads_fd, grads_enzyme)
Expand Down Expand Up @@ -113,6 +114,7 @@ end
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), # uncomment when broken test below is fixed
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
]

for (model, x, name) in models_xs
Expand Down Expand Up @@ -153,18 +155,75 @@ end
Flux.reset!(model)
sum(model(x))
end


device = Flux.get_device()

models_xs = [
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
# (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # this is just producing a Warning
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
broken = false
try
test_enzyme_grad(loss, model, x)
catch e
print(e)
broken = true
end
@test broken
end
end
end

@testset "Broken GPU Models" begin
function loss(model, x)
Flux.reset!(model)
sum(model(x))
end

device = Flux.get_device()

models_xs = [
(Dense(2, 4), randn(Float32, 2,1), "Dense"),
(Chain(Dense(2, 4), Dense(4, 2)), randn(Float32, 2,1), "Chain(Dense, Dense)"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
model = model |> device
x = x |> device
broken = false
try
test_enzyme_grad(loss, model, x)
catch e
print(e)
broken = true
end
@test broken
end
end
end

# function loss(model, x)
# Flux.reset!(model)
# sum(model(x))
# end

# model = RNN(2 => 2)
# x = randn(Float32, 2, 3)
# Flux.reset!(model)
# grads_flux = Flux.gradient(loss, model, x)
# device = Flux.get_device()
# model = Chain(Dense(2 => 4, relu), Dense(4 => 2)) |> device
# x = randn(Float32, 2,1) |> device
# model(x)
# grads_flux = Flux.gradient(loss, model, x) |> cpu
# grads_fd = ngrad(loss, model, x)
# grads_enzyme = grad(loss, model, x) |> cpu
# check_grad(grads_fd, grads_enzyme)


# grads_enzyme = grad(loss, model, x)
# grads_fd = ngrad(loss, model, x)
# check_grad(grads_enzyme, grads_fd)

0 comments on commit f3690cc

Please sign in to comment.