From 0098ddc79d6139cb0d8c835bba476a81a78f1f28 Mon Sep 17 00:00:00 2001 From: apkille Date: Thu, 8 Aug 2024 15:45:36 -0400 Subject: [PATCH 1/2] broadcast Base.:(/) for non-AbstractArray --- src/solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solve.jl b/src/solve.jl index d06354c8..98ac1031 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -283,7 +283,7 @@ function DiffEqBase.__init( uprev = recursivecopy(u) if !(uType <: AbstractArray) - rand_prototype = zero(u/u) # Strip units and type info + rand_prototype = zero(u ./ u) # Strip units and type info randType = typeof(rand_prototype) else randElType = uBottomEltypeNoUnits # Strip units and type info From c5750221fc7099cf947d104df45b97d2f49156de Mon Sep 17 00:00:00 2001 From: apkille Date: Thu, 8 Aug 2024 15:52:33 -0400 Subject: [PATCH 2/2] add test --- test/noindex_tests.jl | 88 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 1 deletion(-) diff --git a/test/noindex_tests.jl b/test/noindex_tests.jl index b98512a2..31223ddb 100644 --- a/test/noindex_tests.jl +++ b/test/noindex_tests.jl @@ -1,4 +1,5 @@ -using StochasticDiffEq, Test, Random, DiffEqNoiseProcess +using StochasticDiffEq, Test, Random, DiffEqNoiseProcess, + RecursiveArrayTools, LinearAlgebra Random.seed!(100) struct NoIndexArray{T, N} <: AbstractArray{T, N} @@ -50,3 +51,88 @@ for alg in algs @test_nowarn sol(0.1) @test_nowarn sol(similar(prob.u0), 0.1) end + + +struct CustomArray{T, N} + x::Array{T, N} +end +Base.size(x::CustomArray) = size(x.x) +Base.axes(x::CustomArray) = axes(x.x) +Base.ndims(x::CustomArray) = ndims(x.x) +Base.ndims(::Type{<:CustomArray{T,N}}) where {T,N} = N +Base.zero(x::CustomArray) = CustomArray(zero(x.x)) +Base.zero(::Type{<:CustomArray{T,N}}) where {T,N} = CustomArray(zero(Array{T,N})) +Base.similar(x::CustomArray, dims::Union{Integer, AbstractUnitRange}...) = CustomArray(similar(x.x, dims...)) +Base.copyto!(x::CustomArray, y::CustomArray) = CustomArray(copyto!(x.x, y.x)) +Base.copy(x::CustomArray) = CustomArray(copy(x.x)) +Base.length(x::CustomArray) = length(x.x) +Base.isempty(x::CustomArray) = isempty(x.x) +Base.eltype(x::CustomArray) = eltype(x.x) +Base.zero(x::CustomArray) = CustomArray(zero(x.x)) +Base.fill!(x::CustomArray, y) = CustomArray(fill!(x.x, y)) +Base.getindex(x::CustomArray, i) = getindex(x.x, i) +Base.setindex!(x::CustomArray, v, idx) = setindex!(x.x, v, idx) +Base.mapreduce(f, op, x::CustomArray; kwargs...) = mapreduce(f, op, x.x; kwargs...) +Base.any(f::Function, x::CustomArray; kwargs...) = any(f, x.x; kwargs...) +Base.all(f::Function, x::CustomArray; kwargs...) = all(f, x.x; kwargs...) +Base.similar(x::CustomArray, t) = CustomArray(similar(x.x, t)) +Base.:(==)(x::CustomArray, y::CustomArray) = x.x == y.x +Base.:(*)(x::Number, y::CustomArray) = CustomArray(x*y.x) +Base.:(/)(x::CustomArray, y::Number) = CustomArray(x.x/y) +LinearAlgebra.norm(x::CustomArray) = norm(x.x) + +struct CustomStyle{N} <: Broadcast.BroadcastStyle where {N} end +CustomStyle(::Val{N}) where N = CustomStyle{N}() +CustomStyle{M}(::Val{N}) where {N,M} = NoIndexStyle{N}() +Base.BroadcastStyle(::Type{<:CustomArray{T,N}}) where {T,N} = CustomStyle{N}() +Broadcast.BroadcastStyle(::CustomStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where {N} = CustomStyle{N}() +Base.similar(bc::Base.Broadcast.Broadcasted{CustomStyle{N}}, ::Type{ElType}) where {N, ElType} = CustomArray(similar(Array{ElType, N}, axes(bc))) +Base.Broadcast._broadcast_getindex(x::CustomArray, i) = x.x[i] +Base.Broadcast.extrude(x::CustomArray) = x +Base.Broadcast.broadcastable(x::CustomArray) = x + +@inline function Base.copyto!(dest::CustomArray, bc::Base.Broadcast.Broadcasted{<:CustomStyle}) + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.x + @simd for I in 1:length(dest′) + @inbounds dest′[I] = bc′[I] + end + return dest +end +@inline function Base.copy(bc::Base.Broadcast.Broadcasted{<:CustomStyle}) + bcf = Broadcast.flatten(bc) + x = find_x(bcf) + data = zeros(eltype(x), size(x)) + @inbounds @simd for I in 1:length(x) + data[I] = bcf[I] + end + return CustomArray(data) +end +find_x(bc::Broadcast.Broadcasted) = find_x(bc.args) +find_x(args::Tuple) = find_x(find_x(args[1]), Base.tail(args)) +find_x(x) = x +find_x(::Any, rest) = find_x(rest) +find_x(x::CustomArray, rest) = x.x + +RecursiveArrayTools.recursive_unitless_bottom_eltype(x::CustomArray) = eltype(x) +RecursiveArrayTools.recursivecopy!(dest::CustomArray, src::CustomArray) = copyto!(dest, src) +RecursiveArrayTools.recursivecopy(x::CustomArray) = copy(x) +RecursiveArrayTools.recursivefill!(x::CustomArray, a) = fill!(x, a) +DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG,rand_vec::CustomArray) = randn!(rng,rand_vec.x) + +Base.show_vector(io::IO, x::CustomArray) = Base.show_vector(io, x.x) + +Base.show(io::IO, x::CustomArray) = (print(io, "CustomArray");show(io, x.x)) +function Base.show(io::IO, ::MIME"text/plain", x::CustomArray) + println(io, Base.summary(x), ":") + Base.print_array(io, x.x) +end + +prob = SDEProblem((du, u, p, t)->copyto!(du, u),(du, u, p, t)->copyto!(du, u), CustomArray(ones(10)), (0.0,1.0)) + +for alg in algs + sol_ca = @test_nowarn solve(prob, alg) + @test_nowarn sol_ca(0.1) + @test_nowarn sol_ca(similar(prob.u0), 0.1) + end \ No newline at end of file