Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seeding ensemble problems #421

Open
kaandocal opened this issue May 14, 2024 · 8 comments
Open

Seeding ensemble problems #421

kaandocal opened this issue May 14, 2024 · 8 comments
Labels

Comments

@kaandocal
Copy link

I have been going through the code & documentation of JumpProcesses in order to find out how to set up reproducible (RNG seeded) ensemble runs. The following snippet yields the same output each time, as intended:

jprob::JumpProblem
eprob = EnsembleProblem(jprob)
    
solve(eprob, SSAStepper(), EnsembleSerial(); trajectories=10, seed=42)

If I replace EnsembleSerial by EnsembleThreads, the outputs stop being reproducible.

I am aware that JumpProblem seems to have an rng field, so one naive solution would be to use the prob_func argument to EnsembleProblem and remake the jump problem with a different seed for each trajectory. That is not directly possible since remake does not accept a seed or rng argument, so I currently recreate the JumpProblem from scratch (which is less than ideal).

There might be a solution to this online, but I haven't found anything, which seems surprising. Is there any better (intended) way of doing this?

@isaacsas
Copy link
Member

rng management and seeding is a bit of a mess in general. Here we pass rngs via the problems and store them in the aggregators, StochasticDiffEq doesn't allow setting an rng in standard usage and still uses Xoroshiro128Plus instead of the newer native Julia rng, and as far as I am aware EnsembleProblems have no mechanism for handling seeding and rngs based on the parallelism mode. (i.e. ideally if one is threading a JumpProblem one would use a thread-compatible rng and initialize it once per-thread or once globally depending on how it works, and then on each thread one would only create one JumpProblem for which solve is called repeatedly based on the per-thread number of simulations to run).

You can see what we do for SSAStepper here:

alias_jump = Threads.threadid() == 1,
saveat = nothing,
callback = nothing,
tstops = eltype(jump_prob.prob.tspan)[],
numsteps_hint = 100)
if !(jump_prob.prob isa DiscreteProblem)
error("SSAStepper only supports DiscreteProblems.")
end
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
if alias_jump
cb = jump_prob.jump_callback.discrete_callbacks[end]
if seed !== nothing
Random.seed!(cb.condition.rng, seed)
end
else
cb = deepcopy(jump_prob.jump_callback.discrete_callbacks[end])
if seed === nothing
Random.seed!(cb.condition.rng, seed_multiplier() * rand(UInt64))
else
Random.seed!(cb.condition.rng, seed)
end
end

If you are in serial alias_jump == true while when using threads it is false. Also note that with threading the whole aggregator is being deep copied every simulation on each thread, instead of just once per thread, which is probably quite inefficient.

Probably we are in need of some kind of uniform SciML-wide approach for handling and storing user-selected rngs that all these libraries could query and use in a common manner.

@isaacsas
Copy link
Member

@ChrisRackauckas would creating a SciML rng interface that works across serial/parallelism modes make sense as a small grant project? Or is it too complicated to expect someone to tackle in that context?

@kaandocal
Copy link
Author

Thanks for the explanation! For reproducibility purposes I think it would be nice to have a consistent RNG interface, amusingly, I only needed to set seeds to debug some weird threading behaviour. Would passing a RNG to solve be an option?

Ideally the threading mode should not affect results, so EnsembleProblem should probably use a different RNG for each trajectory - although I admit this is slower than your proposed solution. I could look into that, but as you mentioned StochasticDiffEq.jl seems to have entirely different RNG handling, which might complicate things...

@isaacsas
Copy link
Member

To make passing via solve work we would need all integrators in OrdinaryDiffEq and StochasticDiffEq to store the rng and have a standard way to get it from them (i.e. we could make it work for SSAStepper but not in a robust way that it would work for all solvers that support jumps).

@TorkelE
Copy link
Member

TorkelE commented May 21, 2024

There is an issue for this over at DiffEq, however, it hasn't really gotten any traction: SciML/DifferentialEquations.jl#1034

@isaacsas
Copy link
Member

A problem with that approach, i.e. user generated per thread/process seeds, is that it isn’t how all (many?) parallel rngs are designed to work. For example, some generators handle setting up uncorrelated streams from a single global seed and knowing the process/thread id. Having users select and pass per thread seeds could actually result in correlated streams for some generators.

(Note, I don’t actually know the recommended way of using the default Julia generator when multi-threading as that is an uncommon workflow for me.)

@TorkelE
Copy link
Member

TorkelE commented May 21, 2024

I think I am relatively agnostic to the actual implementation (as I don't know anything about it), but like @kaandocal I think there is a feature here that should be implemented somehow (possibly using new/old options).

@ChrisRackauckas
Copy link
Member

@ChrisRackauckas would creating a SciML rng interface that works across serial/parallelism modes make sense as a small grant project? Or is it too complicated to expect someone to tackle in that context?

Probably too complicated.

To make passing via solve work we would need all integrators in OrdinaryDiffEq and StochasticDiffEq to store the rng and have a standard way to get it from them (i.e. we could make it work for SSAStepper but not in a robust way that it would work for all solvers that support jumps).

Yes but you'd still want to do it on the prob in the prob_func because otherwise the whole ensmble would have the same seed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants