From 7d632fcf50f086a20959e93c72c71b2a6febe1a3 Mon Sep 17 00:00:00 2001
From: lassepe <lasse.peters@mailbox.org>
Date: Fri, 1 Dec 2023 23:51:55 +0100
Subject: [PATCH] Basic Enzyme support

---
 .github/workflows/ci.yml  |   4 +-
 Project.toml              |   1 +
 enzyme_tests/Project.toml |   4 ++
 enzyme_tests/main.jl      |  48 ++++++++++++++
 src/AutoDiff.jl           | 134 ++++++++++++++++++++++++++++++++++++++
 src/solver.jl             |   2 +-
 test/Project.toml         |   1 +
 test/runtests.jl          |  81 +++++++++++++++++------
 8 files changed, 252 insertions(+), 23 deletions(-)
 create mode 100644 enzyme_tests/Project.toml
 create mode 100644 enzyme_tests/main.jl

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 23a9465..47435fa 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -6,11 +6,11 @@ jobs:
     strategy:
       fail-fast: false
       matrix:
-        julia_version: ["1.7", "1"]
+        julia_version: ["1.10.0-rc1"]
     name: julia ${{ matrix.julia_version }}
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v4
       - uses: julia-actions/setup-julia@v1
         with:
           version: ${{ matrix.julia_version }}
diff --git a/Project.toml b/Project.toml
index b741957..bf4ce1e 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,6 +5,7 @@ version = "0.1.8"
 
 [deps]
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b"
diff --git a/enzyme_tests/Project.toml b/enzyme_tests/Project.toml
new file mode 100644
index 0000000..20eea8f
--- /dev/null
+++ b/enzyme_tests/Project.toml
@@ -0,0 +1,4 @@
+[deps]
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
+ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
diff --git a/enzyme_tests/main.jl b/enzyme_tests/main.jl
new file mode 100644
index 0000000..7ee20f6
--- /dev/null
+++ b/enzyme_tests/main.jl
@@ -0,0 +1,48 @@
+using Enzyme, EnzymeTestUtils, ParametricMCPs
+
+parameter_dimension = 2
+f(z, θ) = [2z[1:2] - z[3:4] - 2θ; z[1:2]]
+lower_bounds = [-Inf, -Inf, 0, 0]
+upper_bounds = [Inf, Inf, Inf, Inf]
+problem = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension)
+
+@info "I can can trigger the rule manually just fine:"
+dsol = Enzyme.autodiff(
+    Forward,
+    ParametricMCPs.solve,
+    Const(problem),
+    Duplicated([1.0, 2.0], [1.0, 0.0]),
+)
+@show dsol
+
+@info """
+Now testing with EnzymeTestUtils.
+
+this fails because `FiniteDifferences.jl` cannot flatten the output struct with to_vec`
+"""
+try
+    test_forward(solve, Duplicated, (problem, Const), ([1.0, 2.0], Duplicated))
+catch e
+    display(e)
+end
+
+@info """
+To circumvent the issue above, now we unpack the relevant fields of the output struct for differentiation
+this fails because Enzyme here thinks that the activities don't match:
+
+```
+Enzyme execution failed.
+Mismatched activity for:   store {} addrspace(10)* %.fca.0.0.0.0.extract, {} addrspace(10)** %.fca.0.0.0.0.gep, align 8, !tbaa !113, !alias.scope !117, !noalias !120 const val:   %.fca.0.0.0.0.extract = extractvalue { [1 x [1 x {} addrspace(10)*]], { [1 x [1 x {} addrspace(10)*]], { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, [2 x i64], {} addrspace(10)* }, { [1 x [1 x {} addrspace(10)*]], { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, [2 x i64], {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, i64 } %0, 0, 0, 0
+```
+
+
+(although they should be exactly as our manual autodiff test above?)
+"""
+try
+    test_forward(Duplicated, (problem, Const), ([1.0, 2.0], Duplicated)) do problem, θ
+        @inline
+        solve(problem, θ).z
+    end
+catch e
+    display(e)
+end
diff --git a/src/AutoDiff.jl b/src/AutoDiff.jl
index b024c44..562c202 100644
--- a/src/AutoDiff.jl
+++ b/src/AutoDiff.jl
@@ -2,6 +2,7 @@ module AutoDiff
 
 using ..ParametricMCPs: ParametricMCPs, get_problem_size, get_result_buffer, get_parameter_dimension
 using ChainRulesCore: ChainRulesCore
+using EnzymeCore: EnzymeCore, EnzymeRules
 using ForwardDiff: ForwardDiff
 using SparseArrays: SparseArrays
 using LinearAlgebra: LinearAlgebra
@@ -45,6 +46,139 @@ function _solve_jacobian_θ(problem, solution, θ; active_tolerance = 1e-3)
     ∂z∂θ
 end
 
