Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoints for Linear Solve #449

Merged
merged 8 commits into from
Feb 25, 2024
Merged

Adjoints for Linear Solve #449

merged 8 commits into from
Feb 25, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Dec 20, 2023

Fixes #198, Fixes #322

TODOs:

  • Preconditioning. These need to be added to the Adjoint Sensitivity Struct. Can we use a left preconditioner for the forward problem as the transpose right preconditioner? Need someone with a better grasp of linear algebra to chime in here
  • Support the linsolve !== nothing case. This is useful if we know that $A^T$ has a structure exploitable by a different solver
  • Tests -- Take from Enzyme and ChainRules
  • Fix literal_getproperty for LinearSolution SciMLBase.jl#566

Example:

using LinearSolve, Zygote

A = rand(4, 4)
b = rand(4)

test_func_1(A, b) = sum(abs2, A \ b)

test_func_1(A, b)

∂A_1, ∂b_1 = @btime Zygote.gradient(test_func_1, copy(A), copy(b))
display(∂A_1)
display(∂b_1)

function test_func_2(A, b)
    prob = LinearProblem(A, b)
    sol = solve(prob)
    return sum(abs2, sol.u)
end

test_func_2(A, b)

∂A_2, ∂b_2 = @btime Zygote.gradient(test_func_2, copy(A), copy(b))
display(∂A_2)
display(∂b_2)

In the following case the cache stores the correct gradients but they are not propagated to A and b. @ChrisRackauckas any idea how to fix this?

cache = init(LinearProblem(copy(A), copy(b)), nothing);
function test_func_3(cache, A, b)
    cache.A = A
    cache.b = b
    sol = solve!(cache)
    return sum(abs2, sol.u)
end

test_func_3(cache, copy(A), copy(b))

∂cache, ∂A_3, ∂b_3 = @btime Zygote.gradient(test_func_3, cache, copy(A), copy(b))
∂cache.A
∂cache.b
display(∂A_3)  # nothing
display(∂b_3)  # nothing

Copy link

codecov bot commented Dec 20, 2023

Codecov Report

Attention: Patch coverage is 6.66667% with 42 lines in your changes are missing coverage. Please review.

Project coverage is 22.96%. Comparing base (a206054) to head (06c09a3).

❗ Current head 06c09a3 differs from pull request most recent head 7671369. Consider uploading reports for the commit 7671369 to get more accurate results

Files Patch % Lines
src/adjoint.jl 2.32% 42 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #449       +/-   ##
===========================================
- Coverage   66.12%   22.96%   -43.17%     
===========================================
  Files          27       28        +1     
  Lines        2146     2147        +1     
===========================================
- Hits         1419      493      -926     
- Misses        727     1654      +927     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

src/adjoint.jl Outdated
# Forward Solve
sol = solve!(cache, alg, args...; kwargs...)

function ∇solve!(∂sol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we technically have to deepcopy in here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, it can be problematic if there are 2 subsequent solve calls on the cache.

@ChrisRackauckas
Copy link
Member

In the following case the cache stores the correct gradients but they are not propagated to A and b. @ChrisRackauckas any idea how to fix this?

Is this not just an inherent limitation of Zygote with mutation? I would presume we just need to stay away from that and only support solve with CRC.

@avik-pal avik-pal force-pushed the ap/adjoint branch 2 times, most recently from 4198c86 to 2493dca Compare February 24, 2024 21:42
@ChrisRackauckas ChrisRackauckas merged commit 7b090b4 into main Feb 25, 2024
10 of 16 checks passed
@ChrisRackauckas ChrisRackauckas deleted the ap/adjoint branch February 25, 2024 17:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dispatch on __solve instead of solve adjoint support
2 participants