Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VecJac #245

Merged
merged 9 commits into from
Jun 2, 2023
Merged

Fix VecJac #245

merged 9 commits into from
Jun 2, 2023

Conversation

vpuri3
Copy link
Contributor

@vpuri3 vpuri3 commented May 30, 2023

merge #244 first

In this PR, I have fixed some bugs in the implementation of vector jacobian products. In doing so, I have separated the VecJac implementations for AutoFiniteDiff, and AutoZygote. The latter lives in the Zygote extension. A concern with the Zygote VJP implementation raised in SciML/SciMLSensitivity.jl#808 was that the underlying functions, auto_vecjac(!), would recompute the pullback function every time the FunctionOperator is called by *, mul!. I have made it so that the pullback is recomputed only when update_coefficients(!) is called with the keyword argument VJP_input.

L = VecJac(f, x1, p, t; autodiff = AutoZygote()) # pullback is computed at x1

# df/dx1' * v. pullback is not recomputed in below calls
L * v               
L(v, p, t)
mul!(w, L, v)
L(w, v, p, t)

L = update_coefficients(L, v, p, t) # pullback is not recomputed
update_coefficients!(L, v, p, t)    # pullback is not recomputed

update_coefficients!(L, w, p, t; VJP_input = x2) # pullback is recomputed at x2: Zygote.pullback(L.f, x2)

# df/dx2' * v. pullback is not recomputed in below calls
L * v               
L(v, p, t)
mul!(w, L, v)
L(w, v, p, t)

# df/dx2' * v --- pullback is recomputed at x2 Zygote.pullback(L.f, x2)
L(v, p, t; VJP_input = x2)
L(w, v, p, t; VJP_input = x2)

Tests are added in test/test_vecjac_products.jl to demonstrate this behaviour.

@codecov
Copy link

codecov bot commented May 30, 2023

Codecov Report

Patch coverage: 87.30% and project coverage change: +0.62 🎉

Comparison is base (e4b7122) 84.95% compared to head (6f1e267) 85.58%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #245      +/-   ##
==========================================
+ Coverage   84.95%   85.58%   +0.62%     
==========================================
  Files          14       14              
  Lines         964      992      +28     
==========================================
+ Hits          819      849      +30     
+ Misses        145      143       -2     
Impacted Files Coverage Δ
src/SparseDiffTools.jl 75.00% <ø> (ø)
src/differentiation/jaches_products.jl 95.53% <0.00%> (ø)
src/differentiation/vecjac_products.jl 93.75% <87.87%> (+3.75%) ⬆️
ext/SparseDiffToolsZygote.jl 96.66% <92.85%> (-3.34%) ⬇️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@vpuri3
Copy link
Contributor Author

vpuri3 commented May 30, 2023

It would be better, so as to avoid confusion between u, v in df/du * v, rename u --> VJP_in, v --> VJP_mult and pass it as a kwarg to update_coefficients. That way L(v, p, t) will be the same as L * v.

@vpuri3 vpuri3 changed the title [WIP] Fix VecJac Fix VecJac May 31, 2023
@vpuri3
Copy link
Contributor Author

vpuri3 commented May 31, 2023

This is good to go. @ChrisRackauckas please take a look

@vpuri3
Copy link
Contributor Author

vpuri3 commented May 31, 2023

rerunning CI because of flaky ODE.jl "matrix contains infs/nans error"

@vpuri3 vpuri3 closed this May 31, 2023
@vpuri3 vpuri3 reopened this May 31, 2023
@ChrisRackauckas
Copy link
Member

@avik-pal can you take a look? I know you looked at the vecjec potential usage in SciMLSensitivity and I think we should be trying to get it to there

@vpuri3
Copy link
Contributor Author

vpuri3 commented Jun 2, 2023

@avik-pal ping

Copy link
Contributor

@avik-pal avik-pal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API looks great. We should be able to directly use it in SciMLSensitivty.jl after this is merged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants