From 7309b48661ca796abeb437cf58aad4d5965b1210 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Jan 2021 17:40:28 -0800 Subject: [PATCH 1/7] Add breaking tests --- test/differentials/composite.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index b9645327d..32209c1d0 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -18,6 +18,15 @@ struct StructWithInvariant StructWithInvariant(x) = new(x, 2x) end +function _unpack2tuple(comp) + a, b = comp + return (a, b) +end + +function _unpacknamedtuple(comp) + x, y = comp.x, comp.y + return (x, y) +end @testset "Composite" begin @testset "empty types" begin @@ -77,6 +86,17 @@ end # Testing iterate via collect @test collect(Composite{Foo}(x=2.5)) == [2.5] @test collect(Composite{Tuple{Float64,}}(2.0)) == [2.0] + + # Test indexed_iterate + ctup = Composite{Tuple{Float64,Int64}}(2.0, 3) + @inferred _unpack2tuple(ctup) + @test _unpack2tuple(ctup) === (2.0, 3) + + # Test getproperty is inferrable + if VERSION ≥ v"1.2" + @inferred _unpacknamedtuple(Composite{Foo}(x=2, y=3.0)) + @inferred _unpacknamedtuple(Composite{Foo}(y=3.0)) + end end @testset "reverse" begin From 00b97f029150ea28365b26514fa1818fec3a8703 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Jan 2021 17:40:46 -0800 Subject: [PATCH 2/7] Impelement indexed_iterate for tuple composite --- src/differentials/composite.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index 40231d1b6..ba2774fd6 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -96,6 +96,10 @@ function Base.reverse(comp::Composite) Composite{typeof(rev_backing), typeof(rev_backing)}(rev_backing) end +function Base.indexed_iterate(comp::Composite{P,<:Tuple}, i::Int, state=1) where {P} + return Base.indexed_iterate(backing(comp), i, state) +end + function Base.map(f, comp::Composite{P, <:Tuple}) where P vals::Tuple = map(f, backing(comp)) return Composite{P, typeof(vals)}(vals) From ed995e7dc616357a32a085ba3f03111d0968578b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 13 Jan 2021 17:41:04 -0800 Subject: [PATCH 3/7] Use hasfield if available --- src/differentials/composite.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index ba2774fd6..73cbc17d7 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -73,10 +73,16 @@ Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx) # for Tuple Base.getproperty(comp::Composite, idx::Int) = unthunk(getproperty(backing(comp), idx)) function Base.getproperty( - comp::Composite{P, <:NamedTuple{L}}, idx::Symbol -) where {P, L} - # Need to check L directly, or else this does not constant-fold - idx ∈ L || return Zero() + comp::Composite{P, T}, idx::Symbol +) where {P, L, T<:NamedTuple{L}} + # hasfield was added in v1.2 + if VERSION ≥ v"1.2.0-DEV.272" + # hasfield more reliably constant-folds than checking L directly + hasfield(T, idx) || return Zero() + else + # Need to check L directly, or else this does not constant-fold + idx ∈ L || return Zero() + end return unthunk(getproperty(backing(comp), idx)) end From 053718c34a5dc734044b961dff90425a91535c98 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 14 Jan 2021 13:23:43 -0800 Subject: [PATCH 4/7] Use Compat --- Project.toml | 2 ++ src/differentials/composite.jl | 11 ++--------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 185d3f722..b37338610 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,13 @@ uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "0.9.25" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] BenchmarkTools = "0.5" +Compat = "2, 3" FiniteDifferences = "0.10" StaticArrays = "0.11, 0.12" julia = "1" diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index 73cbc17d7..36d21853e 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -74,15 +74,8 @@ Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx) Base.getproperty(comp::Composite, idx::Int) = unthunk(getproperty(backing(comp), idx)) function Base.getproperty( comp::Composite{P, T}, idx::Symbol -) where {P, L, T<:NamedTuple{L}} - # hasfield was added in v1.2 - if VERSION ≥ v"1.2.0-DEV.272" - # hasfield more reliably constant-folds than checking L directly - hasfield(T, idx) || return Zero() - else - # Need to check L directly, or else this does not constant-fold - idx ∈ L || return Zero() - end +) where {P, T<:NamedTuple} + hasfield(T, idx) || return Zero() return unthunk(getproperty(backing(comp), idx)) end From 810928bc1a22d03471f17e25f042f7a598cbc8a4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 14 Jan 2021 13:23:59 -0800 Subject: [PATCH 5/7] Make anonymous functions --- test/differentials/composite.jl | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index 32209c1d0..599c50b22 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -18,16 +18,6 @@ struct StructWithInvariant StructWithInvariant(x) = new(x, 2x) end -function _unpack2tuple(comp) - a, b = comp - return (a, b) -end - -function _unpacknamedtuple(comp) - x, y = comp.x, comp.y - return (x, y) -end - @testset "Composite" begin @testset "empty types" begin @test typeof(Composite{Tuple{}}()) == Composite{Tuple{}, Tuple{}} @@ -89,10 +79,15 @@ end # Test indexed_iterate ctup = Composite{Tuple{Float64,Int64}}(2.0, 3) + _unpack2tuple = function(comp) + a, b = comp + return (a, b) + end @inferred _unpack2tuple(ctup) @test _unpack2tuple(ctup) === (2.0, 3) # Test getproperty is inferrable + _unpacknamedtuple = comp -> (comp.x, comp.y) if VERSION ≥ v"1.2" @inferred _unpacknamedtuple(Composite{Foo}(x=2, y=3.0)) @inferred _unpacknamedtuple(Composite{Foo}(y=3.0)) From ce8af578f31e641986032d145fd70b4015b006bb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 14 Jan 2021 13:24:11 -0800 Subject: [PATCH 6/7] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b37338610..836b1838d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.25" +version = "0.9.26" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 1acdd3f8e28e6182cffeb75e651994af66b32bbd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 14 Jan 2021 13:28:48 -0800 Subject: [PATCH 7/7] Load hasfield from Compat --- src/ChainRulesCore.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 5361b9a81..aac178600 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,6 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using LinearAlgebra: LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC +using Compat: hasfield export on_new_rule, refresh_rules # generation tools export frule, rrule # core function