Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests for Enzyme frontend for nonlinear problem #1066

Merged
merged 21 commits into from
Jul 21, 2024
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions test/steady_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,14 +427,39 @@ end
test_loss(p, prob4, alg = DynamicSS(Rodas5()))
test_loss(p, prob2, alg = SimpleNewtonRaphson())

function enzyme_gradient(p, prob; alg = NewtonRaphson())
dp = Enzyme.make_zero(p)
dprob = Enzyme.make_zero(prob)
Enzyme.autodiff(Reverse, (p, prob) -> test_loss(p, prob, alg=alg), Active, Duplicated(p,dp), Duplicated(prob,dprob))
m-bossart marked this conversation as resolved.
Show resolved Hide resolved
return dp
end
dp1 = Zygote.gradient(p -> test_loss(p, prob), p)[1]
dp1_enzyme = enzyme_gradient(p, prob)
@test dp1≈dp1_enzyme rtol=1e-10
dp2 = Zygote.gradient(p -> test_loss(p, prob2), p)[1]
dp2_enzyme = enzyme_gradient(p, prob2)
@test dp2≈dp2_enzyme rtol=1e-10
dp3 = Zygote.gradient(p -> test_loss(p, prob3, alg = DynamicSS(Rodas5())), p)[1]
#dp3_enzyme = enzyme_gradient(p, prob3, alg = DynamicSS(Rodas5()))
#@test dp3≈dp3_enzyme rtol=1e-10
dp4 = Zygote.gradient(p -> test_loss(p, prob4, alg = DynamicSS(Rodas5())), p)[1]
dp5 = Zygote.gradient(p -> test_loss(p, prob2, alg = SimpleNewtonRaphson()), p)[1]
#dp4_enzyme = enzyme_gradient(p, prob4, alg = DynamicSS(Rodas5()))
#@test dp4≈dp4_enzyme rtol=1e-10
dp5 = Zygote.gradient(p -> test_loss(p, prob2, alg = SimpleNewtonRaphson()), p)[1] #need new version? Doesn't hit Enzyme rule... ?
#dp5_enzyme = enzyme_gradient(p, prob2, alg = SimpleNewtonRaphson())
#@test dp5≈dp5_enzyme rtol=1e-10
dp6 = Zygote.gradient(p -> test_loss(p, prob2, alg = Klement()), p)[1]
dp7 = Zygote.gradient(p -> test_loss(p, prob2, alg = SimpleTrustRegion()), p)[1]
dp6_enzyme = enzyme_gradient(p, prob2, alg = Klement())
@test dp6≈dp6_enzyme rtol=1e-10
dp7 = Zygote.gradient(p -> test_loss(p, prob2, alg = SimpleTrustRegion()), p)[1] #need new version? Doesn't hit Enzyme rule... ?
#dp7_enzyme = enzyme_gradient(p, prob2, alg = SimpleTrustRegion())
#@test dp7≈dp7_enzyme rtol=1e-10
dp8 = Zygote.gradient(p -> test_loss(p, prob2, alg = NLsolveJL()), p)[1]
#dp8_enzyme = enzyme_gradient(p, prob2, alg = NLsolveJL())
#@test dp8≈dp8_enzyme rtol=1e-10
dp9 = Zygote.gradient(p -> test_loss(p, prob, alg = TrustRegion()), p)[1]
dp9_enzyme = enzyme_gradient(p, prob, alg = TrustRegion())
@test dp9≈dp9_enzyme rtol=1e-10

@test dp1≈dp2 rtol=1e-10
@test dp1≈dp3 rtol=1e-10
Expand All @@ -443,6 +468,7 @@ end
@test dp1≈dp6 rtol=1e-10
@test dp1≈dp7 rtol=1e-10
@test dp1≈dp8 rtol=1e-10
@test dp1≈dp9 rtol=1e-10

# Larger Batched Problem: For testing the Iterative Solvers Path
u0 = zeros(128)
Expand All @@ -461,9 +487,11 @@ end
test_loss(p, prob)

dp1 = Zygote.gradient(p -> test_loss(p, prob), p)[1]

dp1_enzyme = enzyme_gradient(p, prob)
@test dp1[1] ≈ 128
@test dp1_enzyme[1] ≈ 128
@test dp1[2] ≈ -128
@test dp1_enzyme[2] ≈ -128
end

@testset "Continuous sensitivity tools" begin
Expand Down