diff --git a/src/classification/main.jl b/src/classification/main.jl index ce1ed4e1..146f3d06 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -449,11 +449,21 @@ function apply_forest(forest::Ensemble{S, T}, features::AbstractVector{S}) where end end -function apply_forest(forest::Ensemble{S, T}, features::AbstractMatrix{S}) where {S, T} +function apply_forest( + forest::Ensemble{S, T}, + features::AbstractMatrix{S}; + use_multithreading = false + ) where {S, T} N = size(features,1) predictions = Array{T}(undef, N) - for i in 1:N - predictions[i] = apply_forest(forest, features[i, :]) + if use_multithreading + Threads.@threads for i in 1:N + predictions[i] = apply_forest(forest, @view(features[i, :])) + end + else + for i in 1:N + predictions[i] = apply_forest(forest, @view(features[i, :])) + end end return predictions end diff --git a/test/classification/adult.jl b/test/classification/adult.jl index 00ee7878..6c61d1b3 100644 --- a/test/classification/adult.jl +++ b/test/classification/adult.jl @@ -22,6 +22,10 @@ cm = confusion_matrix(labels, preds) f1 = impurity_importance(model) p1 = permutation_importance(model, labels, features, (model, y, X)->accuracy(y, apply_forest(model, X)), rng=StableRNG(1)).mean +preds_MT = apply_forest(model, features, use_multithreading = true) +cm_MT = confusion_matrix(labels, preds_MT) +@test cm_MT.accuracy > 0.9 + n_iterations = 15 model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1)); preds = apply_adaboost_stumps(model, coeffs, features); diff --git a/test/classification/digits.jl b/test/classification/digits.jl index d8c6bc06..096f1947 100644 --- a/test/classification/digits.jl +++ b/test/classification/digits.jl @@ -86,6 +86,10 @@ preds = apply_forest(model, X) cm = confusion_matrix(Y, preds) @test cm.accuracy > 0.95 +preds_MT = apply_forest(model, X, use_multithreading = true) +cm_MT = confusion_matrix(Y, preds_MT) +@test cm_MT.accuracy > 0.95 + n_iterations = 100 model, coeffs = DecisionTree.build_adaboost_stumps( Y, X, diff --git a/test/classification/heterogeneous.jl b/test/classification/heterogeneous.jl index 004d89b2..f3afbfa9 100644 --- a/test/classification/heterogeneous.jl +++ b/test/classification/heterogeneous.jl @@ -26,6 +26,10 @@ preds = apply_forest(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 +preds_MT = apply_forest(model, features, use_multithreading = true) +cm_MT = confusion_matrix(labels, preds_MT) +@test cm_MT.accuracy > 0.9 + n_subfeatures = 7 model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1)) preds = apply_adaboost_stumps(model, coeffs, features) diff --git a/test/classification/iris.jl b/test/classification/iris.jl index a8ffb3b8..a4eb0b05 100644 --- a/test/classification/iris.jl +++ b/test/classification/iris.jl @@ -79,6 +79,12 @@ cm = confusion_matrix(labels, preds) probs = apply_forest_proba(model, features, classes) @test reshape(sum(probs, dims=2), n) ≈ ones(n) +preds_MT = apply_forest(model, features, use_multithreading = true) +cm_MT = confusion_matrix(labels, preds_MT) +@test cm_MT.accuracy > 0.95 +@test typeof(preds_MT) == Vector{String} +@test sum(preds .!= preds_MT) == 0 + # run n-fold cross validation for forests println("\n##### nfoldCV Classification Forest #####") n_subfeatures = 2 diff --git a/test/classification/low_precision.jl b/test/classification/low_precision.jl index 4bdcfbf8..4ee31de6 100644 --- a/test/classification/low_precision.jl +++ b/test/classification/low_precision.jl @@ -48,6 +48,11 @@ cm = confusion_matrix(labels, preds) @test typeof(preds) == Vector{Int32} @test cm.accuracy > 0.9 +preds_MT = apply_forest(model, features, use_multithreading = true) +cm_MT = confusion_matrix(labels, preds_MT) +@test typeof(preds_MT) == Vector{Int32} +@test cm_MT.accuracy > 0.9 + n_iterations = Int32(25) model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1)); preds = apply_adaboost_stumps(model, coeffs, features); @@ -116,6 +121,10 @@ model = build_forest(labels, features) preds = apply_forest(model, features) @test typeof(preds) == Vector{Int8} +preds_MT = apply_forest(model, features, use_multithreading = true) +@test typeof(preds_MT) == Vector{Int8} +@test sum(abs.(preds .- preds_MT)) == zero(Int8) + model = build_tree(labels, features) preds = apply_tree(model, features) @test typeof(preds) == Vector{Int8} diff --git a/test/classification/random.jl b/test/classification/random.jl index a91ebac6..8b6502ed 100644 --- a/test/classification/random.jl +++ b/test/classification/random.jl @@ -55,6 +55,12 @@ cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 @test typeof(preds) == Vector{Int} +preds_MT = apply_forest(model, features, use_multithreading = true) +cm_MT = confusion_matrix(labels, preds_MT) +@test cm_MT.accuracy > 0.9 +@test typeof(preds_MT) == Vector{Int} +@test sum(abs.(preds .- preds_MT)) == zero(Int) + n_subfeatures = 3 n_trees = 9 partial_sampling = 0.7 @@ -77,6 +83,10 @@ cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.6 @test length(model) == n_trees +preds_MT = apply_forest(model, features, use_multithreading = true) +cm_MT = confusion_matrix(labels, preds_MT) +@test cm_MT.accuracy > 0.9 + # test n_subfeatures n_subfeatures = 0 m_partial = build_forest(labels, features; rng=StableRNG(1)) # default sqrt(n_features) diff --git a/test/regression/digits.jl b/test/regression/digits.jl index fdf6105d..43c307f3 100644 --- a/test/regression/digits.jl +++ b/test/regression/digits.jl @@ -80,6 +80,10 @@ model = build_forest( preds = apply_forest(model, X) @test R2(Y, preds) > 0.8 +preds_MT = apply_forest(model, X, use_multithreading = true) +@test R2(Y, preds_MT) > 0.8 +@test sum(abs.(preds .- preds_MT)) < 1e-8 + println("\n##### 3 foldCV Regression Tree #####") n_folds = 5 r2 = nfoldCV_tree(Y, X, n_folds; rng=StableRNG(1), verbose=false); diff --git a/test/regression/low_precision.jl b/test/regression/low_precision.jl index 2a5ec11c..9b404ea0 100644 --- a/test/regression/low_precision.jl +++ b/test/regression/low_precision.jl @@ -47,6 +47,11 @@ preds = apply_forest(model, features) @test R2(labels, preds) > 0.9 @test typeof(preds) <: Vector{Float64} +preds_MT = apply_forest(model, features, use_multithreading=true) +@test R2(labels, preds_MT) > 0.9 +@test typeof(preds_MT) <: Vector{Float64} +@test sum(abs.(preds .- preds_MT)) < 1.0e-8 + println("\n##### nfoldCV Regression Tree #####") n_folds = Int32(3) pruning_purity = 1.0 @@ -102,6 +107,10 @@ model = build_forest(labels, features) preds = apply_forest(model, features) @test typeof(preds) == Vector{Float16} +preds_MT = apply_forest(model, features, use_multithreading = true) +@test typeof(preds_MT) == Vector{Float16} +@test sum(abs.(preds .- preds_MT)) < 1.0e-8 + model = build_tree(labels, features) preds = apply_tree(model, features) @test typeof(preds) == Vector{Float16}