Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use seed! to put every copy of rng into a unique state #198

Merged
merged 1 commit into from
Nov 29, 2022

Conversation

dhanak
Copy link
Contributor

@dhanak dhanak commented Nov 28, 2022

Fixes #194.

Using rand(_rng, i) didn't really put all copies of rng into a unique state, the states were still interlocked (all the generators produced same sequence of random numbers with some offset). Calling seed! with a deterministic, pseudo-random seed for each thread produces much better results, which is also visible in the classification and regression accuracies produced by the tests.

Here is the diff output of the unit tests before and after the change:

89c89
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.8998998998998999
91c91
< Mean Accuracy: 0.8448448448448449
---
> Mean Accuracy: 0.9019019019019018
93c93
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.8998998998998999
95c95
< Mean Accuracy: 0.8748748748748749
---
> Mean Accuracy: 0.8998998998998999
97c97
< Mean Accuracy: 0.8608608608608609
---
> Mean Accuracy: 0.9039039039039038
99c99
< Mean Accuracy: 0.8608608608608609
---
> Mean Accuracy: 0.9039039039039038
101c101
< Mean Accuracy: 0.8358358358358359
---
> Mean Accuracy: 0.908908908908909
103c103
< Mean Accuracy: 0.8358358358358359
---
> Mean Accuracy: 0.908908908908909
105c105
< Mean Accuracy: 0.8818818818818818
---
> Mean Accuracy: 0.8988988988988988
107c107
< Mean Accuracy: 0.8818818818818818
---
> Mean Accuracy: 0.8988988988988988
109c109
< Mean Accuracy: 0.8788788788788789
---
> Mean Accuracy: 0.9009009009009009
111c111
< Mean Accuracy: 0.8788788788788789
---
> Mean Accuracy: 0.9009009009009009
116,119c116,119
<  6    3    0  0
<  0  121   29  0
<  0   26  136  7
<  0    0    4  1
---
>  5    4    0  0
>  0  128   22  0
>  0    3  165  1
>  0    0    5  0
121,122c121,122
< Accuracy: 0.7927927927927928
< Kappa:    0.6153446948136739
---
> Accuracy: 0.8948948948948949
> Kappa:    0.7995390516158992
127,130c127,130
<  9    1    0  0
<  5  123   15  0
<  0   16  156  2
<  0    0    6  0
---
>  5    5    0  0
>  0  132   11  0
>  0   12  162  0
>  0    0    4  2
132,133c132,133
< Accuracy: 0.8648648648648649
< Kappa:    0.7499123817153157
---
> Accuracy: 0.9039039039039038
> Kappa:    0.818534791049351
138,141c138,141
<  7    3    0  0
<  3  121   33  0
<  0   11  144  4
<  0    0    6  1
---
>  1    9    0  0
>  0  140   17  0
>  0   10  149  0
>  0    0    7  0
143,144c143,144
< Accuracy: 0.8198198198198198
< Kappa:    0.6695445072938374
---
> Accuracy: 0.8708708708708709
> Kappa:    0.754849423890154
146c146
< Mean Accuracy: 0.8258258258258259
---
> Mean Accuracy: 0.8898898898898899
220,223c220,223
<  14    9    0   0
<   2  130    7   0
<   0   10  140   2
<   0    1    2  16
---
>  12   11    0   0
>   0  133    6   0
>   0    4  148   0
>   0    0    7  12
225,226c225,226
< Accuracy: 0.9009009009009009
< Kappa:    0.83520043190714
---
> Accuracy: 0.9159159159159159
> Kappa:    0.8573024594052737
231,234c231,234
<  10   13    0   0
<   1  139   16   0
<   0   13  120   1
<   0    0   10  10
---
>  14    9    0   0
>   0  140   16   0
>   0    4  128   2
>   0    0    3  17
236,237c236,237
< Accuracy: 0.8378378378378378
< Kappa:    0.7238297088094361
---
> Accuracy: 0.8978978978978979
> Kappa:    0.8300535867068942
242,245c242,245
<  16    1    0   0
<   1  126   10   0
<   0    1  150   0
<   0    0    7  21
---
>  14    3    0   0
>   0  132    5   0
>   0    2  145   4
>   0    0    4  24
247,248c247,248
< Accuracy: 0.93993993993994
< Kappa:    0.9009797945256397
---
> Accuracy: 0.9459459459459459
> Kappa:    0.911650256470727
250c250
< Mean Accuracy: 0.8928928928928929
---
> Mean Accuracy: 0.91991991991992
310,312c310,312
<         ├─ Feature 2 < 3.1 ?
<             ├─ Iris-virginica : 2/2
<             └─ Iris-versicolor : 1/1
---
>         ├─ Feature 1 < 5.95 ?
>             ├─ Iris-versicolor : 1/1
>             └─ Iris-virginica : 2/2
355,356c355,356
<   0  20  1
<   0   1  8
---
>   0  19  2
>   0   0  9
359c359
< Kappa:    0.9366286438529784
---
> Kappa:    0.9375780274656679
375,376c375,376
<   0  12   1
<   0   7  15
---
>   0  13   0
>   0   6  16
378,379c378,379
< Accuracy: 0.84
< Kappa:    0.7613365155131264
---
> Accuracy: 0.88
> Kappa:    0.8210023866348449
381c381
< Mean Accuracy: 0.9066666666666666
---
> Mean Accuracy: 0.9199999999999999
426c426
< Mean Accuracy: 0.8270217144261188
---
> Mean Accuracy: 0.8444669676587119
463,465c463,465
< Mean Squared Error:     2.0183096134238294
< Correlation Coeff:      0.8903914722230327
< Coeff of Determination: 0.7924911697044006
---
> Mean Squared Error:     1.2353634596621743
> Correlation Coeff:      0.9467692443993198
> Coeff of Determination: 0.8729883538187402
468,470c468,470
< Mean Squared Error:     1.9714838724549328
< Correlation Coeff:      0.910241766877058
< Coeff of Determination: 0.8011434924520122
---
> Mean Squared Error:     1.3297177364601998
> Correlation Coeff:      0.9564527935419953
> Coeff of Determination: 0.8658761409152053
473,475c473,475
< Mean Squared Error:     1.6739772387561769
< Correlation Coeff:      0.9029059136519314
< Coeff of Determination: 0.813068012307753
---
> Mean Squared Error:     1.1170134745442588
> Correlation Coeff:      0.9507514465365866
> Coeff of Determination: 0.8752638063163086
477c477
< Mean Coeff of Determination: 0.8022342248213886
---
> Mean Coeff of Determination: 0.8713761003500847
488c488
< Mean Coeff of Determination: 0.5825527898815513
---
> Mean Coeff of Determination: 0.6324059967649163

