Skip to content

Commit

Permalink
Unit tests for multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
salbert83 authored and ablaom committed Jul 18, 2022
1 parent 4f19884 commit f358671
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/regression/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ model = build_forest(
preds = apply_forest(model, X)
@test R2(Y, preds) > 0.8

preds_MT = apply_forest(model, X, use_multithreading = true)
@test R2(Y, preds_MT) > 0.8
@test sum(abs.(preds .- preds_MT)) < 1e-8

println("\n##### 3 foldCV Regression Tree #####")
n_folds = 5
r2 = nfoldCV_tree(Y, X, n_folds; rng=StableRNG(1), verbose=false);
Expand Down
9 changes: 9 additions & 0 deletions test/regression/low_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ preds = apply_forest(model, features)
@test R2(labels, preds) > 0.9
@test typeof(preds) <: Vector{Float64}

preds_MT = apply_forest(model, features, use_multithreading=true)
@test R2(labels, preds_MT) > 0.9
@test typeof(preds_MT) <: Vector{Float64}
@test sum(abs.(preds .- preds_MT)) < 1.0e-8

println("\n##### nfoldCV Regression Tree #####")
n_folds = Int32(3)
pruning_purity = 1.0
Expand Down Expand Up @@ -102,6 +107,10 @@ model = build_forest(labels, features)
preds = apply_forest(model, features)
@test typeof(preds) == Vector{Float16}

preds_MT = apply_forest(model, features, use_multithreading = true)
@test typeof(preds_MT) == Vector{Float16}
@test sum(abs.(preds .- preds_MT)) < 1.0e-8

model = build_tree(labels, features)
preds = apply_tree(model, features)
@test typeof(preds) == Vector{Float16}
Expand Down

0 comments on commit f358671

Please sign in to comment.