Skip to content

Commit

Permalink
More tests and some safety
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 24, 2024
1 parent 7c1f1b2 commit 6432716
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ StaticArrays = "1.5"
StaticArraysCore = "1.4.2"
Test = "1"
UnPack = "1"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Expand Down Expand Up @@ -137,6 +138,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
2 changes: 1 addition & 1 deletion ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions), typeof(sensealg))
typeof(__issquare(assumptions)), typeof(sensealg)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters, verbose, assumptions, sensealg)
return cache
Expand Down
6 changes: 1 addition & 5 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
A_ = alias_A ? deepcopy(A) : A
end
else
if alg isa DefaultLinearSolver
A_ = deepcopy(A)
else
A_ = alias_A ? deepcopy(A) : A
end
A_ = deepcopy(A)
end

sol = solve!(cache)
Expand Down
31 changes: 25 additions & 6 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,34 @@ end

dA, db1, db2 = Zygote.gradient(f3, A, b1, b1)

#= Needs ForwardDiff rules
dA2 = ForwardDiff.gradient(x -> f3(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f3(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f3(eltype(x).(A), eltype(x).(b1), x), copy(b1))
dA2 = FiniteDiff.finite_difference_gradient(
x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = FiniteDiff.finite_difference_gradient(
x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = FiniteDiff.finite_difference_gradient(
x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))

@test dAdA2 atol=5e-5
@test db1 db12
@test db2 db22

function f4(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))
prob = LinearProblem(A, b2)
sol2 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_GMRES()))
norm(sol1.u .+ sol2.u)
end

dA, db1, db2 = Zygote.gradient(f4, A, b1, b1)

dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b1)), copy(b1))
db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b1))

@test dAdA2 atol=5e-5
@test dAdA2 atol=5e-5
@test db1 db12
@test db2 db22
=#

A = rand(n, n);
b1 = rand(n);
Expand Down

0 comments on commit 6432716

Please sign in to comment.