diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 23a9465..47435fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,11 +6,11 @@ jobs: strategy: fail-fast: false matrix: - julia_version: ["1.7", "1"] + julia_version: ["1.10.0-rc1"] name: julia ${{ matrix.julia_version }} runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia_version }} diff --git a/Project.toml b/Project.toml index b741957..bf4ce1e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.8" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b" diff --git a/enzyme_tests/Project.toml b/enzyme_tests/Project.toml new file mode 100644 index 0000000..20eea8f --- /dev/null +++ b/enzyme_tests/Project.toml @@ -0,0 +1,4 @@ +[deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" +ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad" diff --git a/enzyme_tests/main.jl b/enzyme_tests/main.jl new file mode 100644 index 0000000..7ee20f6 --- /dev/null +++ b/enzyme_tests/main.jl @@ -0,0 +1,48 @@ +using Enzyme, EnzymeTestUtils, ParametricMCPs + +parameter_dimension = 2 +f(z, θ) = [2z[1:2] - z[3:4] - 2θ; z[1:2]] +lower_bounds = [-Inf, -Inf, 0, 0] +upper_bounds = [Inf, Inf, Inf, Inf] +problem = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension) + +@info "I can can trigger the rule manually just fine:" +dsol = Enzyme.autodiff( + Forward, + ParametricMCPs.solve, + Const(problem), + Duplicated([1.0, 2.0], [1.0, 0.0]), +) +@show dsol + +@info """ +Now testing with EnzymeTestUtils. + +this fails because `FiniteDifferences.jl` cannot flatten the output struct with to_vec` +""" +try + test_forward(solve, Duplicated, (problem, Const), ([1.0, 2.0], Duplicated)) +catch e + display(e) +end + +@info """ +To circumvent the issue above, now we unpack the relevant fields of the output struct for differentiation +this fails because Enzyme here thinks that the activities don't match: + +``` +Enzyme execution failed. +Mismatched activity for: store {} addrspace(10)* %.fca.0.0.0.0.extract, {} addrspace(10)** %.fca.0.0.0.0.gep, align 8, !tbaa !113, !alias.scope !117, !noalias !120 const val: %.fca.0.0.0.0.extract = extractvalue { [1 x [1 x {} addrspace(10)*]], { [1 x [1 x {} addrspace(10)*]], { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, [2 x i64], {} addrspace(10)* }, { [1 x [1 x {} addrspace(10)*]], { i64, i64, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, [2 x i64], {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, i64 } %0, 0, 0, 0 +``` + + +(although they should be exactly as our manual autodiff test above?) +""" +try + test_forward(Duplicated, (problem, Const), ([1.0, 2.0], Duplicated)) do problem, θ + @inline + solve(problem, θ).z + end +catch e + display(e) +end diff --git a/src/AutoDiff.jl b/src/AutoDiff.jl index b024c44..562c202 100644 --- a/src/AutoDiff.jl +++ b/src/AutoDiff.jl @@ -2,6 +2,7 @@ module AutoDiff using ..ParametricMCPs: ParametricMCPs, get_problem_size, get_result_buffer, get_parameter_dimension using ChainRulesCore: ChainRulesCore +using EnzymeCore: EnzymeCore, EnzymeRules using ForwardDiff: ForwardDiff using SparseArrays: SparseArrays using LinearAlgebra: LinearAlgebra @@ -45,6 +46,139 @@ function _solve_jacobian_θ(problem, solution, θ; active_tolerance = 1e-3) ∂z∂θ end +const EnzymeBatchedAnnotation = Union{EnzymeCore.BatchDuplicated,EnzymeCore.BatchDuplicatedNoNeed} +const EnzymeNoneedAnnotation = Union{EnzymeCore.DuplicatedNoNeed,EnzymeCore.BatchDuplicatedNoNeed} + +function EnzymeRules.forward( + func::EnzymeCore.Const{typeof(ParametricMCPs.solve)}, + ::Type{ReturnType}, + problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP}, + θ::EnzymeCore.Annotation; + kwargs..., +) where {ReturnType<:EnzymeCore.Annotation} + # TODO: Enzyme sometimes passes us the problem as non-const (why?). For now, skip this check. + #if !(problem isa EnzymeCore.Const) + # throw(ArgumentError(""" + # `problem` must be annotated `Enzyme.Const`. + # If you did not pass the non-const problem annotation yourself, + # consider filing an issue with ParametricMCPs.jl. + # """)) + #end + + if θ isa EnzymeCore.Const + throw( + ArgumentError( + """ + `θ` was annotated `Enzyme.Const` which defeats the purpose of running AD. + If you did not pass the const θ annotation yourself, + consider filing an issue with ParametricMCPs.jl. + """, + ), + ) + end + + # forward pass + solution_val = func.val(problem.val, θ.val; kwargs...) + + if ReturnType <: EnzymeCore.Const + return solution_val + end + + # backward pass + ∂z∂θ = _solve_jacobian_θ(problem.val, solution_val, θ.val) + + if ReturnType <: EnzymeBatchedAnnotation + solution_dval = map(θ.dval) do θdval + _dval = deepcopy(solution_val) + _dval.z .= ∂z∂θ * θdval + _dval + end + else + # downstream gradient + dz = ∂z∂θ * θ.dval + solution_dval = deepcopy(solution_val) + solution_dval.z .= dz + end + + if ReturnType <: EnzymeNoneedAnnotation + return solution_dval + end + + if ReturnType <: EnzymeCore.Duplicated + return EnzymeCore.Duplicated(solution_val, solution_dval) + end + + if ReturnType <: EnzymeCore.BatchDuplicated + return EnzymeCore.BatchDuplicated(solution_val, solution_dval) + end + + throw(ArgumentError(""" + Forward rule for ReturnType with annotation $(ReturnType) not implemented. + Please file an issue with ParametricMCPs.jl. + """)) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.ConfigWidth{1}, + func::EnzymeCore.Const{typeof(ParametricMCPs.solve)}, + ::Type{<:EnzymeRules.Annotation}, + problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP}, + θ::EnzymeCore.Annotation; + kwargs..., +) + function copy_or_reuse(val, idx) + if EnzymeRules.overwritten(config)[idx] && ismutable(val) + return deepcopy(val) + end + val + end + + θval = copy_or_reuse(θ.val, 3) + res = func.val(problem.val, θval; kwargs...) + # backward pass + ∂z∂θ_thunk = () -> _solve_jacobian_θ(problem.val, res, θval) + + dres = deepcopy(res) + dres.z .= 0.0 + + tape = (; ∂z∂θ_thunk, dres) + + EnzymeRules.AugmentedReturn(res, dres, tape) +end + +function EnzymeRules.reverse( + config, + func::EnzymeCore.Const{typeof(ParametricMCPs.solve)}, + rt::Type{ReturnType}, # TODO: tighter type constraint + tape, + problem::EnzymeCore.Annotation{<:ParametricMCPs.ParametricMCP}, + θ::EnzymeCore.Annotation; + kwargs..., +) where {ReturnType} + println("reverse: $rt") + + if θ isa EnzymeCore.Duplicated + ∂z∂θ = tape.∂z∂θ_thunk() + ∂l∂z = tape.dres.z + θ.dval .+= ∂z∂θ' * ∂l∂z + elseif !(θ isa EnzymeCore.Const) + throw(ArgumentError(""" + Reverse rule for θ with annotation $(typeof(θ)) not implemented. + Please file an issue with ParametricMCPs.jl. + """)) + end + + if !(problem isa EnzymeCore.Const) + throw(ArgumentError(""" + `problem` must be annotated `Enzyme.Const`. + If you did not pass the non-const problem annotation yourself, + consider filing an issue with ParametricMCPs.jl. + """)) + end + + (nothing, nothing) +end + function ChainRulesCore.rrule(::typeof(ParametricMCPs.solve), problem, θ; kwargs...) solution = ParametricMCPs.solve(problem, θ; kwargs...) project_to_θ = ChainRulesCore.ProjectTo(θ) diff --git a/src/solver.jl b/src/solver.jl index 60c0e53..6362b38 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -28,7 +28,7 @@ function solve( jacobian_linear_elements = enable_presolve ? jacobian_z!.constant_entries : empty(jacobian_z!.constant_entries) - status, z, info = PATHSolver.solve_mcp( + status, z::Vector, info::PATHSolver.Information = PATHSolver.solve_mcp( F, J, lower_bounds, diff --git a/test/Project.toml b/test/Project.toml index 47b24a1..96ade20 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/runtests.jl b/test/runtests.jl index e232491..65c7da0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Random: Random using LinearAlgebra: norm using Zygote: Zygote using FiniteDiff: FiniteDiff +using Enzyme: Enzyme @testset "ParametricMCPs.jl" begin rng = Random.MersenneTwister(1) @@ -15,9 +16,15 @@ using FiniteDiff: FiniteDiff lower_bounds = [-Inf, -Inf, 0, 0] upper_bounds = [Inf, Inf, Inf, Inf] problem = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension) - problem_no_jacobian = ParametricMCPs.ParametricMCP(f, lower_bounds, upper_bounds, parameter_dimension; compute_sensitivities = false) + problem_no_jacobian = ParametricMCPs.ParametricMCP( + f, + lower_bounds, + upper_bounds, + parameter_dimension; + compute_sensitivities = false, + ) - feasible_parameters = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [rand(rng, 2) for _ in 1:10]...] + feasible_parameters = [[0.0, 0.0], [rand(rng, 2) for _ in 1:4]...] infeasible_parameters = -feasible_parameters @testset "forward pass" begin @@ -32,27 +39,61 @@ using FiniteDiff: FiniteDiff end end - @testset "backward pass" begin - function dummy_pipeline(θ) - solution = ParametricMCPs.solve(problem, θ) - sum(solution.z .^ 2) - end + function dummy_pipeline(problem, θ) + solution = ParametricMCPs.solve(problem, θ) + sum(solution.z .^ 2) + end + @testset "backward pass" begin for θ in [feasible_parameters; infeasible_parameters] - ∇_autodiff_reverse = only(Zygote.gradient(dummy_pipeline, θ)) - ∇_autodiff_forward = only(Zygote.gradient(θ -> Zygote.forwarddiff(dummy_pipeline, θ), θ)) - ∇_finitediff = FiniteDiff.finite_difference_gradient(dummy_pipeline, θ) - @test isapprox(∇_autodiff_reverse, ∇_finitediff; atol = 1e-4) - @test isapprox(∇_autodiff_reverse, ∇_autodiff_forward; atol = 1e-4) + #Enzyme.jacobian(Enzyme.Reverse, dummy_pipeline, Enzyme.Duplicated([1.0, 1.0], [0.0, 0.0])) + ∇_finitediff = FiniteDiff.finite_difference_gradient(θ -> dummy_pipeline(problem, θ), θ) + + @testset "Zygote Reverse" begin + ∇_zygote_reverse = Zygote.gradient(θ) do θ + dummy_pipeline(problem, θ) + end |> only + @test isapprox(∇_zygote_reverse, ∇_finitediff; atol = 1e-4) + end + + @testset "Zygote Forward" begin + ∇_zygote_forward = Zygote.gradient(θ) do θ + Zygote.forwarddiff(θ) do θ + dummy_pipeline(problem, θ) + end + end |> only + @test isapprox(∇_zygote_forward, ∇_finitediff; atol = 1e-4) + end + + @testset "Enzyme Forward" begin + ∇_enzyme_forward = + Enzyme.autodiff( + Enzyme.Forward, + dummy_pipeline, + problem, + Enzyme.BatchDuplicated(θ, Enzyme.onehot(θ)), + ) |> + only |> + collect + @test isapprox(∇_enzyme_forward, ∇_finitediff; atol = 1e-4) + end + + @testset "Enzyme Reverse" begin + ∇_enzyme_reverse = zero(θ) + Enzyme.autodiff( + Enzyme.Reverse, + dummy_pipeline, + problem, + Enzyme.Duplicated(θ, ∇_enzyme_reverse), + ) + @test isapprox(∇_enzyme_reverse, ∇_finitediff; atol = 1e-4) + end end - end - + end + @testset "missing jacobian" begin - function dummy_pipeline(θ, problem) - solution = ParametricMCPs.solve(problem, θ) - sum(solution.z .^ 2) + @test_throws ArgumentError Zygote.gradient(feasible_parameters[1]) do θ + dummy_pipeline(problem_no_jacobian, θ) end - - @test_throws ArgumentError Zygote.gradient(θ -> dummy_pipeline(θ, problem_no_jacobian), feasible_parameters[1]) - end + end end