Skip to content

Commit

Permalink
More tests and some safety
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 24, 2024
1 parent 7c1f1b2 commit 4198c86
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
6 changes: 1 addition & 5 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
A_ = alias_A ? deepcopy(A) : A
end
else
if alg isa DefaultLinearSolver
A_ = deepcopy(A)
else
A_ = alias_A ? deepcopy(A) : A
end
A_ = deepcopy(A)
end

sol = solve!(cache)
Expand Down
31 changes: 25 additions & 6 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,34 @@ end

dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)

#= Needs ForwardDiff rules
dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1))
dA2 = FiniteDiff.finite_difference_gradient(
x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = FiniteDiff.finite_difference_gradient(
x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = FiniteDiff.finite_difference_gradient(
x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))

@test dAdA2 atol=5e-5
@test db1 db12
@test db2 db22

function f4(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))
prob = LinearProblem(A, b2)
sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES()))
norm(sol1.u .+ sol2.u)
end

dA, db1, db2 = Zygote.gradient(f4, A, b1, b1)

dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))

@test dAdA2 atol=5e-5
@test dAdA2 atol=5e-5
@test db1 db12
@test db2 db22
=#

A = rand(n, n);
b1 = rand(n);
Expand Down

0 comments on commit 4198c86

Please sign in to comment.