+const EnzymeBatchedAnnotation = Union{EnzymeCore.BatchDuplicated,EnzymeCore.BatchDuplicatedNoNeed}
+const EnzymeNoneedAnnotation = Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.BatchDuplicatedNoNeed}
+
+function EnzymeRules.forward(
+    func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
+    ::Type{ReturnType},
+    problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
+    θ::EnzymeCore.Annotation;
+    kwargs...,
+) where {ReturnType<:EnzymeCore.Annotation}
+    # TODO: Enzyme sometimes passes us the problem as non-const (why?). For now, skip this check.
+    #if !(problem isa EnzymeCore.Const)
+    #    throw(ArgumentError("""
+    #                        `problem` must be annotated `Enzyme.Const`.
+    #                        If you did not pass the non-const problem annotation yourself,
+    #                        consider filing an issue with ParametricMCPs.jl.
+    #                        """))
+    #end
+
+    if θ isa EnzymeCore.Const
+        throw(
+            ArgumentError(
+                """
+                `θ` was annotated `Enzyme.Const` which defeats the purpose of running AD.
+                If you did not pass the const θ annotation yourself,
+                consider filing an issue with ParametricMCPs.jl.
+                """,
+            ),
+        )
+    end
+
+    # forward pass
+    solution_val = func.val(problem.val, θ.val; kwargs...)
+
+    if ReturnType <: EnzymeCore.Const
+        return solution_val
+    end
+
+    # backward pass
+    ∂z∂θ = _solve_jacobian_θ(problem.val, solution_val, θ.val)
+
+    if ReturnType <: EnzymeBatchedAnnotation
+        solution_dval = map(θ.dval) do θdval
+            _dval = deepcopy(solution_val)
+            _dval.z .= ∂z∂θ * θdval
+            _dval
+        end
+    else
+        # downstream gradient
+        dz = ∂z∂θ * θ.dval
+        solution_dval = deepcopy(solution_val)
+        solution_dval.z .= dz
+    end
+
+    if ReturnType <: EnzymeNoneedAnnotation
+        return solution_dval
+    end
+
+    if ReturnType <: EnzymeCore.Duplicated
+        return EnzymeCore.Duplicated(solution_val, solution_dval)
+    end
+
+    if ReturnType <: EnzymeCore.BatchDuplicated
+        return EnzymeCore.BatchDuplicated(solution_val, solution_dval)
+    end
+
+    throw(ArgumentError("""
+                        Forward rule for ReturnType with annotation $(ReturnType) not implemented.
+                        Please file an issue with ParametricMCPs.jl.
+                        """))
+end
+
+function EnzymeRules.augmented_primal(
+    config::EnzymeRules.ConfigWidth{1},
+    func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
+    ::Type{<:EnzymeRules.Annotation},
+    problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
+    θ::EnzymeCore.Annotation;
+    kwargs...,
+)
+    function copy_or_reuse(val, idx)
+        if EnzymeRules.overwritten(config)[idx] && ismutable(val)
+            return deepcopy(val)
+        end
+        val
+    end
+
+    θval = copy_or_reuse(θ.val, 3)
+    res = func.val(problem.val, θval; kwargs...)
+    # backward pass
+    ∂z∂θ_thunk = () -> _solve_jacobian_θ(problem.val, res, θval)
+
+    dres = deepcopy(res)
+    dres.z .= 0.0
+
+    tape = (; ∂z∂θ_thunk, dres)
+
+    EnzymeRules.AugmentedReturn(res, dres, tape)
+end
+
+function EnzymeRules.reverse(
+    config,
+    func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
+    rt::Type{ReturnType}, # TODO: tighter type constraint
+    tape,
+    problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
+    θ::EnzymeCore.Annotation;
+    kwargs...,
+) where {ReturnType}
+    println("reverse: $rt")
+
+    if θ isa EnzymeCore.Duplicated
+        ∂z∂θ = tape.∂z∂θ_thunk()
+        ∂l∂z = tape.dres.z
+        θ.dval .+= ∂z∂θ' * ∂l∂z
+    elseif !(θ isa EnzymeCore.Const)
+        throw(ArgumentError("""
+                            Reverse rule for θ with annotation $(typeof(θ)) not implemented.
+                            Please file an issue with ParametricMCPs.jl.
+                            """))
+    end
+
+    if !(problem isa EnzymeCore.Const)
+        throw(ArgumentError("""
+                            `problem` must be annotated `Enzyme.Const`.
+                            If you did not pass the non-const problem annotation yourself,
+                            consider filing an issue with ParametricMCPs.jl.
+                            """))
+    end
+
+    (nothing, nothing)
+end
+
 function ChainRulesCore.rrule(::typeof(ParametricMCPs.solve), problem, θ; kwargs...)
     solution = ParametricMCPs.solve(problem, θ; kwargs...)
     project_to_θ = ChainRulesCore.ProjectTo(θ)
