Skip to content

Commit

Permalink
Unit tests for multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
salbert83 authored and ablaom committed Jul 18, 2022
1 parent f358671 commit 5811020
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/classification/adult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ cm = confusion_matrix(labels, preds)
f1 = impurity_importance(model)
p1 = permutation_importance(model, labels, features, (model, y, X)->accuracy(y, apply_forest(model, X)), rng=StableRNG(1)).mean

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9

n_iterations = 15
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
preds = apply_adaboost_stumps(model, coeffs, features);
Expand Down
4 changes: 4 additions & 0 deletions test/classification/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ preds = apply_forest(model, X)
cm = confusion_matrix(Y, preds)
@test cm.accuracy > 0.95

preds_MT = apply_forest(model, X, use_multithreading = true)
cm_MT = confusion_matrix(Y, preds_MT)
@test cm_MT.accuracy > 0.95

n_iterations = 100
model, coeffs = DecisionTree.build_adaboost_stumps(
Y, X,
Expand Down
4 changes: 4 additions & 0 deletions test/classification/heterogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ preds = apply_forest(model, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9

n_subfeatures = 7
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1))
preds = apply_adaboost_stumps(model, coeffs, features)
Expand Down
6 changes: 6 additions & 0 deletions test/classification/iris.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ cm = confusion_matrix(labels, preds)
probs = apply_forest_proba(model, features, classes)
@test reshape(sum(probs, dims=2), n) ones(n)

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.95
@test typeof(preds_MT) == Vector{String}
@test sum(preds .!= preds_MT) == 0

# run n-fold cross validation for forests
println("\n##### nfoldCV Classification Forest #####")
n_subfeatures = 2
Expand Down
9 changes: 9 additions & 0 deletions test/classification/low_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ cm = confusion_matrix(labels, preds)
@test typeof(preds) == Vector{Int32}
@test cm.accuracy > 0.9

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test typeof(preds_MT) == Vector{Int32}
@test cm_MT.accuracy > 0.9

n_iterations = Int32(25)
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
preds = apply_adaboost_stumps(model, coeffs, features);
Expand Down Expand Up @@ -116,6 +121,10 @@ model = build_forest(labels, features)
preds = apply_forest(model, features)
@test typeof(preds) == Vector{Int8}

preds_MT = apply_forest(model, features, use_multithreading = true)
@test typeof(preds_MT) == Vector{Int8}
@test sum(abs.(preds .- preds_MT)) == zero(Int8)

model = build_tree(labels, features)
preds = apply_tree(model, features)
@test typeof(preds) == Vector{Int8}
Expand Down
10 changes: 10 additions & 0 deletions test/classification/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9
@test typeof(preds) == Vector{Int}

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9
@test typeof(preds_MT) == Vector{Int}
@test sum(abs.(preds .- preds_MT)) == zero(Int)

n_subfeatures = 3
n_trees = 9
partial_sampling = 0.7
Expand All @@ -77,6 +83,10 @@ cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.6
@test length(model) == n_trees

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9

# test n_subfeatures
n_subfeatures = 0
m_partial = build_forest(labels, features; rng=StableRNG(1)) # default sqrt(n_features)
Expand Down

0 comments on commit 5811020

Please sign in to comment.