diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 08428c190..e851893f4 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,6 +1,10 @@ @inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) @inline unthunk_tangent(x::AbstractArray{<:AbstractThunk}) = map(unthunk_tangent, x) unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) +function ChainRulesCore.rrule(::typeof(unthunk_tangent), d::IdDict) + unthunk_iddict_pullback(_) = (NoTangent(), ChainRulesCore.@not_implemented "unthunking IdDict") + return d, unthunk_iddict_pullback +end @non_differentiable unthunk_tangent(::IdDict)