From f358671afa46dc3d3d234e9701c5e27749f3f64c Mon Sep 17 00:00:00 2001 From: salbert83 Date: Fri, 8 Jul 2022 20:49:39 -0400 Subject: [PATCH] Unit tests for multithreading --- test/regression/digits.jl | 4 ++++ test/regression/low_precision.jl | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/test/regression/digits.jl b/test/regression/digits.jl index fdf6105d..43c307f3 100644 --- a/test/regression/digits.jl +++ b/test/regression/digits.jl @@ -80,6 +80,10 @@ model = build_forest( preds = apply_forest(model, X) @test R2(Y, preds) > 0.8 +preds_MT = apply_forest(model, X, use_multithreading = true) +@test R2(Y, preds_MT) > 0.8 +@test sum(abs.(preds .- preds_MT)) < 1e-8 + println("\n##### 3 foldCV Regression Tree #####") n_folds = 5 r2 = nfoldCV_tree(Y, X, n_folds; rng=StableRNG(1), verbose=false); diff --git a/test/regression/low_precision.jl b/test/regression/low_precision.jl index 2a5ec11c..9b404ea0 100644 --- a/test/regression/low_precision.jl +++ b/test/regression/low_precision.jl @@ -47,6 +47,11 @@ preds = apply_forest(model, features) @test R2(labels, preds) > 0.9 @test typeof(preds) <: Vector{Float64} +preds_MT = apply_forest(model, features, use_multithreading=true) +@test R2(labels, preds_MT) > 0.9 +@test typeof(preds_MT) <: Vector{Float64} +@test sum(abs.(preds .- preds_MT)) < 1.0e-8 + println("\n##### nfoldCV Regression Tree #####") n_folds = Int32(3) pruning_purity = 1.0 @@ -102,6 +107,10 @@ model = build_forest(labels, features) preds = apply_forest(model, features) @test typeof(preds) == Vector{Float16} +preds_MT = apply_forest(model, features, use_multithreading = true) +@test typeof(preds_MT) == Vector{Float16} +@test sum(abs.(preds .- preds_MT)) < 1.0e-8 + model = build_tree(labels, features) preds = apply_tree(model, features) @test typeof(preds) == Vector{Float16}