From c868913f5b38bfaed0a1fd71b6b126928028f0ab Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 14 May 2024 15:02:36 -0400 Subject: [PATCH] Fix ReverseDiff downstream --- Project.toml | 2 +- ext/RecursiveArrayToolsReverseDiffExt.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 1673cecf..72608d0a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "3.17.0" +version = "3.18.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/RecursiveArrayToolsReverseDiffExt.jl b/ext/RecursiveArrayToolsReverseDiffExt.jl index 79b59418..d3526da1 100644 --- a/ext/RecursiveArrayToolsReverseDiffExt.jl +++ b/ext/RecursiveArrayToolsReverseDiffExt.jl @@ -3,12 +3,13 @@ module RecursiveArrayToolsReverseDiffExt using RecursiveArrayTools using ReverseDiff using Zygote: @adjoint +using RecursiveArrayTools.ArrayInterface function trackedarraycopyto!(dest, src) for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims = ndims(src))) if dest.u[i] isa AbstractArray - dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i])) - else + dest.u[i] = reshape(ArrayInterface.aos_to_soa(slice), size(dest.u[i])) + elseif dest.u[i] trackedarraycopyto!(dest.u[i], slice) end end