diff --git a/Project.toml b/Project.toml index 7f3b0ba04..1c7ebb77e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.56.0" +version = "1.57.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 28e73c166..6d33a22e7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl") include("rulesets/Base/sort.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Base/broadcast.jl") +include("rulesets/Base/CoreLogging.jl") include("rulesets/Distributed/nondiff.jl") diff --git a/src/rulesets/Base/CoreLogging.jl b/src/rulesets/Base/CoreLogging.jl new file mode 100644 index 000000000..ae97f4e40 --- /dev/null +++ b/src/rulesets/Base/CoreLogging.jl @@ -0,0 +1,20 @@ +# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib) + +function rrule( + rc::RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.CoreLogging.with_logger), + f::Function, + logger::Base.CoreLogging.AbstractLogger, +) + y, f_pb = Base.CoreLogging.with_logger(logger) do + rrule_via_ad(rc, f) + end + with_logger_pullback(ȳ) = (NoTangent(), only(f_pb(ȳ)), NoTangent()) + return y, with_logger_pullback +end + +@non_differentiable Base.CoreLogging.current_logger(args...) +@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) +@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) +@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) +@non_differentiable Base.CoreLogging.handle_message(::Any...) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 58298f068..d35024163 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -483,10 +483,6 @@ end @non_differentiable Broadcast.result_style(::Any) @non_differentiable Broadcast.result_style(::Any, ::Any) -@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) -@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) -@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) -@non_differentiable Base.CoreLogging.handle_message(::Any...) @non_differentiable Libc.free(::Any) @non_differentiable Libc.getpid() diff --git a/test/rulesets/Base/CoreLogging.jl b/test/rulesets/Base/CoreLogging.jl new file mode 100644 index 000000000..28c0b74a8 --- /dev/null +++ b/test/rulesets/Base/CoreLogging.jl @@ -0,0 +1,11 @@ +# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib) +@testset "CoreLogging.jl" begin + @testset "with_logger" begin + test_rrule( + Base.CoreLogging.with_logger, + () -> 2.0 * 3.0, + Base.CoreLogging.NullLogger(); + check_inferred=false, + ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a9f25c55c..768f7c208 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ end test_method_tables() # Check the global method tables are consistent # Each file puts all tests inside one or more @testset blocks + include_test("rulesets/Base/CoreLogging.jl") include_test("rulesets/Base/base.jl") include_test("rulesets/Base/fastmath_able.jl") include_test("rulesets/Base/evalpoly.jl")