From 0993510170236c0ab9d75d8b898089494ee215c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 19 Apr 2022 20:20:12 -0400 Subject: [PATCH 1/2] Fix rrule for * and add support for constant operations Co-authored-by: Mitchell Harris --- src/chain_rules.jl | 78 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/src/chain_rules.jl b/src/chain_rules.jl index bec4d5a4..60877411 100644 --- a/src/chain_rules.jl +++ b/src/chain_rules.jl @@ -4,27 +4,95 @@ ChainRulesCore.@scalar_rule +(x::APL) true ChainRulesCore.@scalar_rule -(x::APL) -1 ChainRulesCore.@scalar_rule +(x::APL, y::APL) (true, true) +function plusconstant1_pullback(Δ) + return ChainRulesCore.NoTangent(), Δ, coefficient(Δ, constantmonomial(Δ)) +end +function ChainRulesCore.rrule(::typeof(plusconstant), p::APL, α) + return plusconstant(p, α), plusconstant1_pullback +end +function plusconstant2_pullback(Δ) + return ChainRulesCore.NoTangent(), coefficient(Δ, constantmonomial(Δ)), Δ +end +function ChainRulesCore.rrule(::typeof(plusconstant), α, p::APL) + return plusconstant(α, p), plusconstant2_pullback +end ChainRulesCore.@scalar_rule -(x::APL, y::APL) (true, -1) function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL) return p * q, MA.add_mul!!(p * Δq, q, Δp) end + +function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function} + for t in terms(p) + c = coefficient(t) + m = monomial(t) + for δ in Δ + if divides(m, δ) + coef = op(c, coefficient(δ)) + mono = _div(monomial(δ), m) + push!(ts, term(coef, mono)) + end + end + end + return polynomial(ts) +end +function adjoint_mult_left(p, Δ) + ts = MA.promote_operation(*, MA.promote_operation(adjoint, termtype(p)), termtype(Δ))[] + return _adjoint_mult(ts, p, Δ) do c, d + c' * d + end +end +function adjoint_mult_right(p, Δ) + ts = MA.promote_operation(*, termtype(Δ), MA.promote_operation(adjoint, termtype(p)))[] + return _adjoint_mult(ts, p, Δ) do c, d + d * c' + end +end + function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL) function times_pullback2(ΔΩ̇) - #ΔΩ = ChainRulesCore.unthunk(Ω̇) - #return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ)) - return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇) + return (ChainRulesCore.NoTangent(), adjoint_mult_right(q, ΔΩ̇), adjoint_mult_left(p, ΔΩ̇)) end return p * q, times_pullback2 end +function ChainRulesCore.rrule(::typeof(multconstant), α, p::APL) + function times_pullback2(ΔΩ̇) + # TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term. + Δα = adjoint_mult_right(p, ΔΩ̇) + return (ChainRulesCore.NoTangent(), coefficient(Δα, constantmonomial(Δα)), α' * ΔΩ̇) + end + return multconstant(α, p), times_pullback2 +end + +function ChainRulesCore.rrule(::typeof(multconstant), p::APL, α) + function times_pullback2(ΔΩ̇) + # TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term. + Δα = adjoint_mult_left(p, ΔΩ̇) + return (ChainRulesCore.NoTangent(), ΔΩ̇ * α', coefficient(Δα, constantmonomial(Δα))) + end + return multconstant(p, α), times_pullback2 +end + +notangent3(Δ) = ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() +function ChainRulesCore.rrule(::typeof(^), mono::AbstractMonomialLike, i::Integer) + return mono^i, notangent3 +end + function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x) return differentiate(p, x), differentiate(Δp, x) end -function pullback(Δdpdx, x) +function differentiate_pullback(Δdpdx, x) return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent() end function ChainRulesCore.rrule(::typeof(differentiate), p, x) dpdx = differentiate(p, x) - return dpdx, Base.Fix2(pullback, x) + return dpdx, Base.Fix2(differentiate_pullback, x) +end + +function coefficient_pullback(Δ, m::AbstractMonomialLike) + return ChainRulesCore.NoTangent(), polynomial(term(Δ, m)), ChainRulesCore.NoTangent() +end +function ChainRulesCore.rrule(::typeof(coefficient), p::APL, m::AbstractMonomialLike) + return coefficient(p, m), Base.Fix2(coefficient_pullback, m) end From 2c3df716b0c0d76c3b94cf00edac93571d07000b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 26 Dec 2022 15:35:07 +0100 Subject: [PATCH 2/2] Clarify issue with scalar products --- src/chain_rules.jl | 16 ++++++++++++--- test/chain_rules.jl | 47 +++++++++++++++++++++++++++------------------ 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/chain_rules.jl b/src/chain_rules.jl index 60877411..6e9d0ba8 100644 --- a/src/chain_rules.jl +++ b/src/chain_rules.jl @@ -1,3 +1,9 @@ +# The publlback depends on the scalar product on the polynomials +# With the scalar product `LinearAlgebra.dot(p, q) = p * q`, there is no pullback for `differentiate` +# With the scalar product `_dot(p, q)` of `test/chain_rules.jl`, there is a pullback for `differentiate` +# and the pullback for `*` changes. +# We give the one for the scalar product `_dot`. + import ChainRulesCore ChainRulesCore.@scalar_rule +(x::APL) true @@ -22,7 +28,7 @@ function ChainRulesCore.frule((_, Δp, Δq), ::typeof(*), p::APL, q::APL) return p * q, MA.add_mul!!(p * Δq, q, Δp) end -function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function} +function _mult_pullback(op::F, ts, p, Δ) where {F<:Function} for t in terms(p) c = coefficient(t) m = monomial(t) @@ -38,20 +44,23 @@ function _adjoint_mult(op::F, ts, p, Δ) where {F<:Function} end function adjoint_mult_left(p, Δ) ts = MA.promote_operation(*, MA.promote_operation(adjoint, termtype(p)), termtype(Δ))[] - return _adjoint_mult(ts, p, Δ) do c, d + return _mult_pullback(ts, p, Δ) do c, d c' * d end end function adjoint_mult_right(p, Δ) ts = MA.promote_operation(*, termtype(Δ), MA.promote_operation(adjoint, termtype(p)))[] - return _adjoint_mult(ts, p, Δ) do c, d + return _mult_pullback(ts, p, Δ) do c, d d * c' end end function ChainRulesCore.rrule(::typeof(*), p::APL, q::APL) function times_pullback2(ΔΩ̇) + # This is for the scalar product `_dot`: return (ChainRulesCore.NoTangent(), adjoint_mult_right(q, ΔΩ̇), adjoint_mult_left(p, ΔΩ̇)) + # For the scalar product `dot`, it would be instead: + return (ChainRulesCore.NoTangent(), ΔΩ̇ * q', p' * ΔΩ̇) end return p * q, times_pullback2 end @@ -82,6 +91,7 @@ end function ChainRulesCore.frule((_, Δp, _), ::typeof(differentiate), p, x) return differentiate(p, x), differentiate(Δp, x) end +# This is for the scalar product `_dot`, there is no pullback for the scalar product `dot` function differentiate_pullback(Δdpdx, x) return ChainRulesCore.NoTangent(), x * differentiate(x * Δdpdx, x), ChainRulesCore.NoTangent() end diff --git a/test/chain_rules.jl b/test/chain_rules.jl index d8e3cb71..307893f0 100644 --- a/test/chain_rules.jl +++ b/test/chain_rules.jl @@ -12,6 +12,20 @@ function test_chain_rule(dot, op, args, Δin, Δout) @test dot(Δin, rΔin[2:end]) ≈ dot(fΔout, Δout) end +function _dot(p, q) + monos = monovec([monomials(p); monomials(q)]) + return dot(coefficient.(p, monos), coefficient.(q, monos)) +end +function _dot(px::Tuple, qx::Tuple) + return _dot(first(px), first(qx)) + _dot(Base.tail(px), Base.tail(qx)) +end +function _dot(::Tuple{}, ::Tuple{}) + return MultivariatePolynomials.MA.Zero() +end +function _dot(::NoTangent, ::NoTangent) + return MultivariatePolynomials.MA.Zero() +end + @testset "ChainRulesCore" begin Mod.@polyvar x y p = 1.1x + y @@ -42,30 +56,25 @@ end @test pullback(q) == (NoTangent(), (-0.2 + 2im) * x^2 - x*y, NoTangent()) @test pullback(1x) == (NoTangent(), 2x^2, NoTangent()) - test_chain_rule(dot, +, (p,), (q,), p) - test_chain_rule(dot, +, (q,), (p,), q) + for d in [dot, _dot] + test_chain_rule(d, +, (p,), (q,), p) + test_chain_rule(d, +, (q,), (p,), q) - test_chain_rule(dot, -, (p,), (q,), p) - test_chain_rule(dot, -, (p,), (p,), q) + test_chain_rule(d, -, (p,), (q,), p) + test_chain_rule(d, -, (p,), (p,), q) - test_chain_rule(dot, +, (p, q), (q, p), p) - test_chain_rule(dot, +, (p, q), (p, q), q) + test_chain_rule(d, +, (p, q), (q, p), p) + test_chain_rule(d, +, (p, q), (p, q), q) - test_chain_rule(dot, -, (p, q), (q, p), p) - test_chain_rule(dot, -, (p, q), (p, q), q) + test_chain_rule(d, -, (p, q), (q, p), p) + test_chain_rule(d, -, (p, q), (p, q), q) + end - test_chain_rule(dot, *, (p, q), (q, p), p * q) - test_chain_rule(dot, *, (p, q), (p, q), q * q) - test_chain_rule(dot, *, (q, p), (p, q), q * q) - test_chain_rule(dot, *, (p, q), (q, p), q * q) + test_chain_rule(_dot, *, (p, q), (q, p), p * q) + test_chain_rule(_dot, *, (p, q), (p, q), q * q) + test_chain_rule(_dot, *, (q, p), (p, q), q * q) + test_chain_rule(_dot, *, (p, q), (q, p), q * q) - function _dot(p, q) - monos = monomials(p + q) - return dot(coefficient.(p, monos), coefficient.(q, monos)) - end - function _dot(px::Tuple{<:AbstractPolynomial,NoTangent}, qx::Tuple{<:AbstractPolynomial,NoTangent}) - return _dot(px[1], qx[1]) - end test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), p) test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(p, x)) test_chain_rule(_dot, differentiate, (p, x), (q, NoTangent()), differentiate(q, x))