Using `rand(_rng, i)` didn't really put all copies of `rng` into a
unique state, the states were still interlocked (all the generators
produced same sequence of random numbers with some offset). Calling `
seed!` with a deterministic, pseudo-random seed for each thread produces
much better results, which is also visible in the classification and
regression accuracies produced by the tests.
@ablaom
Copy link
Member

ablaom commented Nov 28, 2022

Thanks @dhanak for this valuable contribution. @rikhuijzer I really think you are in the best position to review this PR, if you don't mind?

@rikhuijzer
Copy link
Member

rikhuijzer commented Nov 29, 2022

I think it's safe to assume that Random.seed!(rng, a) is not correlated with Random.seed!(rng, b) when a != b. That means that it should be safe to drop the shared_seed.

Below are the accuracy comparisons of what is currently the dev branch of DecisionTree versus one where

_rng = Random.seed!(copy(rng), i)
diff --git a/tree-old.txt b/tree-new.txt
index 5086247..f7fe469 100644
--- a/tree-old.txt
+++ b/tree-new.txt
@@ -48,64 +48,64 @@ Mean Accuracy: 0.8688688688688688
 ##### nfoldCV Classification Forest #####
 Testing nfoldCV_forest
 
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
 
-Mean Accuracy: 0.8448448448448449
+Mean Accuracy: 0.908908908908909
 
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
 
-Mean Accuracy: 0.8748748748748749
+Mean Accuracy: 0.8998998998998999
 
-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914
 
-Mean Accuracy: 0.8608608608608609
+Mean Accuracy: 0.913913913913914
 
-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059
 
-Mean Accuracy: 0.8358358358358359
+Mean Accuracy: 0.9059059059059059
 
-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089
 
-Mean Accuracy: 0.8818818818818818
+Mean Accuracy: 0.9089089089089089
 
-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018
 
