Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-default broadcasting of ExtendedJumpArray's. #340

Merged
merged 3 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"

[compat]
ArrayInterface = "6, 7"
DataStructures = "0.17, 0.18"
Expand All @@ -39,6 +42,9 @@ TreeViews = "0.3"
UnPack = "1.0.2"
julia = "1.6"

[extensions]
JumpProcessFastBroadcastExt = "FastBroadcast"

[extras]
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -48,6 +54,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"

[targets]
test = ["DiffEqCallbacks", "LinearAlgebra", "OrdinaryDiffEq", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"]
test = ["DiffEqCallbacks", "LinearAlgebra", "OrdinaryDiffEq", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"]
27 changes: 27 additions & 0 deletions benchmarks/extended_jump_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This file is not directly included in a test case, but is used to
# benchmark and compare changes to the broadcsting
using JumpProcesses, StableRNGs, FastBroadcast, BenchmarkTools

rng = StableRNG(123)

base_case_out = zeros(500000 * 2)
base_case_in = rand(rng, 500000 * 2)
benchmark_out = ExtendedJumpArray(zeros(500000), zeros(500000))
benchmark_in = ExtendedJumpArray(rand(rng, 500000), rand(rng, 500000))

function test_single_dot(out, array)
@inbounds @. out = array + 1.23 * array
end

function test_double_dot(out, array)
@inbounds @.. out = array + 1.23 * array
end

println("Base-case normal broadcasting")
@benchmark test_single_dot(base_case_out, base_case_in)
println("EJA normal broadcasting")
@benchmark test_single_dot(benchmark_out, benchmark_in)
println("Base-case fast broadcasting")
@benchmark test_double_dot(base_case_out, base_case_in)
println("EJA fast broadcasting")
@benchmark test_double_dot(benchmark_out, benchmark_in)
33 changes: 33 additions & 0 deletions ext/JumpProcessFastBroadcastExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module JumpProcessFastBroadcastExt

using JumpProcesses, FastBroadcast

@inline function FastBroadcast.fast_materialize!(::FastBroadcast.False, ::DB, dst::EJA,
bc::Base.Broadcast.Broadcasted{S}) where {
S,
DB,
EJA <:
ExtendedJumpArray
}
FastBroadcast.fast_materialize!(FastBroadcast.False(), DB(), dst.u,
JumpProcesses.repack(bc, Val(:u)))
FastBroadcast.fast_materialize!(FastBroadcast.False(), DB(), dst.jump_u,
JumpProcesses.repack(bc, Val(:jump_u)))
dst
end

@inline function FastBroadcast.fast_materialize!(::FastBroadcast.True, ::DB, dst::EJA,
bc::Base.Broadcast.Broadcasted{S}) where {
S,
DB,
EJA <:
ExtendedJumpArray
}
FastBroadcast.fast_materialize!(FastBroadcast.True(), DB(), dst.u,
JumpProcesses.repack(bc, Val(:u)))
FastBroadcast.fast_materialize!(FastBroadcast.True(), DB(), dst.jump_u,
JumpProcesses.repack(bc, Val(:jump_u)))
dst
end

end # module JumpProcessFastBroadcastExt
137 changes: 75 additions & 62 deletions src/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,82 +119,95 @@ plot_indices(A::ExtendedJumpArray) = eachindex(A.u)

## broadcasting

struct ExtendedJumpArrayStyle{Style <: Broadcast.BroadcastStyle} <:
Broadcast.AbstractArrayStyle{Any} end
ExtendedJumpArrayStyle(::S) where {S} = ExtendedJumpArrayStyle{S}()
ExtendedJumpArrayStyle(::S, ::Val{N}) where {S, N} = ExtendedJumpArrayStyle(S(Val(N)))
function ExtendedJumpArrayStyle(::Val{N}) where {N}
ExtendedJumpArrayStyle{Broadcast.DefaultArrayStyle{N}}()
end

# promotion rules
@inline function Broadcast.BroadcastStyle(::ExtendedJumpArrayStyle{AStyle},
::ExtendedJumpArrayStyle{BStyle}) where {AStyle,
BStyle}
ExtendedJumpArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
end
function Broadcast.BroadcastStyle(::ExtendedJumpArrayStyle{Style},
::Broadcast.DefaultArrayStyle{0}) where {
Style <:
Broadcast.BroadcastStyle
}
ExtendedJumpArrayStyle{Style}()
end
function Broadcast.BroadcastStyle(::ExtendedJumpArrayStyle,
::Broadcast.DefaultArrayStyle{N}) where {N}
# The jump array styles stores two sub-styles in the type,
# one for the `u` array and one for the `jump_u` array
struct ExtendedJumpArrayStyle{UStyle <: Broadcast.BroadcastStyle,
JumpUStyle <: Broadcast.BroadcastStyle} <:
Broadcast.BroadcastStyle end
# Init style based on type of u/jump_u
function ExtendedJumpArrayStyle(::US, ::JumpUS) where {US, JumpUS}
ExtendedJumpArrayStyle{US, JumpUS}()
end
function Base.BroadcastStyle(::Type{ExtendedJumpArray{T3, T1, UType, JumpUType}}) where {T3,
T1,
UType,
JumpUType
}
ExtendedJumpArrayStyle(Base.BroadcastStyle(UType), Base.BroadcastStyle(JumpUType))
end

# Combine with other styles by combining individually with u/jump_u styles
function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle},
::Style) where {UStyle, JumpUStyle,
Style <: Base.Broadcast.BroadcastStyle}
ExtendedJumpArrayStyle(Broadcast.result_style(UStyle(), Style()),
Broadcast.result_style(JumpUStyle(), Style()))
end

