Skip to content

Commit

Permalink
Merge pull request #188 from JuliaAI/salbert83-master
Browse files Browse the repository at this point in the history
Multithreading for random forest predictions
  • Loading branch information
ablaom authored Jul 18, 2022
2 parents 33bbec4 + 5811020 commit 9e77eac
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,21 @@ function apply_forest(forest::Ensemble{S, T}, features::AbstractVector{S}) where
end
end

function apply_forest(forest::Ensemble{S, T}, features::AbstractMatrix{S}) where {S, T}
function apply_forest(
forest::Ensemble{S, T},
features::AbstractMatrix{S};
use_multithreading = false
) where {S, T}
N = size(features,1)
predictions = Array{T}(undef, N)
for i in 1:N
predictions[i] = apply_forest(forest, features[i, :])
if use_multithreading
Threads.@threads for i in 1:N
predictions[i] = apply_forest(forest, @view(features[i, :]))
end
else
for i in 1:N
predictions[i] = apply_forest(forest, @view(features[i, :]))
end
end
return predictions
end
Expand Down
4 changes: 4 additions & 0 deletions test/classification/adult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ cm = confusion_matrix(labels, preds)
f1 = impurity_importance(model)
p1 = permutation_importance(model, labels, features, (model, y, X)->accuracy(y, apply_forest(model, X)), rng=StableRNG(1)).mean

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9

n_iterations = 15
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
preds = apply_adaboost_stumps(model, coeffs, features);
Expand Down
4 changes: 4 additions & 0 deletions test/classification/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ preds = apply_forest(model, X)
cm = confusion_matrix(Y, preds)
@test cm.accuracy > 0.95

preds_MT = apply_forest(model, X, use_multithreading = true)
cm_MT = confusion_matrix(Y, preds_MT)
@test cm_MT.accuracy > 0.95

n_iterations = 100
model, coeffs = DecisionTree.build_adaboost_stumps(
Y, X,
Expand Down
4 changes: 4 additions & 0 deletions test/classification/heterogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ preds = apply_forest(model, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9

n_subfeatures = 7
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1))
preds = apply_adaboost_stumps(model, coeffs, features)
Expand Down
6 changes: 6 additions & 0 deletions test/classification/iris.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ cm = confusion_matrix(labels, preds)
probs = apply_forest_proba(model, features, classes)
@test reshape(sum(probs, dims=2), n) ones(n)

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.95
@test typeof(preds_MT) == Vector{String}
@test sum(preds .!= preds_MT) == 0

# run n-fold cross validation for forests
println("\n##### nfoldCV Classification Forest #####")
n_subfeatures = 2
Expand Down
9 changes: 9 additions & 0 deletions test/classification/low_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ cm = confusion_matrix(labels, preds)
@test typeof(preds) == Vector{Int32}
@test cm.accuracy > 0.9

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test typeof(preds_MT) == Vector{Int32}
@test cm_MT.accuracy > 0.9

n_iterations = Int32(25)
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
preds = apply_adaboost_stumps(model, coeffs, features);
Expand Down Expand Up @@ -116,6 +121,10 @@ model = build_forest(labels, features)
preds = apply_forest(model, features)
@test typeof(preds) == Vector{Int8}

preds_MT = apply_forest(model, features, use_multithreading = true)
@test typeof(preds_MT) == Vector{Int8}
@test sum(abs.(preds .- preds_MT)) == zero(Int8)

model = build_tree(labels, features)
preds = apply_tree(model, features)
@test typeof(preds) == Vector{Int8}
Expand Down
10 changes: 10 additions & 0 deletions test/classification/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9
@test typeof(preds) == Vector{Int}

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9
@test typeof(preds_MT) == Vector{Int}
@test sum(abs.(preds .- preds_MT)) == zero(Int)

n_subfeatures = 3
n_trees = 9
partial_sampling = 0.7
Expand All @@ -77,6 +83,10 @@ cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.6
@test length(model) == n_trees

preds_MT = apply_forest(model, features, use_multithreading = true)
cm_MT = confusion_matrix(labels, preds_MT)
@test cm_MT.accuracy > 0.9

# test n_subfeatures
n_subfeatures = 0
m_partial = build_forest(labels, features; rng=StableRNG(1)) # default sqrt(n_features)
Expand Down
4 changes: 4 additions & 0 deletions test/regression/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ model = build_forest(
preds = apply_forest(model, X)
@test R2(Y, preds) > 0.8

preds_MT = apply_forest(model, X, use_multithreading = true)
@test R2(Y, preds_MT) > 0.8
@test sum(abs.(preds .- preds_MT)) < 1e-8

println("\n##### 3 foldCV Regression Tree #####")
n_folds = 5
r2 = nfoldCV_tree(Y, X, n_folds; rng=StableRNG(1), verbose=false);
Expand Down
9 changes: 9 additions & 0 deletions test/regression/low_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ preds = apply_forest(model, features)
@test R2(labels, preds) > 0.9
@test typeof(preds) <: Vector{Float64}

preds_MT = apply_forest(model, features, use_multithreading=true)
@test R2(labels, preds_MT) > 0.9
@test typeof(preds_MT) <: Vector{Float64}
@test sum(abs.(preds .- preds_MT)) < 1.0e-8

println("\n##### nfoldCV Regression Tree #####")
n_folds = Int32(3)
pruning_purity = 1.0
Expand Down Expand Up @@ -102,6 +107,10 @@ model = build_forest(labels, features)
preds = apply_forest(model, features)
@test typeof(preds) == Vector{Float16}

preds_MT = apply_forest(model, features, use_multithreading = true)
@test typeof(preds_MT) == Vector{Float16}
@test sum(abs.(preds .- preds_MT)) < 1.0e-8

model = build_tree(labels, features)
preds = apply_tree(model, features)
@test typeof(preds) == Vector{Float16}
Expand Down

0 comments on commit 9e77eac

Please sign in to comment.