From a30b6d93eaf94c8289c5d614086961a0c80a6889 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Mar 2024 18:46:44 +0100 Subject: [PATCH 1/9] start testing --- pytorch-test.ipynb | 61 +++++++++++++++++++++++++++++++++++++ test/ext_enzyme/enzyme.jl | 64 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 pytorch-test.ipynb create mode 100644 test/ext_enzyme/enzyme.jl diff --git a/pytorch-test.ipynb b/pytorch-test.ipynb new file mode 100644 index 0000000000..3205869182 --- /dev/null +++ b/pytorch-test.ipynb @@ -0,0 +1,61 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch as th\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class A(nn.Module):\n", + " def __init__():\n", + " super().__init__()\n", + " return\n", + " \n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(2, 3)\n", + " self.fc2 = nn.Linear(3, 1)\n", + " self.str = \"ciao\"\n", + " self.a = A()\n", + "\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deepl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl new file mode 100644 index 0000000000..d7b2b5b692 --- /dev/null +++ b/test/ext_enzyme/enzyme.jl @@ -0,0 +1,64 @@ +using Test +using Flux +using Enzyme +using EnzymeTestUtils +using Functors +# using EnzymeCore + +make_zero(x::AbstractArray) = zero(x) +make_zero(x::Number) = zero(x) +make_zero(x) = x + +make_differential(model) = fmap(make_zero, model) + +function grad(f, x...) + args = [] + for x in x + if x isa Number + push!(args, Active(x)) + else + push!(args, Duplicated(x, make_differential(x))) + end + end + @show x args + ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) + g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) + return g +end + +@testset "grad" begin + @testset "number and arrays" begin + f(x, y) = sum(x.^2) + y^3 + x = Float32[1, 2, 3] + y = 3f0 + g = grad(f, x, y) + @test g[1] isa Array{Float32} + @test g[2] isa Float32 + @test g[1] ≈ 2x + @test g[2] ≈ 3*y^2 + end + + @testset "struct" begin + struct SimpleDense{W, B, F} + weight::W + bias::B + σ::F + end + SimpleDense(in::Integer, out::Integer; σ=identity) = SimpleDense(randn(Float32, out, in), zeros(Float32, out), σ) + (m::SimpleDense)(x) = m.σ.(m.weight * x .+ m.bias) + @functor SimpleDense + + model = SimpleDense(2, 4) + x = randn(Float32, 2) + loss(model, x) = sum(model(x)) + + g = grad(loss, model, x) + @test g[1] isa SimpleDense + @test g[2] isa Array{Float32} + @test g[1].weight isa Array{Float32} + @test g[1].bias isa Array{Float32} + @test g[1].weight ≈ ones(Float32, 4, 1) .* x' + @test g[1].bias ≈ ones(Float32, 4) + end +end + From 6db08f0644886ab551993b70140179cd8f019782 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 8 Mar 2024 13:49:32 +0100 Subject: [PATCH 2/9] add tests for Enzyme --- pytorch-test.ipynb | 61 ---------------------------- test/ext_enzyme/enzyme.jl | 85 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 65 deletions(-) delete mode 100644 pytorch-test.ipynb diff --git a/pytorch-test.ipynb b/pytorch-test.ipynb deleted file mode 100644 index 3205869182..0000000000 --- a/pytorch-test.ipynb +++ /dev/null @@ -1,61 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch as th\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class A(nn.Module):\n", - " def __init__():\n", - " super().__init__()\n", - " return\n", - " \n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.fc1 = nn.Linear(2, 3)\n", - " self.fc2 = nn.Linear(3, 1)\n", - " self.str = \"ciao\"\n", - " self.a = A()\n", - "\n", - "\n", - "\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "deepl", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index d7b2b5b692..adbd947dff 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -5,11 +5,12 @@ using EnzymeTestUtils using Functors # using EnzymeCore -make_zero(x::AbstractArray) = zero(x) -make_zero(x::Number) = zero(x) -make_zero(x) = x +Enzyme.API.runtimeActivity!(true) # for Enzyme debugging +make_zero(x::Union{Number,AbstractArray}) = zero(x) +make_zero(x) = x make_differential(model) = fmap(make_zero, model) +# make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329 function grad(f, x...) args = [] @@ -20,12 +21,41 @@ function grad(f, x...) push!(args, Duplicated(x, make_differential(x))) end end - @show x args + # @show x args ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) return g end +function check_grad(g1, g2; broken=false) + fmap(g1, g2) do x, y + if x isa Union{Number, AbstractArray{<:Number}} + # @test y isa typeof(x) + # @show x y + @test x ≈ y rtol=1e-4 atol=1e-4 broken=broken + end + return x + end +end + +function test_enzyme_grad(model, x) + loss(model, x) = sum(model(x)) + + Flux.reset!(model) + l = loss(model, x) + Flux.reset!(model) + @test loss(model, x) == l # Check loss doesn't change with multiple runs + + + Flux.reset!(model) + grads_flux = Flux.gradient(loss, model, x) + + Flux.reset!(model) + grads_enzyme = grad(loss, model, x) + + check_grad(grads_flux, grads_enzyme) +end + @testset "grad" begin @testset "number and arrays" begin f(x, y) = sum(x.^2) + y^3 @@ -62,3 +92,50 @@ end end end +@testset "Models" begin + models_xs = [ + (Dense(2, 4), randn(Float32, 2), "Dense"), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2), "Chain(Dense, Dense)"), + (f64(Chain(Dense(2, 4), Dense(4, 2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), + (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), + (Chain(Conv((3, 3), 2 => 3, relu), Conv((3, 3), 3 => 1, relu)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), + (Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())), rand(Float32, 100, 100, 3, 50), "Chain(Conv, MaxPool)"), + (Maxout(() -> Dense(5 => 7, tanh), 3), randn(Float32, 5, 1), "Maxout"), + # BROKEN, uncomment as tests below are fixed + # (RNN(3 => 5), randn(Float32, 3, 10), "RNN"), + # (Chain(RNN(3 => 5), RNN(5 => 3)), randn(Float32, 3, 10), "Chain(RNN, RNN)"), # uncomment when broken test below is fixed + # (LSTM(3 => 5), randn(Float32, 3, 10), "LSTM"), + # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 10), "Chain(LSTM, LSTM)"), + ] + + for (model, x, name) in models_xs + @testset "check grad $name" begin + test_enzyme_grad(model, x) + end + end +end + +@testset "Broken Models" begin + loss(model, x) = sum(model(x)) + + @testset "RNN" begin + model = RNN(3 => 5) + x = randn(Float32, 3, 10) + Flux.reset!(model) + grads_flux = Flux.gradient(loss, model, x) + Flux.reset!(model) + grads_enzyme = grad(loss, model, x) + check_grad(grads_flux[1].state, grads_enzyme[1].state, broken=true) + end + + @testset "LSTM" begin + model = LSTM(3 => 5) + x = randn(Float32, 3, 10) + Flux.reset!(model) + grads_flux = Flux.gradient(loss, model, x) + Flux.reset!(model) + grads_enzyme = grad(loss, model, x) + check_grad(grads_flux[1].state, grads_enzyme[1].state, broken=true) + end +end From fc3bb8c7de00c3d613f3103796d1c73344500840 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 8 Mar 2024 13:54:14 +0100 Subject: [PATCH 3/9] update runtests --- Project.toml | 6 +++++- test/ext_enzyme/enzyme.jl | 2 -- test/runtests.jl | 5 +++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index bc31cd5d3f..42b4c51773 100644 --- a/Project.toml +++ b/Project.toml @@ -40,6 +40,7 @@ Adapt = "3, 4" CUDA = "4, 5" ChainRulesCore = "1.12" Compat = "4.10.0" +Enzyme = "0.11" Functors = "0.4" MLUtils = "0.4" MacroTools = "0.5" @@ -62,6 +63,7 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -71,4 +73,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", + "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", + "Enzyme"] diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index adbd947dff..5bc6600741 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -1,9 +1,7 @@ using Test using Flux using Enzyme -using EnzymeTestUtils using Functors -# using EnzymeCore Enzyme.API.runtimeActivity!(true) # for Enzyme debugging diff --git a/test/runtests.jl b/test/runtests.jl index 8dca6becdd..2ff26a80bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,4 +116,9 @@ Random.seed!(0) @info "Skipping Metal tests, set FLUX_TEST_METAL=true to run them." end + @testset "Enzyme" begin + import Enzyme + include("ext_enzyme/enzyme.jl") + end + end From 875cd63c4c487903537073e83069dc6458b20e98 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 11 Mar 2024 11:31:57 +0100 Subject: [PATCH 4/9] comparison with finitedifferences --- test/ext_enzyme/enzyme.jl | 117 ++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 43 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 5bc6600741..916feaef15 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -2,13 +2,23 @@ using Test using Flux using Enzyme using Functors +using FiniteDifferences Enzyme.API.runtimeActivity!(true) # for Enzyme debugging make_zero(x::Union{Number,AbstractArray}) = zero(x) make_zero(x) = x make_differential(model) = fmap(make_zero, model) -# make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329 +## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329 + +function ngrad(f, x...) + ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x] + ps = [f64(x[1]) for x in ps_and_res] + res = [x[2] for x in ps_and_res] + fdm = FiniteDifferences.central_fdm(5, 1) + gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...) + return ((re(g) for (re, g) in zip(res, gs))...,) +end function grad(f, x...) args = [] @@ -19,39 +29,33 @@ function grad(f, x...) push!(args, Duplicated(x, make_differential(x))) end end - # @show x args ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) return g end function check_grad(g1, g2; broken=false) - fmap(g1, g2) do x, y - if x isa Union{Number, AbstractArray{<:Number}} - # @test y isa typeof(x) - # @show x y - @test x ≈ y rtol=1e-4 atol=1e-4 broken=broken + fmap_with_path(g1, g2) do kp, x, y + !isempty(kp) && :state ∈ kp && return # ignore RNN and LSTM state + if x isa AbstractArray{<:Number} + @show kp + @test x ≈ y rtol=1e-3 atol=1e-3 broken=broken end return x end end -function test_enzyme_grad(model, x) - loss(model, x) = sum(model(x)) - - Flux.reset!(model) +function test_enzyme_grad(loss, model, x) + Flux.trainmode!(model) l = loss(model, x) - Flux.reset!(model) @test loss(model, x) == l # Check loss doesn't change with multiple runs - - Flux.reset!(model) + grads_fd = ngrad(loss, model, x) grads_flux = Flux.gradient(loss, model, x) - - Flux.reset!(model) grads_enzyme = grad(loss, model, x) - check_grad(grads_flux, grads_enzyme) + # check_grad(grads_flux, grads_enzyme) + check_grad(grads_fd, grads_enzyme) end @testset "grad" begin @@ -91,6 +95,11 @@ end end @testset "Models" begin + function loss(model, x) + Flux.reset!(model) + sum(model(x)) + end + models_xs = [ (Dense(2, 4), randn(Float32, 2), "Dense"), (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2), "Chain(Dense, Dense)"), @@ -98,42 +107,64 @@ end (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), (Chain(Conv((3, 3), 2 => 3, relu), Conv((3, 3), 3 => 1, relu)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), - (Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())), rand(Float32, 100, 100, 3, 50), "Chain(Conv, MaxPool)"), - (Maxout(() -> Dense(5 => 7, tanh), 3), randn(Float32, 5, 1), "Maxout"), - # BROKEN, uncomment as tests below are fixed - # (RNN(3 => 5), randn(Float32, 3, 10), "RNN"), - # (Chain(RNN(3 => 5), RNN(5 => 3)), randn(Float32, 3, 10), "Chain(RNN, RNN)"), # uncomment when broken test below is fixed - # (LSTM(3 => 5), randn(Float32, 3, 10), "LSTM"), - # (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 10), "Chain(LSTM, LSTM)"), + (Chain(Conv((5, 5), 3 => 4, pad=SamePad()), MaxPool((5, 5), pad=SamePad())), rand(Float32, 6, 6, 3, 2), "Chain(Conv, MaxPool)"), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), + (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), # uncomment when broken test below is fixed + (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), ] for (model, x, name) in models_xs @testset "check grad $name" begin - test_enzyme_grad(model, x) + println("testing $name") + test_enzyme_grad(loss, model, x) end - end + end end -@testset "Broken Models" begin - loss(model, x) = sum(model(x)) - - @testset "RNN" begin - model = RNN(3 => 5) - x = randn(Float32, 3, 10) +@testset "Recurrence Tests" begin + function loss(model, x) Flux.reset!(model) - grads_flux = Flux.gradient(loss, model, x) - Flux.reset!(model) - grads_enzyme = grad(loss, model, x) - check_grad(grads_flux[1].state, grads_enzyme[1].state, broken=true) + for i in 1:3 + x = model(x) + end + return sum(x) end - @testset "LSTM" begin - model = LSTM(3 => 5) - x = randn(Float32, 3, 10) - Flux.reset!(model) - grads_flux = Flux.gradient(loss, model, x) + models_xs = [ + (RNN(3 => 3), randn(Float32, 3, 2), "RNN"), + (LSTM(3 => 3), randn(Float32, 3, 2), "LSTM"), + # TESTS BELOW ARE BROKEN FOR ZYGOTE BUT CORRECT FOR ENZYME! + (Chain(RNN(3 => 5), RNN(5 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), + ] + + for (model, x, name) in models_xs + @testset "check grad $name" begin + println("testing $name") + test_enzyme_grad(loss, model, x) + end + end +end + +@testset "Broken Models" begin + function loss(model, x) Flux.reset!(model) - grads_enzyme = grad(loss, model, x) - check_grad(grads_flux[1].state, grads_enzyme[1].state, broken=true) + sum(model(x)) end + end + +# function loss(model, x) +# Flux.reset!(model) +# sum(model(x)) +# end + +# model = RNN(2 => 2) +# x = randn(Float32, 2, 3) +# Flux.reset!(model) +# grads_flux = Flux.gradient(loss, model, x) +# grads_enzyme = grad(loss, model, x) +# grads_fd = ngrad(loss, model, x) +# check_grad(grads_enzyme, grads_fd) From f3690cc8d73b245dc1f1c789956fb1d87962c006 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 11 Mar 2024 15:24:07 +0100 Subject: [PATCH 5/9] cl/enzyme --- test/ext_enzyme/enzyme.jl | 77 ++++++++++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 9 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 916feaef15..9784c04977 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -12,6 +12,7 @@ make_differential(model) = fmap(make_zero, model) ## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329 function ngrad(f, x...) + x = [cpu(x) for x in x] ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x] ps = [f64(x[1]) for x in ps_and_res] res = [x[2] for x in ps_and_res] @@ -39,7 +40,7 @@ function check_grad(g1, g2; broken=false) !isempty(kp) && :state ∈ kp && return # ignore RNN and LSTM state if x isa AbstractArray{<:Number} @show kp - @test x ≈ y rtol=1e-3 atol=1e-3 broken=broken + @test x ≈ y rtol=1e-2 broken=broken end return x end @@ -50,9 +51,9 @@ function test_enzyme_grad(loss, model, x) l = loss(model, x) @test loss(model, x) == l # Check loss doesn't change with multiple runs - grads_fd = ngrad(loss, model, x) - grads_flux = Flux.gradient(loss, model, x) - grads_enzyme = grad(loss, model, x) + grads_fd = ngrad(loss, model, x) |> cpu + grads_flux = Flux.gradient(loss, model, x) |> cpu + grads_enzyme = grad(loss, model, x) |> cpu # check_grad(grads_flux, grads_enzyme) check_grad(grads_fd, grads_enzyme) @@ -113,6 +114,7 @@ end (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), # uncomment when broken test below is fixed (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), ] for (model, x, name) in models_xs @@ -153,18 +155,75 @@ end Flux.reset!(model) sum(model(x)) end - + + device = Flux.get_device() + + models_xs = [ + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + # (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # this is just producing a Warning + (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), + ] + + for (model, x, name) in models_xs + @testset "check grad $name" begin + println("testing $name") + broken = false + try + test_enzyme_grad(loss, model, x) + catch e + print(e) + broken = true + end + @test broken + end + end end +@testset "Broken GPU Models" begin + function loss(model, x) + Flux.reset!(model) + sum(model(x)) + end + + device = Flux.get_device() + + models_xs = [ + (Dense(2, 4), randn(Float32, 2,1), "Dense"), + (Chain(Dense(2, 4), Dense(4, 2)), randn(Float32, 2,1), "Chain(Dense, Dense)"), + ] + + for (model, x, name) in models_xs + @testset "check grad $name" begin + println("testing $name") + model = model |> device + x = x |> device + broken = false + try + test_enzyme_grad(loss, model, x) + catch e + print(e) + broken = true + end + @test broken + end + end +end + # function loss(model, x) # Flux.reset!(model) # sum(model(x)) # end -# model = RNN(2 => 2) -# x = randn(Float32, 2, 3) -# Flux.reset!(model) -# grads_flux = Flux.gradient(loss, model, x) +# device = Flux.get_device() +# model = Chain(Dense(2 => 4, relu), Dense(4 => 2)) |> device +# x = randn(Float32, 2,1) |> device +# model(x) +# grads_flux = Flux.gradient(loss, model, x) |> cpu +# grads_fd = ngrad(loss, model, x) +# grads_enzyme = grad(loss, model, x) |> cpu +# check_grad(grads_fd, grads_enzyme) + + # grads_enzyme = grad(loss, model, x) # grads_fd = ngrad(loss, model, x) # check_grad(grads_enzyme, grads_fd) From 29d08a1555459f8a49fd039810f8974d64c1f802 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 16 Mar 2024 19:15:03 +0100 Subject: [PATCH 6/9] tests passing --- out.txt | 102 ++++++++++++++++++++++++++++++++++++++ t.jl | 4 ++ test/ext_enzyme/enzyme.jl | 88 ++++++++------------------------ 3 files changed, 128 insertions(+), 66 deletions(-) create mode 100644 out.txt create mode 100644 t.jl diff --git a/out.txt b/out.txt new file mode 100644 index 0000000000..f0055d0d20 --- /dev/null +++ b/out.txt @@ -0,0 +1,102 @@ + +[1332974] signal (11.1): Segmentation fault +in expression starting at /home/lucibello/.julia/dev/Flux/t.jl:4 +typekeyvalue_hash at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jltypes.c:1622 [inlined] +lookup_typevalue at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jltypes.c:1059 +jl_inst_arg_tuple_type at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jltypes.c:2157 +jl_f_tuple at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/builtins.c:868 [inlined] +jl_f_tuple at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/builtins.c:863 +absint at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/absint.jl:116 +abs_typeof at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/absint.jl:302 +unknown function (ip: 0x7f98b76ac15e) +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +check_ir! at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:500 +check_ir! at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:208 +check_ir! at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:178 +check_ir at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:157 [inlined] +#codegen#488 at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:4376 +codegen at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:4340 [inlined] +_thunk at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5430 +_thunk at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5430 [inlined] +cached_compilation at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5464 [inlined] +#532 at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5530 +JuliaContext at /home/lucibello/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47 +unknown function (ip: 0x7f99113c1265) +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +#s1883#531 at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5482 [inlined] +#s1883#531 at ./none:0 +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +GeneratedFunctionStub at ./boot.jl:602 +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +jl_call_staged at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/method.c:540 +ijl_code_for_staged at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/method.c:593 +get_staged at ./compiler/utilities.jl:123 +retrieve_code_info at ./compiler/utilities.jl:135 [inlined] +InferenceState at ./compiler/inferencestate.jl:430 +typeinf_edge at ./compiler/typeinfer.jl:920 +abstract_call_method at ./compiler/abstractinterpretation.jl:629 +abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:95 +abstract_call_known at ./compiler/abstractinterpretation.jl:2087 +abstract_call at ./compiler/abstractinterpretation.jl:2169 +abstract_call at ./compiler/abstractinterpretation.jl:2162 +abstract_call at ./compiler/abstractinterpretation.jl:2354 +abstract_eval_call at ./compiler/abstractinterpretation.jl:2370 +abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2380 +abstract_eval_statement at ./compiler/abstractinterpretation.jl:2624 +abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2889 +typeinf_local at ./compiler/abstractinterpretation.jl:3098 +typeinf_nocycle at ./compiler/abstractinterpretation.jl:3186 +_typeinf at ./compiler/typeinfer.jl:247 +typeinf at ./compiler/typeinfer.jl:216 +typeinf_ext at ./compiler/typeinfer.jl:1051 +typeinf_ext_toplevel at ./compiler/typeinfer.jl:1082 +typeinf_ext_toplevel at ./compiler/typeinfer.jl:1078 +jfptr_typeinf_ext_toplevel_45276.1 at /home/lucibello/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line) +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] +jl_type_infer at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:394 +jl_generate_fptr_impl at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jitlayers.cpp:502 +jl_compile_method_internal at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2480 [inlined] +jl_compile_method_internal at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2368 +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2886 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +autodiff at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/Enzyme.jl:224 +unknown function (ip: 0x7f9911360dca) +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] +do_call at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:126 +eval_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:223 +eval_stmt_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined] +eval_body at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:617 +jl_interpret_toplevel_thunk at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:775 +jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:934 +jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:877 +ijl_toplevel_eval_in at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:985 +eval at ./boot.jl:385 [inlined] +include_string at ./loading.jl:2076 +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +_include at ./loading.jl:2136 +include at ./Base.jl:495 +jfptr_include_46403.1 at /home/lucibello/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line) +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +exec_options at ./client.jl:318 +_start at ./client.jl:552 +jfptr__start_82738.1 at /home/lucibello/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line) +_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] +ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 +jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] +true_main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:582 +jl_repl_entrypoint at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:731 +main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/cli/loader_exe.c:58 +unknown function (ip: 0x7f9928680d8f) +__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) +unknown function (ip: 0x4010b8) +Allocations: 159273970 (Pool: 159095605; Big: 178365); GC: 99 diff --git a/t.jl b/t.jl new file mode 100644 index 0000000000..7a4dc139a6 --- /dev/null +++ b/t.jl @@ -0,0 +1,4 @@ +using CUDA, Enzyme +x = CUDA.ones(2) +dx = CUDA.zeros(2) +autodiff(Reverse, sum, Active, Duplicated(x, dx)) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 9784c04977..df9f258a3d 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -1,17 +1,21 @@ using Test using Flux + using Enzyme using Functors using FiniteDifferences +using CUDA +Enzyme.API.typeWarning!(false) # suppresses a warning with Bilinear https://github.com/EnzymeAD/Enzyme.jl/issues/1341 Enzyme.API.runtimeActivity!(true) # for Enzyme debugging +# Enzyme.Compiler.bitcode_replacement!(false) -make_zero(x::Union{Number,AbstractArray}) = zero(x) -make_zero(x) = x -make_differential(model) = fmap(make_zero, model) +_make_zero(x::Union{Number,AbstractArray}) = zero(x) +_make_zero(x) = x +make_zero(model) = fmap(_make_zero, model) ## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329 -function ngrad(f, x...) +function gradient_fd(f, x...) x = [cpu(x) for x in x] ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x] ps = [f64(x[1]) for x in ps_and_res] @@ -21,13 +25,13 @@ function ngrad(f, x...) return ((re(g) for (re, g) in zip(res, gs))...,) end -function grad(f, x...) +function gradient_ez(f, x...) args = [] for x in x if x isa Number push!(args, Active(x)) else - push!(args, Duplicated(x, make_differential(x))) + push!(args, Duplicated(x, make_zero(x))) end end ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) @@ -37,10 +41,10 @@ end function check_grad(g1, g2; broken=false) fmap_with_path(g1, g2) do kp, x, y - !isempty(kp) && :state ∈ kp && return # ignore RNN and LSTM state + :state ∈ kp && return # ignore RNN and LSTM state if x isa AbstractArray{<:Number} - @show kp - @test x ≈ y rtol=1e-2 broken=broken + # @show kp + @test x ≈ y rtol=1e-2 atol=1e-6 broken=broken end return x end @@ -51,20 +55,20 @@ function test_enzyme_grad(loss, model, x) l = loss(model, x) @test loss(model, x) == l # Check loss doesn't change with multiple runs - grads_fd = ngrad(loss, model, x) |> cpu + grads_fd = gradient_fd(loss, model, x) |> cpu grads_flux = Flux.gradient(loss, model, x) |> cpu - grads_enzyme = grad(loss, model, x) |> cpu + grads_enzyme = gradient_ez(loss, model, x) |> cpu # check_grad(grads_flux, grads_enzyme) check_grad(grads_fd, grads_enzyme) end -@testset "grad" begin +@testset "gradient_ez" begin @testset "number and arrays" begin f(x, y) = sum(x.^2) + y^3 x = Float32[1, 2, 3] y = 3f0 - g = grad(f, x, y) + g = gradient_ez(f, x, y) @test g[1] isa Array{Float32} @test g[2] isa Float32 @test g[1] ≈ 2x @@ -85,7 +89,7 @@ end x = randn(Float32, 2) loss(model, x) = sum(model(x)) - g = grad(loss, model, x) + g = gradient_ez(loss, model, x) @test g[1] isa SimpleDense @test g[2] isa Array{Float32} @test g[1].weight isa Array{Float32} @@ -108,13 +112,14 @@ end (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), (Chain(Conv((3, 3), 2 => 3, relu), Conv((3, 3), 3 => 1, relu)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), - (Chain(Conv((5, 5), 3 => 4, pad=SamePad()), MaxPool((5, 5), pad=SamePad())), rand(Float32, 6, 6, 3, 2), "Chain(Conv, MaxPool)"), + (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), # uncomment when broken test below is fixed + (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), + (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), ] for (model, x, name) in models_xs @@ -160,7 +165,6 @@ end models_xs = [ (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - # (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # this is just producing a Warning (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), ] @@ -171,37 +175,7 @@ end try test_enzyme_grad(loss, model, x) catch e - print(e) - broken = true - end - @test broken - end - end -end - -@testset "Broken GPU Models" begin - function loss(model, x) - Flux.reset!(model) - sum(model(x)) - end - - device = Flux.get_device() - - models_xs = [ - (Dense(2, 4), randn(Float32, 2,1), "Dense"), - (Chain(Dense(2, 4), Dense(4, 2)), randn(Float32, 2,1), "Chain(Dense, Dense)"), - ] - - for (model, x, name) in models_xs - @testset "check grad $name" begin - println("testing $name") - model = model |> device - x = x |> device - broken = false - try - test_enzyme_grad(loss, model, x) - catch e - print(e) + println(e) broken = true end @test broken @@ -209,21 +183,3 @@ end end end -# function loss(model, x) -# Flux.reset!(model) -# sum(model(x)) -# end - -# device = Flux.get_device() -# model = Chain(Dense(2 => 4, relu), Dense(4 => 2)) |> device -# x = randn(Float32, 2,1) |> device -# model(x) -# grads_flux = Flux.gradient(loss, model, x) |> cpu -# grads_fd = ngrad(loss, model, x) -# grads_enzyme = grad(loss, model, x) |> cpu -# check_grad(grads_fd, grads_enzyme) - - -# grads_enzyme = grad(loss, model, x) -# grads_fd = ngrad(loss, model, x) -# check_grad(grads_enzyme, grads_fd) From a1fa4050f51bb17a83acb95330748705e502eb03 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 16 Mar 2024 19:15:22 +0100 Subject: [PATCH 7/9] cleanup --- out.txt | 102 -------------------------------------------------------- t.jl | 4 --- 2 files changed, 106 deletions(-) delete mode 100644 out.txt delete mode 100644 t.jl diff --git a/out.txt b/out.txt deleted file mode 100644 index f0055d0d20..0000000000 --- a/out.txt +++ /dev/null @@ -1,102 +0,0 @@ - -[1332974] signal (11.1): Segmentation fault -in expression starting at /home/lucibello/.julia/dev/Flux/t.jl:4 -typekeyvalue_hash at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jltypes.c:1622 [inlined] -lookup_typevalue at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jltypes.c:1059 -jl_inst_arg_tuple_type at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jltypes.c:2157 -jl_f_tuple at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/builtins.c:868 [inlined] -jl_f_tuple at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/builtins.c:863 -absint at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/absint.jl:116 -abs_typeof at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/absint.jl:302 -unknown function (ip: 0x7f98b76ac15e) -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -check_ir! at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:500 -check_ir! at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:208 -check_ir! at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:178 -check_ir at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler/validation.jl:157 [inlined] -#codegen#488 at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:4376 -codegen at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:4340 [inlined] -_thunk at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5430 -_thunk at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5430 [inlined] -cached_compilation at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5464 [inlined] -#532 at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5530 -JuliaContext at /home/lucibello/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47 -unknown function (ip: 0x7f99113c1265) -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -#s1883#531 at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5482 [inlined] -#s1883#531 at ./none:0 -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -GeneratedFunctionStub at ./boot.jl:602 -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -jl_call_staged at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/method.c:540 -ijl_code_for_staged at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/method.c:593 -get_staged at ./compiler/utilities.jl:123 -retrieve_code_info at ./compiler/utilities.jl:135 [inlined] -InferenceState at ./compiler/inferencestate.jl:430 -typeinf_edge at ./compiler/typeinfer.jl:920 -abstract_call_method at ./compiler/abstractinterpretation.jl:629 -abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:95 -abstract_call_known at ./compiler/abstractinterpretation.jl:2087 -abstract_call at ./compiler/abstractinterpretation.jl:2169 -abstract_call at ./compiler/abstractinterpretation.jl:2162 -abstract_call at ./compiler/abstractinterpretation.jl:2354 -abstract_eval_call at ./compiler/abstractinterpretation.jl:2370 -abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2380 -abstract_eval_statement at ./compiler/abstractinterpretation.jl:2624 -abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:2889 -typeinf_local at ./compiler/abstractinterpretation.jl:3098 -typeinf_nocycle at ./compiler/abstractinterpretation.jl:3186 -_typeinf at ./compiler/typeinfer.jl:247 -typeinf at ./compiler/typeinfer.jl:216 -typeinf_ext at ./compiler/typeinfer.jl:1051 -typeinf_ext_toplevel at ./compiler/typeinfer.jl:1082 -typeinf_ext_toplevel at ./compiler/typeinfer.jl:1078 -jfptr_typeinf_ext_toplevel_45276.1 at /home/lucibello/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line) -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] -jl_type_infer at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:394 -jl_generate_fptr_impl at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jitlayers.cpp:502 -jl_compile_method_internal at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2480 [inlined] -jl_compile_method_internal at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2368 -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2886 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -autodiff at /home/lucibello/.julia/packages/Enzyme/wR2t7/src/Enzyme.jl:224 -unknown function (ip: 0x7f9911360dca) -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] -do_call at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:126 -eval_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:223 -eval_stmt_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined] -eval_body at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:617 -jl_interpret_toplevel_thunk at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:775 -jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:934 -jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:877 -ijl_toplevel_eval_in at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:985 -eval at ./boot.jl:385 [inlined] -include_string at ./loading.jl:2076 -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -_include at ./loading.jl:2136 -include at ./Base.jl:495 -jfptr_include_46403.1 at /home/lucibello/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line) -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -exec_options at ./client.jl:318 -_start at ./client.jl:552 -jfptr__start_82738.1 at /home/lucibello/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line) -_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined] -ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076 -jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined] -true_main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:582 -jl_repl_entrypoint at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:731 -main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/cli/loader_exe.c:58 -unknown function (ip: 0x7f9928680d8f) -__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) -unknown function (ip: 0x4010b8) -Allocations: 159273970 (Pool: 159095605; Big: 178365); GC: 99 diff --git a/t.jl b/t.jl deleted file mode 100644 index 7a4dc139a6..0000000000 --- a/t.jl +++ /dev/null @@ -1,4 +0,0 @@ -using CUDA, Enzyme -x = CUDA.ones(2) -dx = CUDA.zeros(2) -autodiff(Reverse, sum, Active, Duplicated(x, dx)) From 60c0cd4809635f7bb63fe2d022aee7a35efa5e2e Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 16 Mar 2024 19:22:32 +0100 Subject: [PATCH 8/9] add FiniteDifferences to extra --- Project.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 42b4c51773..8e6844c75c 100644 --- a/Project.toml +++ b/Project.toml @@ -41,6 +41,7 @@ CUDA = "4, 5" ChainRulesCore = "1.12" Compat = "4.10.0" Enzyme = "0.11" +FiniteDifferences = "0.12" Functors = "0.4" MLUtils = "0.4" MacroTools = "0.5" @@ -65,6 +66,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" @@ -74,5 +76,5 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", - "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", - "Enzyme"] + "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", + "Enzyme", "FiniteDifferences"] From 767349ef5b83d27fb2d4f266d513529069e251b1 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 17 Mar 2024 06:44:56 +0100 Subject: [PATCH 9/9] check_grad -> test_grad --- test/ext_enzyme/enzyme.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index df9f258a3d..36212bb10f 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -39,7 +39,7 @@ function gradient_ez(f, x...) return g end -function check_grad(g1, g2; broken=false) +function test_grad(g1, g2; broken=false) fmap_with_path(g1, g2) do kp, x, y :state ∈ kp && return # ignore RNN and LSTM state if x isa AbstractArray{<:Number} @@ -59,8 +59,8 @@ function test_enzyme_grad(loss, model, x) grads_flux = Flux.gradient(loss, model, x) |> cpu grads_enzyme = gradient_ez(loss, model, x) |> cpu - # check_grad(grads_flux, grads_enzyme) - check_grad(grads_fd, grads_enzyme) + # test_grad(grads_flux, grads_enzyme) + test_grad(grads_fd, grads_enzyme) end @testset "gradient_ez" begin