diff --git a/src/solver.jl b/src/solver.jl
index 60c0e53..6362b38 100644
--- a/src/solver.jl
+++ b/src/solver.jl
@@ -28,7 +28,7 @@ function solve(
 
     jacobian_linear_elements =
         enable_presolve ? jacobian_z!.constant_entries : empty(jacobian_z!.constant_entries)
-    status, z, info = PATHSolver.solve_mcp(
+    status, z::Vector, info::PATHSolver.Information = PATHSolver.solve_mcp(
         F,
         J,
         lower_bounds,
diff --git a/test/Project.toml b/test/Project.toml
index 47b24a1..96ade20 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,4 +1,5 @@
 [deps]
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
 FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/test/runtests.jl b/test/runtests.jl
index e232491..65c7da0 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -4,6 +4,7 @@ using Random: Random
 using LinearAlgebra: norm
 using Zygote: Zygote
 using FiniteDiff: FiniteDiff
+using Enzyme: Enzyme
 
 @testset "ParametricMCPs.jl" begin
     rng = Random.MersenneTwister(1)
@@ -15,9 +16,15 @@ using FiniteDiff: FiniteDiff
     lower_bounds = [-Inf, -Inf, 0, 0]
     upper_bounds = [Inf, Inf, Inf, Inf]
     problem = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension)
-    problem_no_jacobian = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension; compute_sensitivities = false)
+    problem_no_jacobian = ParametricMCPs.ParametricMCP(
+        f,
+        lower_bounds,
+        upper_bounds,
+        parameter_dimension;
+        compute_sensitivities = false,
+    )
 
-    feasible_parameters = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [rand(rng, 2) for _ in 1:10]...]
+    feasible_parameters = [[0.0, 0.0], [rand(rng, 2) for _ in 1:4]...]
     infeasible_parameters = -feasible_parameters
 
     @testset "forward pass" begin
@@ -32,27 +39,61 @@ using FiniteDiff: FiniteDiff
         end
     end
 
-    @testset "backward pass" begin
-        function dummy_pipeline(θ)
-            solution = ParametricMCPs.solve(problem, θ)
-            sum(solution.z .^ 2)
-        end
+    function dummy_pipeline(problem, θ)
+        solution = ParametricMCPs.solve(problem, θ)
+        sum(solution.z .^ 2)
+    end
 
+    @testset "backward pass" begin
         for θ in [feasible_parameters; infeasible_parameters]
-            ∇_autodiff_reverse = only(Zygote.gradient(dummy_pipeline, θ))
-            ∇_autodiff_forward = only(Zygote.gradient(θ -> Zygote.forwarddiff(dummy_pipeline, θ), θ))
-            ∇_finitediff = FiniteDiff.finite_difference_gradient(dummy_pipeline, θ)
-            @test isapprox(∇_autodiff_reverse, ∇_finitediff; atol = 1e-4)
-            @test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol = 1e-4)
+            #Enzyme.jacobian(Enzyme.Reverse, dummy_pipeline, Enzyme.Duplicated([1.0, 1.0], [0.0, 0.0]))
+            ∇_finitediff = FiniteDiff.finite_difference_gradient(θ -> dummy_pipeline(problem, θ), θ)
+
+            @testset "Zygote Reverse" begin
+                ∇_zygote_reverse = Zygote.gradient(θ) do θ
+                    dummy_pipeline(problem, θ)
+                end |> only
+                @test isapprox(∇_zygote_reverse, ∇_finitediff; atol = 1e-4)
+            end
+
+            @testset "Zygote Forward" begin
+                ∇_zygote_forward = Zygote.gradient(θ) do θ
+                    Zygote.forwarddiff(θ) do θ
+                        dummy_pipeline(problem, θ)
+                    end
+                end |> only
+                @test isapprox(∇_zygote_forward, ∇_finitediff; atol = 1e-4)
+            end
+
+            @testset "Enzyme Forward" begin
+                ∇_enzyme_forward =
+                    Enzyme.autodiff(
+                        Enzyme.Forward,
+                        dummy_pipeline,
+                        problem,
+                        Enzyme.BatchDuplicated(θ, Enzyme.onehot(θ)),
+                    ) |>
+                    only |>
+                    collect
+                @test isapprox(∇_enzyme_forward, ∇_finitediff; atol = 1e-4)
+            end
+
+            @testset "Enzyme Reverse" begin
+                ∇_enzyme_reverse = zero(θ)
+                Enzyme.autodiff(
+                    Enzyme.Reverse,
+                    dummy_pipeline,
+                    problem,
+                    Enzyme.Duplicated(θ, ∇_enzyme_reverse),
+                )
+                @test isapprox(∇_enzyme_reverse, ∇_finitediff; atol = 1e-4)
+            end
         end
-    end   
-    
+    end
+
     @testset "missing jacobian" begin
-        function dummy_pipeline(θ, problem)
-            solution = ParametricMCPs.solve(problem, θ)
-            sum(solution.z .^ 2)
+        @test_throws ArgumentError Zygote.gradient(feasible_parameters[1]) do θ
+            dummy_pipeline(problem_no_jacobian, θ)
         end
-
-        @test_throws ArgumentError Zygote.gradient(θ -> dummy_pipeline(θ, problem_no_jacobian), feasible_parameters[1])
-    end 
+    end
 end