diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 2f8f48fd..b0fbe5ba 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -313,13 +313,19 @@ function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N} end end +# Modified version of the function used in Zygote.∇getindex +# Works around cases where `a` is immutable or has an eltype that does not define `zero` (e.g. Any) +_zero(xs::AbstractArray{<:Number}, T::Type{<:AbstractZero}) = fill!(similar(xs), zero(eltype(xs))) +_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false) +_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{ZeroTangent, T}), ZeroTangent()) + # TODO: Temporary - make better function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N} getindex(a, inds...), let EvenOddOdd{1, c_order(N)}( (@Base.constprop :aggressive Δ->begin Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...) - BB = zero(a) + BB = _zero(a, eltype(Δ)) BB[inds...] = Δ (NoTangent(), BB, map(x->NoTangent(), inds)...) end), diff --git a/test/runtests.jl b/test/runtests.jl index 4b1832e6..74e0a769 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -218,3 +218,6 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) #include("pinn.jl") + +# PR #82 - getindex on non-numeric arrays +@test Diffractor.gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}