From 913a00cec2305955c20e041907d2ae575ae514a0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 7 Dec 2024 18:49:12 +0000 Subject: [PATCH 1/4] feat: pretty printing for gradient operators --- ext/DynamicExpressionsZygoteExt.jl | 35 +++++++++++++------ src/ExtensionInterface.jl | 17 ++++++++++ test/test_zygote_gradient_wrapper.jl | 51 ++++++++++++++++++++++++++++ test/unittest.jl | 1 + 4 files changed, 93 insertions(+), 11 deletions(-) create mode 100644 test/test_zygote_gradient_wrapper.jl diff --git a/ext/DynamicExpressionsZygoteExt.jl b/ext/DynamicExpressionsZygoteExt.jl index 590cc402..5654c27e 100644 --- a/ext/DynamicExpressionsZygoteExt.jl +++ b/ext/DynamicExpressionsZygoteExt.jl @@ -1,19 +1,32 @@ module DynamicExpressionsZygoteExt -import Zygote: gradient -import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient +using Zygote: gradient +import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient, ZygoteGradient function _zygote_gradient(op::F, ::Val{1}) where {F} - function (x) - out = gradient(op, x)[1] - return out === nothing ? zero(x) : out - end + return ZygoteGradient{F,1,1}(op) end -function _zygote_gradient(op::F, ::Val{2}) where {F} - function (x, y) - (∂x, ∂y) = gradient(op, x, y) - return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y) - end +function _zygote_gradient(op::F, ::Val{2}, ::Val{side}=Val(nothing)) where {F,side} + # side should be either nothing (for both), 1, or 2 + @assert side === nothing || side in (1, 2) + return ZygoteGradient{F,2,side}(op) +end + +function (g::ZygoteGradient{F,1,1})(x) where {F} + out = only(gradient(g.op, x)) + return out === nothing ? zero(x) : out +end +function (g::ZygoteGradient{F,2,nothing})(x, y) where {F} + (∂x, ∂y) = gradient(g.op, x, y) + return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y) +end +function (g::ZygoteGradient{F,2,1})(x, y) where {F} + ∂x = only(gradient(Base.Fix2(g.op, y), x)) + return ∂x === nothing ? zero(x) : ∂x +end +function (g::ZygoteGradient{F,2,2})(x, y) where {F} + ∂y = only(gradient(Base.Fix1(g.op, x), y)) + return ∂y === nothing ? zero(y) : ∂y end end diff --git a/src/ExtensionInterface.jl b/src/ExtensionInterface.jl index 521b0c88..5c84efd6 100644 --- a/src/ExtensionInterface.jl +++ b/src/ExtensionInterface.jl @@ -7,6 +7,23 @@ function symbolic_to_node(args...; kws...) return error("Please load the `SymbolicUtils` package to use `symbolic_to_node`.") end +struct ZygoteGradient{F,degree,arg} <: Function + op::F +end + +function Base.show(io::IO, g::ZygoteGradient{F,degree,arg}) where {F,degree,arg} + print(io, "∂") + if degree == 2 + if arg == 1 + print(io, "₁") + elseif arg == 2 + print(io, "₂") + end + end + print(io, g.op) + return nothing +end + function _zygote_gradient(args...) return error("Please load the Zygote.jl package.") end diff --git a/test/test_zygote_gradient_wrapper.jl b/test/test_zygote_gradient_wrapper.jl new file mode 100644 index 00000000..67c8de58 --- /dev/null +++ b/test/test_zygote_gradient_wrapper.jl @@ -0,0 +1,51 @@ +@testitem "ZygoteGradient string representation" begin + using DynamicExpressions + using DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient + using Zygote + + # Test unary gradient + f(x) = x^2 + @test repr(_zygote_gradient(f, Val(1))) == "∂f" + + # Test binary gradient (both partials) + g(x, y) = x * y + @test repr(_zygote_gradient(g, Val(2))) == "∂g" + + # Test binary gradient (first partial) + @test repr(_zygote_gradient(g, Val(2), Val(1))) == "∂₁g" + + # Test binary gradient (second partial) + @test repr(_zygote_gradient(g, Val(2), Val(2))) == "∂₂g" + + # Test with standard operators + @test repr(_zygote_gradient(+, Val(2))) == "∂+" + @test repr(_zygote_gradient(*, Val(2), Val(1))) == "∂₁*" + @test repr(_zygote_gradient(*, Val(2), Val(2))) == "∂₂*" + + first_partial = _zygote_gradient(log, Val(2), Val(1)) + nested = _zygote_gradient(first_partial, Val(1)) + @test repr(nested) == "∂∂₁log" +end + +@testitem "ZygoteGradient evaluation" begin + using DynamicExpressions + using DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient + using Zygote + + x = 2.0 + y = 3.0 + + # Test unary gradient + f(x) = x^2 + @test (_zygote_gradient(f, Val(1)))(x) == 4.0 + + # Test binary gradient (both partials) + g(x, y) = x * y + @test (_zygote_gradient(g, Val(2)))(x, y) == (3.0, 2.0) + + # Test binary gradient (first partial) + @test (_zygote_gradient(g, Val(2), Val(1)))(x, y) == 3.0 + + # Test second partial + @test (_zygote_gradient(g, Val(2), Val(2)))(x, y) == 2.0 +end diff --git a/test/unittest.jl b/test/unittest.jl index f3ff96c8..26c2bdd0 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -130,3 +130,4 @@ include("test_node_interface.jl") include("test_expression_math.jl") include("test_structured_expression.jl") include("test_readonlynode.jl") +include("test_zygote_gradient_wrapper.jl") From c0da506d8276f91a4c07dbeb816bbfb15efe939d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 7 Dec 2024 18:49:33 +0000 Subject: [PATCH 2/4] chore: bump version with pretty gradient operators --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fc9ed3e7..1751515f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "1.5.1" +version = "1.6.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 699fded1c67c42c6815742af5d1cf83d77b5451f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 7 Dec 2024 18:51:51 +0000 Subject: [PATCH 3/4] feat: also use pretty printing for regular `show` --- src/ExtensionInterface.jl | 1 + test/test_zygote_gradient_wrapper.jl | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/ExtensionInterface.jl b/src/ExtensionInterface.jl index 5c84efd6..1628683d 100644 --- a/src/ExtensionInterface.jl +++ b/src/ExtensionInterface.jl @@ -23,6 +23,7 @@ function Base.show(io::IO, g::ZygoteGradient{F,degree,arg}) where {F,degree,arg} print(io, g.op) return nothing end +Base.show(io::IO, ::MIME"text/plain", g::ZygoteGradient) = show(io, g) function _zygote_gradient(args...) return error("Please load the Zygote.jl package.") diff --git a/test/test_zygote_gradient_wrapper.jl b/test/test_zygote_gradient_wrapper.jl index 67c8de58..45e7e319 100644 --- a/test/test_zygote_gradient_wrapper.jl +++ b/test/test_zygote_gradient_wrapper.jl @@ -25,6 +25,9 @@ first_partial = _zygote_gradient(log, Val(2), Val(1)) nested = _zygote_gradient(first_partial, Val(1)) @test repr(nested) == "∂∂₁log" + + # Also should work with text/plain + @test repr(MIME"text/plain", nested) == "∂∂₁log" end @testitem "ZygoteGradient evaluation" begin From e27214a2b86e7ce72499b01c47044e3c8179b9eb Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 7 Dec 2024 19:04:21 +0000 Subject: [PATCH 4/4] test: fix passing of MIME type --- test/test_zygote_gradient_wrapper.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_zygote_gradient_wrapper.jl b/test/test_zygote_gradient_wrapper.jl index 45e7e319..7eed34bb 100644 --- a/test/test_zygote_gradient_wrapper.jl +++ b/test/test_zygote_gradient_wrapper.jl @@ -27,7 +27,7 @@ @test repr(nested) == "∂∂₁log" # Also should work with text/plain - @test repr(MIME"text/plain", nested) == "∂∂₁log" + @test repr("text/plain", nested) == "∂∂₁log" end @testitem "ZygoteGradient evaluation" begin