diff --git a/Project.toml b/Project.toml index 185d3f722..836b1838d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,15 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.25" +version = "0.9.26" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] BenchmarkTools = "0.5" +Compat = "2, 3" FiniteDifferences = "0.10" StaticArrays = "0.11, 0.12" julia = "1" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 5361b9a81..aac178600 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,6 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using LinearAlgebra: LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC +using Compat: hasfield export on_new_rule, refresh_rules # generation tools export frule, rrule # core function diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index 40231d1b6..36d21853e 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -73,10 +73,9 @@ Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx) # for Tuple Base.getproperty(comp::Composite, idx::Int) = unthunk(getproperty(backing(comp), idx)) function Base.getproperty( - comp::Composite{P, <:NamedTuple{L}}, idx::Symbol -) where {P, L} - # Need to check L directly, or else this does not constant-fold - idx ∈ L || return Zero() + comp::Composite{P, T}, idx::Symbol +) where {P, T<:NamedTuple} + hasfield(T, idx) || return Zero() return unthunk(getproperty(backing(comp), idx)) end @@ -96,6 +95,10 @@ function Base.reverse(comp::Composite) Composite{typeof(rev_backing), typeof(rev_backing)}(rev_backing) end +function Base.indexed_iterate(comp::Composite{P,<:Tuple}, i::Int, state=1) where {P} + return Base.indexed_iterate(backing(comp), i, state) +end + function Base.map(f, comp::Composite{P, <:Tuple}) where P vals::Tuple = map(f, backing(comp)) return Composite{P, typeof(vals)}(vals) diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index b9645327d..599c50b22 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -18,7 +18,6 @@ struct StructWithInvariant StructWithInvariant(x) = new(x, 2x) end - @testset "Composite" begin @testset "empty types" begin @test typeof(Composite{Tuple{}}()) == Composite{Tuple{}, Tuple{}} @@ -77,6 +76,22 @@ end # Testing iterate via collect @test collect(Composite{Foo}(x=2.5)) == [2.5] @test collect(Composite{Tuple{Float64,}}(2.0)) == [2.0] + + # Test indexed_iterate + ctup = Composite{Tuple{Float64,Int64}}(2.0, 3) + _unpack2tuple = function(comp) + a, b = comp + return (a, b) + end + @inferred _unpack2tuple(ctup) + @test _unpack2tuple(ctup) === (2.0, 3) + + # Test getproperty is inferrable + _unpacknamedtuple = comp -> (comp.x, comp.y) + if VERSION ≥ v"1.2" + @inferred _unpacknamedtuple(Composite{Foo}(x=2, y=3.0)) + @inferred _unpacknamedtuple(Composite{Foo}(y=3.0)) + end end @testset "reverse" begin