diff --git a/Project.toml b/Project.toml index 47666254..17f1ef19 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +[weakdeps] +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" + [compat] ArrayInterface = "6, 7" DataStructures = "0.17, 0.18" @@ -39,6 +42,9 @@ TreeViews = "0.3" UnPack = "1.0.2" julia = "1.6" +[extensions] +JumpProcessFastBroadcastExt = "FastBroadcast" + [extras] DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -48,6 +54,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" [targets] -test = ["DiffEqCallbacks", "LinearAlgebra", "OrdinaryDiffEq", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"] +test = ["DiffEqCallbacks", "LinearAlgebra", "OrdinaryDiffEq", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"] diff --git a/benchmarks/extended_jump_array.jl b/benchmarks/extended_jump_array.jl new file mode 100644 index 00000000..0de459a3 --- /dev/null +++ b/benchmarks/extended_jump_array.jl @@ -0,0 +1,27 @@ +# This file is not directly included in a test case, but is used to +# benchmark and compare changes to the broadcsting +using JumpProcesses, StableRNGs, FastBroadcast, BenchmarkTools + +rng = StableRNG(123) + +base_case_out = zeros(500000 * 2) +base_case_in = rand(rng, 500000 * 2) +benchmark_out = ExtendedJumpArray(zeros(500000), zeros(500000)) +benchmark_in = ExtendedJumpArray(rand(rng, 500000), rand(rng, 500000)) + +function test_single_dot(out, array) + @inbounds @. out = array + 1.23 * array +end + +function test_double_dot(out, array) + @inbounds @.. out = array + 1.23 * array +end + +println("Base-case normal broadcasting") +@benchmark test_single_dot(base_case_out, base_case_in) +println("EJA normal broadcasting") +@benchmark test_single_dot(benchmark_out, benchmark_in) +println("Base-case fast broadcasting") +@benchmark test_double_dot(base_case_out, base_case_in) +println("EJA fast broadcasting") +@benchmark test_double_dot(benchmark_out, benchmark_in) diff --git a/ext/JumpProcessFastBroadcastExt.jl b/ext/JumpProcessFastBroadcastExt.jl new file mode 100644 index 00000000..bdec9167 --- /dev/null +++ b/ext/JumpProcessFastBroadcastExt.jl @@ -0,0 +1,33 @@ +module JumpProcessFastBroadcastExt + +using JumpProcesses, FastBroadcast + +@inline function FastBroadcast.fast_materialize!(::FastBroadcast.False, ::DB, dst::EJA, + bc::Base.Broadcast.Broadcasted{S}) where { + S, + DB, + EJA <: + ExtendedJumpArray + } + FastBroadcast.fast_materialize!(FastBroadcast.False(), DB(), dst.u, + JumpProcesses.repack(bc, Val(:u))) + FastBroadcast.fast_materialize!(FastBroadcast.False(), DB(), dst.jump_u, + JumpProcesses.repack(bc, Val(:jump_u))) + dst +end + +@inline function FastBroadcast.fast_materialize!(::FastBroadcast.True, ::DB, dst::EJA, + bc::Base.Broadcast.Broadcasted{S}) where { + S, + DB, + EJA <: + ExtendedJumpArray + } + FastBroadcast.fast_materialize!(FastBroadcast.True(), DB(), dst.u, + JumpProcesses.repack(bc, Val(:u))) + FastBroadcast.fast_materialize!(FastBroadcast.True(), DB(), dst.jump_u, + JumpProcesses.repack(bc, Val(:jump_u))) + dst +end + +end # module JumpProcessFastBroadcastExt diff --git a/src/extended_jump_array.jl b/src/extended_jump_array.jl index c7c35a87..df7b9218 100644 --- a/src/extended_jump_array.jl +++ b/src/extended_jump_array.jl @@ -119,82 +119,95 @@ plot_indices(A::ExtendedJumpArray) = eachindex(A.u) ## broadcasting -struct ExtendedJumpArrayStyle{Style <: Broadcast.BroadcastStyle} <: - Broadcast.AbstractArrayStyle{Any} end -ExtendedJumpArrayStyle(::S) where {S} = ExtendedJumpArrayStyle{S}() -ExtendedJumpArrayStyle(::S, ::Val{N}) where {S, N} = ExtendedJumpArrayStyle(S(Val(N))) -function ExtendedJumpArrayStyle(::Val{N}) where {N} - ExtendedJumpArrayStyle{Broadcast.DefaultArrayStyle{N}}() -end - -# promotion rules -@inline function Broadcast.BroadcastStyle(::ExtendedJumpArrayStyle{AStyle}, - ::ExtendedJumpArrayStyle{BStyle}) where {AStyle, - BStyle} - ExtendedJumpArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle())) -end -function Broadcast.BroadcastStyle(::ExtendedJumpArrayStyle{Style}, - ::Broadcast.DefaultArrayStyle{0}) where { - Style <: - Broadcast.BroadcastStyle - } - ExtendedJumpArrayStyle{Style}() -end -function Broadcast.BroadcastStyle(::ExtendedJumpArrayStyle, - ::Broadcast.DefaultArrayStyle{N}) where {N} +# The jump array styles stores two sub-styles in the type, +# one for the `u` array and one for the `jump_u` array +struct ExtendedJumpArrayStyle{UStyle <: Broadcast.BroadcastStyle, + JumpUStyle <: Broadcast.BroadcastStyle} <: + Broadcast.BroadcastStyle end +# Init style based on type of u/jump_u +function ExtendedJumpArrayStyle(::US, ::JumpUS) where {US, JumpUS} + ExtendedJumpArrayStyle{US, JumpUS}() +end +function Base.BroadcastStyle(::Type{ExtendedJumpArray{T3, T1, UType, JumpUType}}) where {T3, + T1, + UType, + JumpUType + } + ExtendedJumpArrayStyle(Base.BroadcastStyle(UType), Base.BroadcastStyle(JumpUType)) +end + +# Combine with other styles by combining individually with u/jump_u styles +function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, + ::Style) where {UStyle, JumpUStyle, + Style <: Base.Broadcast.BroadcastStyle} + ExtendedJumpArrayStyle(Broadcast.result_style(UStyle(), Style()), + Broadcast.result_style(JumpUStyle(), Style())) +end + +# Decay back to the DefaultArrayStyle for higher-order default styles, to support adding to raw vectors as needed +function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, + ::Broadcast.DefaultArrayStyle{0}) where {UStyle, JumpUStyle} + ExtendedJumpArrayStyle(UStyle(), JumpUStyle()) +end + +function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle}, + ::Broadcast.DefaultArrayStyle{N}) where {N, UStyle, JumpUStyle} Broadcast.DefaultArrayStyle{N}() end -combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}() -@inline function combine_styles(args::Tuple{Any}) - Broadcast.result_style(Broadcast.BroadcastStyle(args[1])) +# Lookup the first ExtendedJumpArray to pick output container size +"`A = find_eja(args)` returns the first ExtendedJumpArray among the arguments." +find_eja(bc::Base.Broadcast.Broadcasted) = find_eja(bc.args) +find_eja(args::Tuple) = find_eja(find_eja(args[1]), Base.tail(args)) +find_eja(x) = x +find_eja(::Tuple{}) = nothing +find_eja(a::ExtendedJumpArray, rest) = a +find_eja(::Any, rest) = find_eja(rest) + +function Base.similar(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, + ::Type{ElType}) where {US, JumpUS, ElType} + A = find_eja(bc) + ExtendedJumpArray(similar(A.u, ElType), similar(A.jump_u, ElType)) end -@inline function combine_styles(args::Tuple{Any, Any}) - Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), - Broadcast.BroadcastStyle(args[2])) + +# Helper functions that repack broadcasted functions +@inline function repack(bc::Broadcast.Broadcasted{Style}, i) where {Style} + Broadcast.Broadcasted{Style}(bc.f, repack_args(i, bc.args)) end -@inline function combine_styles(args::Tuple) - Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), - combine_styles(Base.tail(args))) +@inline function repack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, + i::Val{:u}) where {US, JumpUS} + Broadcast.Broadcasted{US}(bc.f, repack_args(i, bc.args)) end - -function Broadcast.BroadcastStyle(::Type{ExtendedJumpArray{T, S}}) where {T, S} - ExtendedJumpArrayStyle(Broadcast.result_style(Broadcast.BroadcastStyle(T))) +@inline function repack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}, + i::Val{:jump_u}) where {US, JumpUS} + Broadcast.Broadcasted{JumpUS}(bc.f, repack_args(i, bc.args)) end -@inline function Base.copy(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{Style}}) where { - Style - } - ExtendedJumpArray(copy(unpack(bc, Val(:u))), copy(unpack(bc, Val(:jump_u)))) +# Helper functions that repack arguments +@inline repack(x, ::Any) = x +@inline repack(x::ExtendedJumpArray, ::Val{:u}) = x.u +@inline repack(x::ExtendedJumpArray, ::Val{:jump_u}) = x.jump_u + +# Repack args with generated function to do this in a type-stable way without recursion +@generated function repack_args(extract_symbol, args::NTuple{N, Any}) where {N} + # Extract over the arg tuple + extracted_args = [:(repack(args[$i], extract_symbol)) for i in 1:N] + # Splat extracted args to another args tuple + return quote + ($(extracted_args...),) + end end @inline function Base.copyto!(dest::ExtendedJumpArray, - bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{Style}}) where { - Style - } - copyto!(dest.u, unpack(bc, Val(:u))) - copyto!(dest.jump_u, unpack(bc, Val(:jump_u))) + bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}) where { + US, + JumpUS + } + copyto!(dest.u, repack(bc, Val(:u))) + copyto!(dest.jump_u, repack(bc, Val(:jump_u))) dest end -# drop axes because it is easier to recompute -@inline function unpack(bc::Broadcast.Broadcasted{Style}, i) where {Style} - Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) -end -@inline function unpack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{Style}}, - i) where {Style} - Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args)) -end -unpack(x, ::Any) = x -unpack(x::ExtendedJumpArray, ::Val{:u}) = x.u -unpack(x::ExtendedJumpArray, ::Val{:jump_u}) = x.jump_u - -@inline function unpack_args(i, args::Tuple) - (unpack(args[1], i), unpack_args(i, Base.tail(args))...) -end -unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),) -unpack_args(::Any, args::Tuple{}) = () - Base.:*(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(y .* x.u, y .* x.jump_u) Base.:*(y::Number, x::ExtendedJumpArray) = ExtendedJumpArray(y .* x.u, y .* x.jump_u) Base.:/(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(x.u ./ y, x.jump_u ./ y) diff --git a/test/extended_jump_array.jl b/test/extended_jump_array.jl index 065af9ce..53f647b4 100644 --- a/test/extended_jump_array.jl +++ b/test/extended_jump_array.jl @@ -1,4 +1,4 @@ -using Test, JumpProcesses, DiffEqBase +using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq using StableRNGs rng = StableRNG(123) @@ -24,3 +24,62 @@ new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0) norm_check_alloc(jump_array, t) = @allocated DiffEqBase.ODE_DEFAULT_NORM(jump_array, t) norm_check_alloc(rand_array, 0.0); @test 0 == norm_check_alloc(rand_array, 0.0) + +## Broadcasting +bc_eja_1 = ExtendedJumpArray(rand(rng, 10), rand(rng, 2)) +bc_eja_2 = ExtendedJumpArray(rand(rng, 10), rand(rng, 2)) +bc_out = ExtendedJumpArray(zeros(10), zeros(2)) + +# Test that broadcasting gives the same output as non-broadcasted math +@test bc_eja_1 + bc_eja_2 ≈ bc_eja_1 .+ bc_eja_2 +@test 3.14 * bc_eja_1 + 2.7 * bc_eja_2 ≈ 3.14 .* bc_eja_1 .+ 2.7 .* bc_eja_2 + +# Test that non-allocating (copyto!) gives the same result, both w/ and w/o the dot macro +bc_out .= 3.14 .* bc_eja_1 + 2.7 .* bc_eja_2 +@test bc_out ≈ 3.14 * bc_eja_1 + 2.7 * bc_eja_2 +@. bc_out = 3.14 * bc_eja_1 + 2.7 * bc_eja_2 +@test bc_out ≈ 3.14 * bc_eja_1 + 2.7 * bc_eja_2 + +# Test that mismatched arrays cannot be broadcasted +bc_mismatch = ExtendedJumpArray(rand(rng, 8), rand(rng, 4)) +@test_throws DimensionMismatch bc_mismatch+bc_eja_1 +@test_throws DimensionMismatch bc_mismatch.+bc_eja_1 + +# Test that datatype mixing persists through broadcasting +bc_dtype_1 = ExtendedJumpArray(rand(rng, 10), rand(rng, 1:10, 2)) +bc_dtype_2 = ExtendedJumpArray(rand(rng, 10), rand(rng, 1:10, 2)) +result = bc_dtype_1 + bc_dtype_2 * 2 +@test eltype(result.jump_u) == Int64 +out_result = ExtendedJumpArray(zeros(10), zeros(2)) +out_result .= bc_dtype_1 .+ bc_dtype_2 .* 2 +@test eltype(result.jump_u) == Int64 +@test out_result ≈ result + +# Test that fast broadcasting also gives the correct results +using FastBroadcast +@.. bc_out = 3.14 * bc_eja_1 + 2.7 * bc_eja_2 +@test bc_out ≈ 3.14 * bc_eja_1 + 2.7 * bc_eja_2 + +# Test both the in-place and allocating problems (https://github.com/SciML/JumpProcesses.jl/issues/321) +# to check that an ExtendedJumpArray is not getting downgraded into a Vector +oop_test_rate(u, p, t) = exp(t) +function oop_test_affect!(integrator) + integrator.u[1] += 1 + nothing +end +oop_test_jump = VariableRateJump(oop_test_rate, oop_test_affect!) + +# Test in-place +u₀ = [0.0] +inplace_prob = ODEProblem((du, u, p, t) -> (du .= 0), u₀, (0.0, 2.0), nothing) +jump_prob = JumpProblem(inplace_prob, Direct(), oop_test_jump) +sol = solve(jump_prob, Tsit5()) +@test sol.retcode == ReturnCode.Success +sol.u + +# Test out-of-place +u₀ = [0.0] +oop_prob = ODEProblem((u, p, t) -> [0.0], u₀, (0.0, 2.0), nothing) # only difference is use of OOP ode function +jump_prob = JumpProblem(oop_prob, Direct(), oop_test_jump) +sol = solve(jump_prob, Tsit5()) +@test sol.retcode == ReturnCode.Success