Skip to content

Commit

Permalink
refactor: use new RecursiveArrayTools
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 27, 2023
1 parent 78074f3 commit 0ae3915
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/common_defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ end
@inline function UNITLESS_ABS2(x::AbstractArray)
mapreduce(UNITLESS_ABS2, abs2_and_sum, x, init = zero(real(value(eltype(x)))))
end
@inline function UNITLESS_ABS2(x::RecursiveArrayTools.AbstractVectorOfArray)
mapreduce(UNITLESS_ABS2, abs2_and_sum, x.u, init = zero(real(value(eltype(x)))))
end
@inline function UNITLESS_ABS2(x::RecursiveArrayTools.ArrayPartition)
mapreduce(UNITLESS_ABS2, abs2_and_sum, x.x, init = zero(real(value(eltype(x)))))
end
Expand Down Expand Up @@ -37,7 +40,11 @@ end
Base.FastMath.sqrt_fast(real(sum(abs2, u)) / max(length(u), 1))
end

@inline function ODE_DEFAULT_NORM(u::AbstractArray, t)
@inline function ODE_DEFAULT_NORM(u::Union{
AbstractArray,
RecursiveArrayTools.AbstractVectorOfArray,
},
t)
Base.FastMath.sqrt_fast(UNITLESS_ABS2(u) / max(recursive_length(u), 1))
end
@inline ODE_DEFAULT_NORM(u, t) = norm(u)
Expand All @@ -57,7 +64,8 @@ end
@inline NAN_CHECK(x::Number) = isnan(x)
@inline NAN_CHECK(x::Float64) = isnan(x) || (x > 1e50)
@inline NAN_CHECK(x::Enum) = false
@inline NAN_CHECK(x::AbstractArray) = any(NAN_CHECK, x)
@inline NAN_CHECK(x::Union{AbstractArray, RecursiveArrayTools.AbstractVectorOfArray}) = any(NAN_CHECK,
x)
@inline NAN_CHECK(x::RecursiveArrayTools.ArrayPartition) = any(NAN_CHECK, x.x)
@inline function NAN_CHECK(x::SparseArrays.AbstractSparseMatrixCSC)
any(NAN_CHECK, SparseArrays.nonzeros(x))
Expand Down

0 comments on commit 0ae3915

Please sign in to comment.