Skip to content

Commit

Permalink
test cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
valenad1 committed Sep 19, 2023
1 parent c08bef6 commit 18e679e
Showing 1 changed file with 29 additions and 32 deletions.
61 changes: 29 additions & 32 deletions h2o-algos/src/test/java/hex/adaboost/AdaBoostTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import hex.Model;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRFModel;
import org.apache.commons.io.FileUtils;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand All @@ -22,7 +21,6 @@
import water.util.FrameUtils;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -102,30 +100,7 @@ public void testBasicTrainGLM() {
} finally {
Scope.exit();
}
}

@Test
public void testBasicTrainGLMWeakLerner() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
String response = "CAPSULE";
train.toCategoricalCol(response);
GLMModel.GLMParameters p = new GLMModel.GLMParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._response_column = response;

GLM adaBoost = new GLM(p);
GLMModel adaBoostModel = adaBoost.trainModel().get();
Scope.track_generic(adaBoostModel);
assertNotNull(adaBoostModel);
Frame score = adaBoostModel.score(train);
Scope.track(score);
} finally {
Scope.exit();
}
}
}

@Test
public void testBasicTrainLarge() {
Expand Down Expand Up @@ -467,7 +442,6 @@ public void testBasicTrainAndScoreGLM() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
Frame test = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
String response = "CAPSULE";
train.toCategoricalCol(response);
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
Expand All @@ -482,7 +456,7 @@ public void testBasicTrainAndScoreGLM() {
Scope.track_generic(adaBoostModel);
assertNotNull(adaBoostModel);

Frame score = adaBoostModel.score(test);
Frame score = adaBoostModel.score(train);
Scope.track(score);
} finally {
Scope.exit();
Expand All @@ -494,7 +468,6 @@ public void testBasicTrainAndScoreGBM() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
Frame test = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
String response = "CAPSULE";
train.toCategoricalCol(response);
AdaBoostModel.AdaBoostParameters p = new AdaBoostModel.AdaBoostParameters();
Expand All @@ -509,7 +482,31 @@ public void testBasicTrainAndScoreGBM() {
Scope.track_generic(adaBoostModel);
assertNotNull(adaBoostModel);

Frame score = adaBoostModel.score(test);
Frame score = adaBoostModel.score(train);
Scope.track(score);
} finally {
Scope.exit();
}
}

@Test
public void test() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/prostate/prostate.csv"));
String response = "CAPSULE";
train.toCategoricalCol(response);
GBMModel.GBMParameters p = new GBMModel.GBMParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._response_column = response;

GBM adaBoost = new GBM(p);
GBMModel adaBoostModel = adaBoost.trainModel().get();
Scope.track_generic(adaBoostModel);
assertNotNull(adaBoostModel);

Frame score = adaBoostModel.score(train);
Scope.track(score);
} finally {
Scope.exit();
Expand Down

0 comments on commit 18e679e

Please sign in to comment.