Skip to content

Commit

Permalink
Add rules for division by Cholesky (#631)
Browse files Browse the repository at this point in the history
* Add rules for division by Cholesky

* Add tests for new rules

* Increment minor version number

* Increment minor version number
  • Loading branch information
sethaxen authored Aug 3, 2022
1 parent 36f3ce7 commit 8587a07
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.40.0"
version = "1.41.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
35 changes: 35 additions & 0 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,38 @@ function _x_divide_conj_y(x, y)
z = x / conj(y)
return iszero(x) ? zero(z) : z
end

# these rules exists because the primals mutates using `ldiv!` and `rdiv!`
function rrule(::typeof(\), A::Cholesky, B::AbstractVecOrMat{<:Union{Real,Complex}})
U, getproperty_back = rrule(getproperty, A, :U)
Z = U' \ B
Y = U \ Z
project_B = ProjectTo(B)
function ldiv_Cholesky_AbsVecOrMat_pullback(ΔY)
∂Z = U' \ ΔY
∂B = U \ ∂Z
∂A = Thunk() do
_, Ā = getproperty_back(-add!!(∂Z * Y', Z * ∂B'))
return
end
return NoTangent(), ∂A, project_B(∂B)
end
return Y, ldiv_Cholesky_AbsVecOrMat_pullback
end

function rrule(::typeof(/), B::AbstractMatrix{<:Union{Real,Complex}}, A::Cholesky)
U, getproperty_back = rrule(getproperty, A, :U)
Z = B / U
Y = Z / U'
project_B = ProjectTo(B)
function rdiv_AbstractMatrix_Cholesky_pullback(ΔY)
∂Z = ΔY / U
∂B = ∂Z / U'
∂A = Thunk() do
_, Ā = getproperty_back(-add!!(∂Z' * Y, Z' * ∂B))
return
end
return NoTangent(), project_B(∂B), ∂A
end
return Y, rdiv_AbstractMatrix_Cholesky_pullback
end
24 changes: 24 additions & 0 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,5 +521,29 @@ end
@test ΔX.factors isa Diagonal && all(iszero, ΔX.factors)
end
end

@testset "\\(::Cholesky, ::AbstractVecOrMat)" begin
n = 10
for T in (Float64, ComplexF64), sz in (n, (n, 5))
A = generate_well_conditioned_matrix(T, n)
C = cholesky(A)
B = randn(T, sz)
# because the rule calls the rrule for getproperty, its rrule is not
# completely type-inferrable
test_rrule(\, C, B; check_inferred=false)
end
end

@testset "/(::AbstractMatrix, ::Cholesky)" begin
n = 10
for T in (Float64, ComplexF64)
A = generate_well_conditioned_matrix(T, n)
C = cholesky(A)
B = randn(T, 5, n)
# because the rule calls the rrule for getproperty, its rrule is not
# completely type-inferrable
test_rrule(/, B, C; check_inferred=false)
end
end
end
end

8 comments on commit 8587a07

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/65542

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.41.0 -m "<description of version>" 8587a07a3bd27f4bc06071eefa0979674cc33a07
git push origin v1.41.0

Also, note the warning: Version 1.41.0 skips over 1.40.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/65542

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.41.0 -m "<description of version>" 8587a07a3bd27f4bc06071eefa0979674cc33a07
git push origin v1.41.0

Also, note the warning: Version 1.41.0 skips over 1.40.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/65542

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.41.0 -m "<description of version>" 8587a07a3bd27f4bc06071eefa0979674cc33a07
git push origin v1.41.0

Also, note the warning: Version 1.41.0 skips over 1.40.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/65542

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.41.0 -m "<description of version>" 8587a07a3bd27f4bc06071eefa0979674cc33a07
git push origin v1.41.0

Please sign in to comment.