# Decay back to the DefaultArrayStyle for higher-order default styles, to support adding to raw vectors as needed
function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle},
::Broadcast.DefaultArrayStyle{0}) where {UStyle, JumpUStyle}
ExtendedJumpArrayStyle(UStyle(), JumpUStyle())
end

function Base.BroadcastStyle(::ExtendedJumpArrayStyle{UStyle, JumpUStyle},
::Broadcast.DefaultArrayStyle{N}) where {N, UStyle, JumpUStyle}
Broadcast.DefaultArrayStyle{N}()
end

combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
@inline function combine_styles(args::Tuple{Any})
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
# Lookup the first ExtendedJumpArray to pick output container size
"`A = find_eja(args)` returns the first ExtendedJumpArray among the arguments."
find_eja(bc::Base.Broadcast.Broadcasted) = find_eja(bc.args)
find_eja(args::Tuple) = find_eja(find_eja(args[1]), Base.tail(args))
find_eja(x) = x
find_eja(::Tuple{}) = nothing
find_eja(a::ExtendedJumpArray, rest) = a
find_eja(::Any, rest) = find_eja(rest)

function Base.similar(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}},
::Type{ElType}) where {US, JumpUS, ElType}
A = find_eja(bc)
ExtendedJumpArray(similar(A.u, ElType), similar(A.jump_u, ElType))
end
@inline function combine_styles(args::Tuple{Any, Any})
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
Broadcast.BroadcastStyle(args[2]))

# Helper functions that repack broadcasted functions
@inline function repack(bc::Broadcast.Broadcasted{Style}, i) where {Style}
Broadcast.Broadcasted{Style}(bc.f, repack_args(i, bc.args))
end
@inline function combine_styles(args::Tuple)
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
combine_styles(Base.tail(args)))
@inline function repack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}},
i::Val{:u}) where {US, JumpUS}
Broadcast.Broadcasted{US}(bc.f, repack_args(i, bc.args))
end

function Broadcast.BroadcastStyle(::Type{ExtendedJumpArray{T, S}}) where {T, S}
ExtendedJumpArrayStyle(Broadcast.result_style(Broadcast.BroadcastStyle(T)))
@inline function repack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}},
i::Val{:jump_u}) where {US, JumpUS}
Broadcast.Broadcasted{JumpUS}(bc.f, repack_args(i, bc.args))
end

@inline function Base.copy(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{Style}}) where {
Style
}
ExtendedJumpArray(copy(unpack(bc, Val(:u))), copy(unpack(bc, Val(:jump_u))))
# Helper functions that repack arguments
@inline repack(x, ::Any) = x
@inline repack(x::ExtendedJumpArray, ::Val{:u}) = x.u
@inline repack(x::ExtendedJumpArray, ::Val{:jump_u}) = x.jump_u

# Repack args with generated function to do this in a type-stable way without recursion
@generated function repack_args(extract_symbol, args::NTuple{N, Any}) where {N}
# Extract over the arg tuple
extracted_args = [:(repack(args[$i], extract_symbol)) for i in 1:N]
# Splat extracted args to another args tuple
return quote
($(extracted_args...),)
end
end

@inline function Base.copyto!(dest::ExtendedJumpArray,
bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{Style}}) where {
Style
}
copyto!(dest.u, unpack(bc, Val(:u)))
copyto!(dest.jump_u, unpack(bc, Val(:jump_u)))
bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{US, JumpUS}}) where {
US,
JumpUS
}
copyto!(dest.u, repack(bc, Val(:u)))
copyto!(dest.jump_u, repack(bc, Val(:jump_u)))
dest
end

