From 596e685fece8444f28905e7204fcf3387687a99b Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 10 Jan 2025 20:06:02 +0100 Subject: [PATCH] fix ad type instabilities in new zygote --- ext/TensorOperationsChainRulesCoreExt.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index d3d88b8..207320a 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -86,19 +86,19 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) function pullback(ΔC′) ΔC = unthunk(ΔC′) dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin + dA = @thunk let ipA = invperm(linearize(pA)) _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end - dα = @thunk begin + dα = @thunk let _dα = tensorscalar(tensorcontract(A, ((), linearize(pA)), !conjA, ΔC, (trivtuple(numind(pA)), ()), false, ((), ()), One(), ba...)) return projectα(_dα) end - dβ = @thunk begin + dβ = @thunk let # TODO: consider using `inner` _dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pA))), true, ΔC, (trivtuple(numind(pA)), ()), false, @@ -165,7 +165,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) pΔC = (TupleTools.getindices(ipAB, trivtuple(numout(pA))), TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB)))) dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin + dA = @thunk let ipA = (invperm(linearize(pA)), ()) conjΔC = conjA conjB′ = conjA ? conjB : !conjB @@ -177,7 +177,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end - dB = @thunk begin + dB = @thunk let ipB = (invperm(linearize(pB)), ()) conjΔC = conjB conjA′ = conjB ? conjA : !conjA @@ -189,7 +189,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) conjB ? α : conj(α), Zero(), ba...) return projectB(_dB) end - dα = @thunk begin + dα = @thunk let C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) # TODO: consider using `inner` _dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(pAB))), true, @@ -197,7 +197,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ((), ()), One(), ba...)) return projectα(_dα) end - dβ = @thunk begin + dβ = @thunk let # TODO: consider using `inner` _dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pAB))), true, ΔC, (trivtuple(numind(pAB)), ()), false, @@ -249,7 +249,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) function pullback(ΔC′) ΔC = unthunk(ΔC′) dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin + dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) Es = map(q[1], q[2]) do i1, i2 return one(TensorOperations.tensoralloc_add(scalartype(A), A, @@ -263,7 +263,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end - dα = @thunk begin + dα = @thunk let C_αβ = tensortrace(A, p, q, false, One(), ba...) _dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(p))), !conjA, @@ -271,7 +271,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) ((), ()), One(), ba...)) return projectα(_dα) end - dβ = @thunk begin + dβ = @thunk let _dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(p))), true, ΔC, (trivtuple(numind(p)), ()), false, ((), ()), One(), ba...))