-Mean Accuracy: 0.8788788788788789
+Mean Accuracy: 0.9019019019019018
 
 Fold 1
 Classes:  [-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 6    3    0  0
- 0  121   29  0
- 0   26  136  7
- 0    0    4  1
+ 4    5    0  0
+ 1  124   25  0
+ 0    4  165  0
+ 0    0    5  0
 
-Accuracy: 0.7927927927927928
-Kappa:    0.6153446948136739
+Accuracy: 0.8798798798798799
+Kappa:    0.7701030394035107
 
 Fold 2
 Classes:  [-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 9    1    0  0
- 5  123   15  0
- 0   16  156  2
- 0    0    6  0
+ 8    2    0  0
+ 0  128   15  0
+ 0   10  164  0
+ 0    0    4  2
 
-Accuracy: 0.8648648648648649
-Kappa:    0.7499123817153157
+Accuracy: 0.9069069069069069
+Kappa:    0.8248409264443879
 
 Fold 3
 Classes:  [-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 7    3    0  0
- 3  121   33  0
- 0   11  144  4
+ 2    8    0  0
+ 0  141   16  0
+ 0    4  155  0
  0    0    6  1
 
-Accuracy: 0.8198198198198198
-Kappa:    0.6695445072938374
+Accuracy: 0.8978978978978979
+Kappa:    0.807114382091383
 
-Mean Accuracy: 0.8258258258258259
+Mean Accuracy: 0.8948948948948949
 
 ##### nfoldCV Adaboosted Stumps #####
 Testing nfoldCV_stumps
@@ -179,37 +179,37 @@ Mean Accuracy: 0.9629629629629629
 Fold 1
 Classes:  Int32[-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 14    9    0   0
-  2  130    7   0
-  0   10  140   2
-  0    1    2  16
+ 17    6    0   0
+  0  135    4   0
+  0    8  144   0
+  0    0    4  15
 
-Accuracy: 0.9009009009009009
-Kappa:    0.83520043190714
+Accuracy: 0.933933933933934
+Kappa:    0.8896653513660049
 
 Fold 2
 Classes:  Int32[-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
- 10   13    0   0
-  1  139   16   0
-  0   13  120   1
-  0    0   10  10
+ 13   10    0   0
+  0  143   13   0
+  0    7  125   2
+  0    0    1  19
 
-Accuracy: 0.8378378378378378
-Kappa:    0.7238297088094361
+Accuracy: 0.9009009009009009
+Kappa:    0.8349603508350355
 
 Fold 3
 Classes:  Int32[-2, -1, 0, 1]
 Matrix:   4×4 Matrix{Int64}:
  16    1    0   0
-  1  126   10   0
+  0  127   10   0
   0    1  150   0
-  0    0    7  21
+  0    0   10  18
 
-Accuracy: 0.93993993993994
-Kappa:    0.9009797945256397
+Accuracy: 0.933933933933934
+Kappa:    0.8902800658978584
 
-Mean Accuracy: 0.8928928928928929
+Mean Accuracy: 0.9229229229229229
 
 ##### nfoldCV Adaboosted Stumps #####
 
@@ -265,13 +265,13 @@ Feature 3 < 2.45 ?
             └─ Iris-virginica : 1/1
         └─ Feature 4 < 1.55 ?
             ├─ Iris-virginica : 3/3
-            └─ Feature 1 < 6.95 ?
+            └─ Feature 3 < 5.45 ?
                 ├─ Iris-versicolor : 2/2
                 └─ Iris-virginica : 1/1
     └─ Feature 3 < 4.85 ?
-        ├─ Feature 2 < 3.1 ?
-            ├─ Iris-virginica : 2/2
-            └─ Iris-versicolor : 1/1
+        ├─ Feature 1 < 5.95 ?
+            ├─ Iris-versicolor : 1/1
+            └─ Iris-virginica : 2/2
         └─ Iris-virginica : 43/43
 
 ##### nfoldCV Classification Tree #####
@@ -314,33 +314,33 @@ Fold 1
 Classes:  ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
 Matrix:   3×3 Matrix{Int64}:
  20   0  0
-  0  20  1
+  0  18  3
   0   1  8
 
-Accuracy: 0.96
-Kappa:    0.9366286438529784
+Accuracy: 0.92
+Kappa:    0.8751560549313357
 
 Fold 2
 Classes:  ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
 Matrix:   3×3 Matrix{Int64}:
  15   0   0
   0  15   1
-  0   3  16
+  0   2  17
 
-Accuracy: 0.92
-Kappa:    0.8798076923076925
+Accuracy: 0.94
+Kappa:    0.9096929560505719
 
 Fold 3
 Classes:  ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]
 Matrix:   3×3 Matrix{Int64}:
  15   0   0
-  0  12   1
-  0   7  15
+  0  13   0
+  0   3  19
 
-Accuracy: 0.84
-Kappa:    0.7613365155131264
+Accuracy: 0.94
+Kappa:    0.9090357792601576
 
-Mean Accuracy: 0.9066666666666666
+Mean Accuracy: 0.9333333333333332
 
 ##### nfoldCV Classification Adaboosted Stumps #####
 
@@ -385,7 +385,7 @@ Mean Accuracy: 0.8109892809975735
 
 ##### 3 foldCV Classification Forest #####
 
-Mean Accuracy: 0.8270217144261188
+Mean Accuracy: 0.8429005804846587
 
 ##### nfoldCV Classification Adaboosted Stumps #####
 
@@ -422,21 +422,21 @@ Mean Coeff of Determination: 0.821479058935842
 ##### nfoldCV Regression Forest #####
 
 Fold 1
-Mean Squared Error:     2.0183096134238294
-Correlation Coeff:      0.8903914722230327
-Coeff of Determination: 0.7924911697044006
+Mean Squared Error:     1.3577742526795888
+Correlation Coeff:      0.9396271935146402
+Coeff of Determination: 0.8604029108789377
 
 Fold 2
-Mean Squared Error:     1.9714838724549328
-Correlation Coeff:      0.910241766877058
-Coeff of Determination: 0.8011434924520122
+Mean Squared Error:     1.3034832328733625
+Correlation Coeff:      0.9529278684745566
+Coeff of Determination: 0.8685223212027657
 
 Fold 3
-Mean Squared Error:     1.6739772387561769
-Correlation Coeff:      0.9029059136519314
-Coeff of Determination: 0.813068012307753
+Mean Squared Error:     1.1485186853278506
+Correlation Coeff:      0.9420191589030741
+Coeff of Determination: 0.8717456392002396
 
-Mean Coeff of Determination: 0.8022342248213886
+Mean Coeff of Determination: 0.8668902904273144
 ==================================================
 TEST: regression/digits.jl
 
@@ -447,7 +447,7 @@ Mean Coeff of Determination: 0.6349826429860214
 
 ##### 3 foldCV Regression Forest #####
 
-Mean Coeff of Determination: 0.5825527898815513
+Mean Coeff of Determination: 0.6477805012747754
 ==================================================
 TEST: regression/scikitlearn.jl
 
@@ -496,5 +496,5 @@ TEST: miscellaneous/feature_importance_test.jl
 
 ==================================================
 Test Summary: | Pass  Total   Time
-Test Suites   | 9658   9658  53.0s
+Test Suites   | 9612   9612  53.6s
      Testing DecisionTree tests passed

What do you think @dhanak?

@dhanak
Copy link
Contributor Author

dhanak commented Nov 29, 2022

I think it's safe to assume that Random.seed!(rng, a) is not correlated with Random.seed!(rng, b) when a != b. That means that it should be safe to drop the shared_seed.
What do you think @dhanak?

I agree on the assumption, that is why using shared_seed + i is good enough. I disagree on the conclusion, however. The role of shared_seed is not to disconnect the various copies of rng (i takes care of that), but to make the seeds depend on the current state of rng, and thus make them deterministically different for every unique state of rng.

In your version, every tree with a specific index draws the same sequence of numbers for each invocation, given a specific class of rng, irrespective of the specific state in which rng is. I.e., the 1st tree always uses one set of numbers, the 2nd tree always uses another, etc. They are different from one another, but not different upon each invocation.

@rikhuijzer
Copy link
Member

rikhuijzer commented Nov 29, 2022

Now I get it. Thanks, David 😄

@ablaom Can you merge this and create a release? I don't yet understand how to create releases in the MLJ-style, unfortunately.

@ablaom
Copy link
Member

ablaom commented Nov 29, 2022

Sure, I'll take care of it. FYI: new release instructions.

@ablaom ablaom merged commit 2efcb75 into JuliaAI:dev Nov 29, 2022
@ablaom ablaom mentioned this pull request Nov 29, 2022
@ablaom
Copy link
Member

ablaom commented Nov 29, 2022

Thanks @dhanak for this valuable contribution. Thank you @rikhuijzer for your generous engagement and review. 🙏🏾

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RNG “shuffling” introduced in #174 is fundamentally flawed
3 participants