Skip to content

Commit

Permalink
Clarify issue with scalar products
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 26, 2022
1 parent 0993510 commit 15315c8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.4.6"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"

Expand Down
16 changes: 13 additions & 3 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# The publlback depends on the scalar product on the polynomials
# With the scalar product `LinearAlgebra.dot(p, q) = p * q`, there is no pullback for `differentiate`
# With the scalar product `_dot(p, q)` of `test/chain_rules.jl`, there is a pullback for `differentiate`
# and the pullback for `*` changes.
# We give the one for the scalar product `_dot`.

import ChainRulesCore

ChainRulesCore.@scalar_rule +(x::APL) true
Expand All @@ -22,7 +28,7 @@ function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL)
return p * q, MA.add_mul!!(p * Δq, q, Δp)
end

function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function}
function _mult_pullback(op::F, ts, p, Δ) where {F<:Function}
for t in terms(p)
c = coefficient(t)
m = monomial(t)
Expand All @@ -38,20 +44,23 @@ function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function}
end
function adjoint_mult_left(p, Δ)
ts = MA.promote_operation(*, MA.promote_operation(adjoint, termtype(p)), termtype(Δ))[]
return _adjoint_mult(ts, p, Δ) do c, d
return _mult_pullback(ts, p, Δ) do c, d
c' * d
end
end
function adjoint_mult_right(p, Δ)
ts = MA.promote_operation(*, termtype(Δ), MA.promote_operation(adjoint, termtype(p)))[]
return _adjoint_mult(ts, p, Δ) do c, d
return _mult_pullback(ts, p, Δ) do c, d
d * c'
end
end

function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL)
function times_pullback2(ΔΩ̇)
# This is for the scalar product `_dot`:
return (ChainRulesCore.NoTangent(), adjoint_mult_right(q, ΔΩ̇), adjoint_mult_left(p, ΔΩ̇))
# For the scalar product `dot`, it would be instead:
return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇)
end
return p * q, times_pullback2
end
Expand Down Expand Up @@ -82,6 +91,7 @@ end
function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x)
return differentiate(p, x), differentiate(Δp, x)
end
# This is for the scalar product `_dot`, there is no pullback for the scalar product `dot`
function differentiate_pullback(Δdpdx, x)
return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent()
end
Expand Down
47 changes: 28 additions & 19 deletions test/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ function test_chain_rule(dot, op, args, Δin, Δout)
@test dot(Δin, rΔin[2:end]) dot(fΔout, Δout)
end

function _dot(p, q)
monos = monovec([monomials(p); monomials(q)])
return dot(coefficient.(p, monos), coefficient.(q, monos))
end
function _dot(px::Tuple, qx::Tuple)
return _dot(first(px), first(qx)) + _dot(Base.tail(px), Base.tail(qx))
end
function _dot(::Tuple{}, ::Tuple{})
return MultivariatePolynomials.MA.Zero()
end
function _dot(::NoTangent, ::NoTangent)
return MultivariatePolynomials.MA.Zero()
end

@testset "ChainRulesCore" begin
Mod.@polyvar x y
p = 1.1x + y
Expand Down Expand Up @@ -42,30 +56,25 @@ end
@test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent())
@test pullback(1x) == (NoTangent(), 2x^2, NoTangent())

test_chain_rule(dot, +, (p,), (q,), p)
test_chain_rule(dot, +, (q,), (p,), q)
for d in [dot, _dot]
test_chain_rule(d, +, (p,), (q,), p)
test_chain_rule(d, +, (q,), (p,), q)

test_chain_rule(dot, -, (p,), (q,), p)
test_chain_rule(dot, -, (p,), (p,), q)
test_chain_rule(d, -, (p,), (q,), p)
test_chain_rule(d, -, (p,), (p,), q)

test_chain_rule(dot, +, (p, q), (q, p), p)
test_chain_rule(dot, +, (p, q), (p, q), q)
test_chain_rule(d, +, (p, q), (q, p), p)
test_chain_rule(d, +, (p, q), (p, q), q)

test_chain_rule(dot, -, (p, q), (q, p), p)
test_chain_rule(dot, -, (p, q), (p, q), q)
test_chain_rule(d, -, (p, q), (q, p), p)
test_chain_rule(d, -, (p, q), (p, q), q)
end

test_chain_rule(dot, *, (p, q), (q, p), p * q)
test_chain_rule(dot, *, (p, q), (p, q), q * q)
test_chain_rule(dot, *, (q, p), (p, q), q * q)
test_chain_rule(dot, *, (p, q), (q, p), q * q)
test_chain_rule(_dot, *, (p, q), (q, p), p * q)
test_chain_rule(_dot, *, (p, q), (p, q), q * q)
test_chain_rule(_dot, *, (q, p), (p, q), q * q)
test_chain_rule(_dot, *, (p, q), (q, p), q * q)

function _dot(p, q)
monos = monomials(p + q)
return dot(coefficient.(p, monos), coefficient.(q, monos))
end
function _dot(px::Tuple{<:AbstractPolynomial,NoTangent}, qx::Tuple{<:AbstractPolynomial,NoTangent})
return _dot(px[1], qx[1])
end
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), p)
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(p, x))
test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(q, x))
Expand Down

0 comments on commit 15315c8

Please sign in to comment.