-
-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficency of ExtendedJumpArray broadcasting in ode_interpolant #335
Comments
I don't think we'd want to add ODE method dependent dispatches. That would be a nightmare to put together and maintain. Dispatches that are not ODE-method dependent would be fine (i.e. things that are used by all / many methods like your last PR, so more dispatches on things in DiffEqBase, SciMLBase, or Julia's Base). But maybe there is a ExtendedJumpArray broadcast feature that just isn't implemented and if dispatched would fix the issue here? (I really have no familiarity with the broadcast API unfortunately, so can't help with suggestions. Maybe @ChrisRackauckas has some suggestions.)
|
https://github.com/YingboMa/FastBroadcast.jl/blob/master/src/FastBroadcast.jl#L19 Try overriding this to be false so it doesn't use the linear indexing?
Yes, it should just be one thing about broadcast lowering |
Overriding that to false doesn't seem to solve the problem; there's still a bunch of calls to I've been thinking through the broadcast interface, it feels like what we want is a broadcast style that ends up flattening the broadcast kernel, then applies it on both I'll update if I make any progress. Right now my debugging frustratingly seems to show that on broadcast operations it's not using the defined
shows:
which is actually just wrong; this should not just return a Vector{Float64}. I don't know why it's doing this, but I checked that this happens even with a clean Julia environment (only Running |
This means it's hitting the fallback broadcast style, and when it's using the AbstractArray style then it defaults to returning an The thing to look at would probably be https://github.com/jonniedie/ComponentArrays.jl. It solves this kinds of problems quite nicely, so we may just want to lift some of its broadcast implementation. I've also considered completely removing the ExtendedJumpArray and just using a ComponentArray. Though the issue is that it's somewhat magical to the user if they solve using a an |
Thanks for the reference! It looks like ComponentArrays actually has very little broadcasting magic, but a much more advanced indexing magic. I'll see if there is a straightforward way to get the ExtendedJumpArray to actually use the broadcasting code that has already been written, otherwise I'll see if there is some indexing magic to use. |
I think the fact that ComponentArrays forces everything to have the same type, thus you cannot end up in the situation where |
Yeah, I did test out getting broadcasting working and ran into something similar. The current broadcast code wasn't actually doing anything, but once I added two more interface functions in
I might try out the hacky thing first of just directly calling |
I haven't figured out a solution to this so far though I'd like to; even a simple adding two Notes so far:
At the very least, I'm trying to address 1) and 2). There are currently correctness issues with the current fallback to
For 3), @ChrisRackauckas, do you think it's ok to add a dependency on |
I did it, at least for normal broadcasting! I should have a PR in today. I rewrote the broadcast rules and slightly changed how broadcast repacking works. I'm working on the FastBroadcast overloads right now, confirming that they work as expected. For the benchmark, I compare just against a linear vector. The old fallback mechanism used to be ~3-5x as slow, but now it is effectively the same. Checking with Cthulhu shows that efficient simd instructions are being emitted now.
|
I’d be fine with adding a dependency on FastBroadcast if that enables you to fix the issue. Please go ahead and add it if needed. |
Background
Continuing to optimize a system with VariableRateJumps and callbacks, I've found that performing ODE interpolations on ExtendedJumpArray's is consuming around ~80% of the total runtime of the solve. About 10% is actually in the various
do_step
backtraces, with most of the runtime is in thefind_first_continuous_callback
call. Of that, the vast majority is happening inBase.getindex(A:ExtendedJumpArray)
Problem
In particular, it looks like the combination of
ExtendedJumpArray
's,@muladd
, andFastBroadcast
is causing unoptimal code generation. For my use case, the problematic function is here:but it seems to affect others as well.
Examining this with Cthulhu and friends, looking at the LLVM and native code, etc etc shows an absolute morass of branching, with something like 4800 (!) branches in the assembly. The branches are coming from inlining where it's using the ExtendedJumpArray index check. This churn means that the compiler doesn't seem to be able to loop unroll, since it has to repeatedly check "is index < length(jump_array.u)", over and over and over.
Solution
I'm guessing that this would be a lot faster and not be churning the branch predictor with a bunch of mostly-useless (<=) calls to switch between the
.u
and.jump_u
members, if we could somehow turn these calls into separate calls likeI fixed this by adding this dispatch to my usercode (I unpacked one of the macros because getting the imports right was annoying, but it's just the original ode_interpolant with the
.u
and.jump_u
parts unrolled):This dramatically reduces the runtime, and inspecting the LLVM code shows "only" 123 branches with good loop unrolling. There's still a boat-load of allocations happening in
handle_callbacks
, but I'll deal with that with a separate issue/PR.Questions
@..
macro to do this type of code generation automatically?OrdinaryDiffEq
.The text was updated successfully, but these errors were encountered: