Skip to content

Commit

Permalink
Use seed! to put every copy of rng into a unique state
Browse files Browse the repository at this point in the history
Using `rand(_rng, i)` didn't really put all copies of `rng` into a
unique state, the states were still interlocked (all the generators
produced same sequence of random numbers with some offset). Calling `
seed!` with a deterministic, pseudo-random seed for each thread produces
much better results, which is also visible in the classification and
regression accuracies produced by the tests.
  • Loading branch information
dhanak committed Nov 28, 2022
1 parent c199322 commit 10843eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
6 changes: 2 additions & 4 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,10 @@ function build_forest(
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)

if rng isa Random.AbstractRNG
shared_seed = rand(rng, UInt)
Threads.@threads for i in 1:n_trees
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree. This
# is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
_rng = Random.seed!(copy(rng), shared_seed + i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
Expand Down
6 changes: 2 additions & 4 deletions src/regression/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,10 @@ function build_forest(
forest = impurity_importance ? Vector{Root{S, T}}(undef, n_trees) : Vector{LeafOrNode{S, T}}(undef, n_trees)

if rng isa Random.AbstractRNG
shared_seed = rand(rng, UInt)
Threads.@threads for i in 1:n_trees
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree.
# This is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
_rng = Random.seed!(copy(rng), shared_seed + i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
Expand Down

0 comments on commit 10843eb

Please sign in to comment.