Skip to content

Commit

Permalink
Basic Enzyme support
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Dec 1, 2023
1 parent 94f8519 commit 7d632fc
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 23 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ jobs:
strategy:
fail-fast: false
matrix:
julia_version: ["1.7", "1"]
julia_version: ["1.10.0-rc1"]
name: julia ${{ matrix.julia_version }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia_version }}
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.8"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b"
Expand Down
4 changes: 4 additions & 0 deletions enzyme_tests/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
48 changes: 48 additions & 0 deletions enzyme_tests/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Enzyme, EnzymeTestUtils, ParametricMCPs

parameter_dimension = 2
f(z, θ) = [2z[1:2] - z[3:4] - 2θ; z[1:2]]
lower_bounds = [-Inf, -Inf, 0, 0]
upper_bounds = [Inf, Inf, Inf, Inf]
problem = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension)

@info "I can can trigger the rule manually just fine:"
dsol = Enzyme.autodiff(
Forward,
ParametricMCPs.solve,
Const(problem),
Duplicated([1.0, 2.0], [1.0, 0.0]),
)
@show dsol

@info """
Now testing with EnzymeTestUtils.
this fails because `FiniteDifferences.jl` cannot flatten the output struct with to_vec`
"""
try
test_forward(solve, Duplicated, (problem, Const), ([1.0, 2.0], Duplicated))
catch e
display(e)
end

@info """
To circumvent the issue above, now we unpack the relevant fields of the output struct for differentiation
this fails because Enzyme here thinks that the activities don't match:
```
Enzyme execution failed.
Mismatched activity for: store {} addrspace(10)* %.fca.0.0.0.0.extract, {} addrspace(10)** %.fca.0.0.0.0.gep, align 8, !tbaa !113, !alias.scope !117, !noalias !120 const val: %.fca.0.0.0.0.extract = extractvalue { [1 x [1 x {} addrspace(10)*]], { [1 x [1 x {} addrspace(10)*]], { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, [2 x i64], {} addrspace(10)* }, { [1 x [1 x {} addrspace(10)*]], { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, [2 x i64], {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, i64 } %0, 0, 0, 0
```
(although they should be exactly as our manual autodiff test above?)
"""
try
test_forward(Duplicated, (problem, Const), ([1.0, 2.0], Duplicated)) do problem, θ
@inline
solve(problem, θ).z
end
catch e
display(e)
end
134 changes: 134 additions & 0 deletions src/AutoDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module AutoDiff

using ..ParametricMCPs: ParametricMCPs, get_problem_size, get_result_buffer, get_parameter_dimension
using ChainRulesCore: ChainRulesCore
using EnzymeCore: EnzymeCore, EnzymeRules
using ForwardDiff: ForwardDiff
using SparseArrays: SparseArrays
using LinearAlgebra: LinearAlgebra
Expand Down Expand Up @@ -45,6 +46,139 @@ function _solve_jacobian_θ(problem, solution, θ; active_tolerance = 1e-3)
∂z∂θ
end

const EnzymeBatchedAnnotation = Union{EnzymeCore.BatchDuplicated,EnzymeCore.BatchDuplicatedNoNeed}
const EnzymeNoneedAnnotation = Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.BatchDuplicatedNoNeed}

function EnzymeRules.forward(

Check warning on line 52 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L52

Added line #L52 was not covered by tests
func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
::Type{ReturnType},
problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
θ::EnzymeCore.Annotation;
kwargs...,
) where {ReturnType<:EnzymeCore.Annotation}
# TODO: Enzyme sometimes passes us the problem as non-const (why?). For now, skip this check.
#if !(problem isa EnzymeCore.Const)
# throw(ArgumentError("""
# `problem` must be annotated `Enzyme.Const`.
# If you did not pass the non-const problem annotation yourself,
# consider filing an issue with ParametricMCPs.jl.
# """))
#end

if θ isa EnzymeCore.Const
throw(

Check warning on line 69 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
ArgumentError(
"""
`θ` was annotated `Enzyme.Const` which defeats the purpose of running AD.
If you did not pass the const θ annotation yourself,
consider filing an issue with ParametricMCPs.jl.
""",
),
)
end

# forward pass
solution_val = func.val(problem.val, θ.val; kwargs...)

Check warning on line 81 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L81

Added line #L81 was not covered by tests

if ReturnType <: EnzymeCore.Const
return solution_val

Check warning on line 84 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L83-L84

Added lines #L83 - L84 were not covered by tests
end

# backward pass
∂z∂θ = _solve_jacobian_θ(problem.val, solution_val, θ.val)

Check warning on line 88 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L88

Added line #L88 was not covered by tests

if ReturnType <: EnzymeBatchedAnnotation
solution_dval = map.dval) do θdval
_dval = deepcopy(solution_val)
_dval.z .= ∂z∂θ * θdval
_dval

Check warning on line 94 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L90-L94

Added lines #L90 - L94 were not covered by tests
end
else
# downstream gradient
dz = ∂z∂θ * θ.dval
solution_dval = deepcopy(solution_val)
solution_dval.z .= dz

Check warning on line 100 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L98-L100

Added lines #L98 - L100 were not covered by tests
end

if ReturnType <: EnzymeNoneedAnnotation
return solution_dval

Check warning on line 104 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end

if ReturnType <: EnzymeCore.Duplicated
return EnzymeCore.Duplicated(solution_val, solution_dval)

Check warning on line 108 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end

if ReturnType <: EnzymeCore.BatchDuplicated
return EnzymeCore.BatchDuplicated(solution_val, solution_dval)

Check warning on line 112 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L111-L112

Added lines #L111 - L112 were not covered by tests
end

throw(ArgumentError("""

Check warning on line 115 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L115

Added line #L115 was not covered by tests
Forward rule for ReturnType with annotation $(ReturnType) not implemented.
Please file an issue with ParametricMCPs.jl.
"""))
end

function EnzymeRules.augmented_primal(

Check warning on line 121 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L121

Added line #L121 was not covered by tests
config::EnzymeRules.ConfigWidth{1},
func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
::Type{<:EnzymeRules.Annotation},
problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
θ::EnzymeCore.Annotation;
kwargs...,
)
function copy_or_reuse(val, idx)
if EnzymeRules.overwritten(config)[idx] && ismutable(val)
return deepcopy(val)

Check warning on line 131 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L129-L131

Added lines #L129 - L131 were not covered by tests
end
val

Check warning on line 133 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L133

Added line #L133 was not covered by tests
end

θval = copy_or_reuse.val, 3)
res = func.val(problem.val, θval; kwargs...)

Check warning on line 137 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L136-L137

Added lines #L136 - L137 were not covered by tests
# backward pass
∂z∂θ_thunk = () -> _solve_jacobian_θ(problem.val, res, θval)

Check warning on line 139 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L139

Added line #L139 was not covered by tests

dres = deepcopy(res)
dres.z .= 0.0

Check warning on line 142 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L141-L142

Added lines #L141 - L142 were not covered by tests

tape = (; ∂z∂θ_thunk, dres)

Check warning on line 144 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L144

Added line #L144 was not covered by tests

EnzymeRules.AugmentedReturn(res, dres, tape)

Check warning on line 146 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L146

Added line #L146 was not covered by tests
end

function EnzymeRules.reverse(

Check warning on line 149 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L149

Added line #L149 was not covered by tests
config,
func::EnzymeCore.Const{typeof(ParametricMCPs.solve)},
rt::Type{ReturnType}, # TODO: tighter type constraint
tape,
problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP},
θ::EnzymeCore.Annotation;
kwargs...,
) where {ReturnType}
println("reverse: $rt")

Check warning on line 158 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L158

Added line #L158 was not covered by tests

if θ isa EnzymeCore.Duplicated
∂z∂θ = tape.∂z∂θ_thunk()
∂l∂z = tape.dres.z
θ.dval .+= ∂z∂θ' * ∂l∂z
elseif !isa EnzymeCore.Const)
throw(ArgumentError("""

Check warning on line 165 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L160-L165

Added lines #L160 - L165 were not covered by tests
Reverse rule for θ with annotation $(typeof(θ)) not implemented.
Please file an issue with ParametricMCPs.jl.
"""))
end

if !(problem isa EnzymeCore.Const)
throw(ArgumentError("""

Check warning on line 172 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L171-L172

Added lines #L171 - L172 were not covered by tests
`problem` must be annotated `Enzyme.Const`.
If you did not pass the non-const problem annotation yourself,
consider filing an issue with ParametricMCPs.jl.
"""))
end

(nothing, nothing)

Check warning on line 179 in src/AutoDiff.jl

View check run for this annotation

Codecov / codecov/patch

src/AutoDiff.jl#L179

Added line #L179 was not covered by tests
end

function ChainRulesCore.rrule(::typeof(ParametricMCPs.solve), problem, θ; kwargs...)
solution = ParametricMCPs.solve(problem, θ; kwargs...)
project_to_θ = ChainRulesCore.ProjectTo(θ)
Expand Down
2 changes: 1 addition & 1 deletion src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function solve(

jacobian_linear_elements =
enable_presolve ? jacobian_z!.constant_entries : empty(jacobian_z!.constant_entries)
status, z, info = PATHSolver.solve_mcp(
status, z::Vector, info::PATHSolver.Information = PATHSolver.solve_mcp(
F,
J,
lower_bounds,
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
81 changes: 61 additions & 20 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Random: Random
using LinearAlgebra: norm
using Zygote: Zygote
using FiniteDiff: FiniteDiff
using Enzyme: Enzyme

@testset "ParametricMCPs.jl" begin
rng = Random.MersenneTwister(1)
Expand All @@ -15,9 +16,15 @@ using FiniteDiff: FiniteDiff
lower_bounds = [-Inf, -Inf, 0, 0]
upper_bounds = [Inf, Inf, Inf, Inf]
problem = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension)
problem_no_jacobian = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension; compute_sensitivities = false)
problem_no_jacobian = ParametricMCPs.ParametricMCP(
f,
lower_bounds,
upper_bounds,
parameter_dimension;
compute_sensitivities = false,
)

feasible_parameters = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [rand(rng, 2) for _ in 1:10]...]
feasible_parameters = [[0.0, 0.0], [rand(rng, 2) for _ in 1:4]...]
infeasible_parameters = -feasible_parameters

@testset "forward pass" begin
Expand All @@ -32,27 +39,61 @@ using FiniteDiff: FiniteDiff
end
end

@testset "backward pass" begin
function dummy_pipeline(θ)
solution = ParametricMCPs.solve(problem, θ)
sum(solution.z .^ 2)
end
function dummy_pipeline(problem, θ)
solution = ParametricMCPs.solve(problem, θ)
sum(solution.z .^ 2)
end

@testset "backward pass" begin
for θ in [feasible_parameters; infeasible_parameters]
∇_autodiff_reverse = only(Zygote.gradient(dummy_pipeline, θ))
∇_autodiff_forward = only(Zygote.gradient-> Zygote.forwarddiff(dummy_pipeline, θ), θ))
∇_finitediff = FiniteDiff.finite_difference_gradient(dummy_pipeline, θ)
@test isapprox(∇_autodiff_reverse, ∇_finitediff; atol = 1e-4)
@test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol = 1e-4)
#Enzyme.jacobian(Enzyme.Reverse, dummy_pipeline, Enzyme.Duplicated([1.0, 1.0], [0.0, 0.0]))
∇_finitediff = FiniteDiff.finite_difference_gradient-> dummy_pipeline(problem, θ), θ)

@testset "Zygote Reverse" begin
∇_zygote_reverse = Zygote.gradient(θ) do θ
dummy_pipeline(problem, θ)
end |> only
@test isapprox(∇_zygote_reverse, ∇_finitediff; atol = 1e-4)
end

@testset "Zygote Forward" begin
∇_zygote_forward = Zygote.gradient(θ) do θ
Zygote.forwarddiff(θ) do θ
dummy_pipeline(problem, θ)
end
end |> only
@test isapprox(∇_zygote_forward, ∇_finitediff; atol = 1e-4)
end

@testset "Enzyme Forward" begin
∇_enzyme_forward =
Enzyme.autodiff(
Enzyme.Forward,
dummy_pipeline,
problem,
Enzyme.BatchDuplicated(θ, Enzyme.onehot(θ)),
) |>
only |>
collect
@test isapprox(∇_enzyme_forward, ∇_finitediff; atol = 1e-4)
end

@testset "Enzyme Reverse" begin
∇_enzyme_reverse = zero(θ)
Enzyme.autodiff(
Enzyme.Reverse,
dummy_pipeline,
problem,
Enzyme.Duplicated(θ, ∇_enzyme_reverse),
)
@test isapprox(∇_enzyme_reverse, ∇_finitediff; atol = 1e-4)
end
end
end
end

@testset "missing jacobian" begin
function dummy_pipeline(θ, problem)
solution = ParametricMCPs.solve(problem, θ)
sum(solution.z .^ 2)
@test_throws ArgumentError Zygote.gradient(feasible_parameters[1]) do θ
dummy_pipeline(problem_no_jacobian, θ)
end

@test_throws ArgumentError Zygote.gradient-> dummy_pipeline(θ, problem_no_jacobian), feasible_parameters[1])
end
end
end

0 comments on commit 7d632fc

Please sign in to comment.