From 0ae39159fe6539d21a52dd214f830c67009dc83e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Oct 2023 13:46:01 +0530 Subject: [PATCH] refactor: use new RecursiveArrayTools --- src/common_defaults.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/common_defaults.jl b/src/common_defaults.jl index 4088984b8..b044409c6 100644 --- a/src/common_defaults.jl +++ b/src/common_defaults.jl @@ -6,6 +6,9 @@ end @inline function UNITLESS_ABS2(x::AbstractArray) mapreduce(UNITLESS_ABS2, abs2_and_sum, x, init = zero(real(value(eltype(x))))) end +@inline function UNITLESS_ABS2(x::RecursiveArrayTools.AbstractVectorOfArray) + mapreduce(UNITLESS_ABS2, abs2_and_sum, x.u, init = zero(real(value(eltype(x))))) +end @inline function UNITLESS_ABS2(x::RecursiveArrayTools.ArrayPartition) mapreduce(UNITLESS_ABS2, abs2_and_sum, x.x, init = zero(real(value(eltype(x))))) end @@ -37,7 +40,11 @@ end Base.FastMath.sqrt_fast(real(sum(abs2, u)) / max(length(u), 1)) end -@inline function ODE_DEFAULT_NORM(u::AbstractArray, t) +@inline function ODE_DEFAULT_NORM(u::Union{ + AbstractArray, + RecursiveArrayTools.AbstractVectorOfArray, + }, + t) Base.FastMath.sqrt_fast(UNITLESS_ABS2(u) / max(recursive_length(u), 1)) end @inline ODE_DEFAULT_NORM(u, t) = norm(u) @@ -57,7 +64,8 @@ end @inline NAN_CHECK(x::Number) = isnan(x) @inline NAN_CHECK(x::Float64) = isnan(x) || (x > 1e50) @inline NAN_CHECK(x::Enum) = false -@inline NAN_CHECK(x::AbstractArray) = any(NAN_CHECK, x) +@inline NAN_CHECK(x::Union{AbstractArray, RecursiveArrayTools.AbstractVectorOfArray}) = any(NAN_CHECK, + x) @inline NAN_CHECK(x::RecursiveArrayTools.ArrayPartition) = any(NAN_CHECK, x.x) @inline function NAN_CHECK(x::SparseArrays.AbstractSparseMatrixCSC) any(NAN_CHECK, SparseArrays.nonzeros(x))