Skip to content

Commit

Permalink
fix: jacobian caching
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent 4a74ae0 commit 0b1ce3a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
25 changes: 18 additions & 7 deletions lib/NonlinearSolveBase/src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,16 @@ end
## Numbers
function (cache::JacobianCache{<:Number})(::Number, u, p = cache.p)
cache.stats.njacs += 1
SciMLBase.has_jac(cache.f) && return cache.f.jac(u, p)
SciMLBase.has_vjp(cache.f) && return cache.f.vjp(one(u), u, p)
SciMLBase.has_jvp(cache.f) && return cache.f.jvp(one(u), u, p)
return DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
cache.J = if SciMLBase.has_jac(cache.f)
cache.f.jac(u, p)
elseif SciMLBase.has_vjp(cache.f)
cache.f.vjp(one(u), u, p)
elseif SciMLBase.has_jvp(cache.f)
cache.f.jvp(one(u), u, p)
else
DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
end
return cache.J
end

## Actually Compute the Jacobian
Expand All @@ -156,12 +162,17 @@ function (cache::JacobianCache)(J::Union{AbstractMatrix, Nothing}, u, p = cache.
cache.f.jac(J, u, p)
else
DI.jacobian!(
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p))
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)
)
end
return J
else
SciMLBase.has_jac(cache.f) && return cache.f.jac(u, p)
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
if SciMLBase.has_jac(cache.f)
cache.J = cache.f.jac(u, p)
else
cache.J = DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
end
return cache.J
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
dataOut = f([1, 2, 3], nothing) + 0.1 * randn(10, 1)

resid(x, p) = f(x, p) - dataOut
jac(x, p) = [dataIn .^ 2 dataIn ones(10, 1)]
jac(x, p) = [1:10 .^ 2 1:10 ones(10, 1)]
x0 = [1, 1, 1]

prob = NonlinearLeastSquaresProblem(resid, x0)
Expand Down

0 comments on commit 0b1ce3a

Please sign in to comment.