# drop axes because it is easier to recompute
@inline function unpack(bc::Broadcast.Broadcasted{Style}, i) where {Style}
Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
end
@inline function unpack(bc::Broadcast.Broadcasted{ExtendedJumpArrayStyle{Style}},
i) where {Style}
Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
end
unpack(x, ::Any) = x
unpack(x::ExtendedJumpArray, ::Val{:u}) = x.u
unpack(x::ExtendedJumpArray, ::Val{:jump_u}) = x.jump_u

@inline function unpack_args(i, args::Tuple)
(unpack(args[1], i), unpack_args(i, Base.tail(args))...)
end
unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),)
unpack_args(::Any, args::Tuple{}) = ()

Base.:*(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(y .* x.u, y .* x.jump_u)
Base.:*(y::Number, x::ExtendedJumpArray) = ExtendedJumpArray(y .* x.u, y .* x.jump_u)
Base.:/(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(x.u ./ y, x.jump_u ./ y)
Expand Down
61 changes: 60 additions & 1 deletion test/extended_jump_array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test, JumpProcesses, DiffEqBase
using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq
using StableRNGs

rng = StableRNG(123)
Expand All @@ -24,3 +24,62 @@ new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0)
norm_check_alloc(jump_array, t) = @allocated DiffEqBase.ODE_DEFAULT_NORM(jump_array, t)
norm_check_alloc(rand_array, 0.0);
@test 0 == norm_check_alloc(rand_array, 0.0)

## Broadcasting
bc_eja_1 = ExtendedJumpArray(rand(rng, 10), rand(rng, 2))
bc_eja_2 = ExtendedJumpArray(rand(rng, 10), rand(rng, 2))
bc_out = ExtendedJumpArray(zeros(10), zeros(2))

# Test that broadcasting gives the same output as non-broadcasted math
@test bc_eja_1 + bc_eja_2 ≈ bc_eja_1 .+ bc_eja_2
@test 3.14 * bc_eja_1 + 2.7 * bc_eja_2 ≈ 3.14 .* bc_eja_1 .+ 2.7 .* bc_eja_2

# Test that non-allocating (copyto!) gives the same result, both w/ and w/o the dot macro
bc_out .= 3.14 .* bc_eja_1 + 2.7 .* bc_eja_2
@test bc_out ≈ 3.14 * bc_eja_1 + 2.7 * bc_eja_2
@. bc_out = 3.14 * bc_eja_1 + 2.7 * bc_eja_2
@test bc_out ≈ 3.14 * bc_eja_1 + 2.7 * bc_eja_2

# Test that mismatched arrays cannot be broadcasted
bc_mismatch = ExtendedJumpArray(rand(rng, 8), rand(rng, 4))
@test_throws DimensionMismatch bc_mismatch+bc_eja_1
@test_throws DimensionMismatch bc_mismatch.+bc_eja_1

# Test that datatype mixing persists through broadcasting
bc_dtype_1 = ExtendedJumpArray(rand(rng, 10), rand(rng, 1:10, 2))
bc_dtype_2 = ExtendedJumpArray(rand(rng, 10), rand(rng, 1:10, 2))
result = bc_dtype_1 + bc_dtype_2 * 2
@test eltype(result.jump_u) == Int64
out_result = ExtendedJumpArray(zeros(10), zeros(2))
out_result .= bc_dtype_1 .+ bc_dtype_2 .* 2
@test eltype(result.jump_u) == Int64
@test out_result ≈ result

# Test that fast broadcasting also gives the correct results
using FastBroadcast
@.. bc_out = 3.14 * bc_eja_1 + 2.7 * bc_eja_2
@test bc_out ≈ 3.14 * bc_eja_1 + 2.7 * bc_eja_2

# Test both the in-place and allocating problems (https://github.com/SciML/JumpProcesses.jl/issues/321)
# to check that an ExtendedJumpArray is not getting downgraded into a Vector
oop_test_rate(u, p, t) = exp(t)
function oop_test_affect!(integrator)
integrator.u[1] += 1
nothing
end
oop_test_jump = VariableRateJump(oop_test_rate, oop_test_affect!)

# Test in-place
u₀ = [0.0]
inplace_prob = ODEProblem((du, u, p, t) -> (du .= 0), u₀, (0.0, 2.0), nothing)
jump_prob = JumpProblem(inplace_prob, Direct(), oop_test_jump)
sol = solve(jump_prob, Tsit5())
@test sol.retcode == ReturnCode.Success
sol.u

# Test out-of-place
u₀ = [0.0]
oop_prob = ODEProblem((u, p, t) -> [0.0], u₀, (0.0, 2.0), nothing) # only difference is use of OOP ode function
jump_prob = JumpProblem(oop_prob, Direct(), oop_test_jump)
sol = solve(jump_prob, Tsit5())
@test sol.retcode == ReturnCode.Success