From 4198c867b9e66bce02135ef4b66df41172d1ae18 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 16:28:14 -0500 Subject: [PATCH] More tests and some safety --- src/adjoint.jl | 6 +----- test/adjoint.jl | 31 +++++++++++++++++++++++++------ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/adjoint.jl b/src/adjoint.jl index 3d46d8048..550bb2bd6 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -47,11 +47,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, A_ = alias_A ? deepcopy(A) : A end else - if alg isa DefaultLinearSolver - A_ = deepcopy(A) - else - A_ = alias_A ? deepcopy(A) : A - end + A_ = deepcopy(A) end sol = solve!(cache) diff --git a/test/adjoint.jl b/test/adjoint.jl index ecc9714eb..26a72016f 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -51,15 +51,34 @@ end dA, db1, db2 = Zygote.gradient(f3, A, b1, b1) -#= Needs ForwardDiff rules -dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) -db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) -db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1)) +dA2 = FiniteDiff.finite_difference_gradient( + x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = FiniteDiff.finite_difference_gradient( + x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = FiniteDiff.finite_difference_gradient( + x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1)) + +@test dA≈dA2 atol=5e-5 +@test db1 ≈ db12 +@test db2 ≈ db22 + +function f4(A, b1, b2; alg = LUFactorization()) + prob = LinearProblem(A, b1) + sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR())) + prob = LinearProblem(A, b2) + sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES())) + norm(sol1.u .+ sol2.u) +end + +dA, db1, db2 = Zygote.gradient(f4, A, b1, b1) + +dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A)) +db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1)) +db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1)) -@test dA ≈ dA2 atol=5e-5 +@test dA≈dA2 atol=5e-5 @test db1 ≈ db12 @test db2 ≈ db22 -=# A = rand(n, n); b1 = rand(n);