Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-15556 Implement UpliftDRF MOJO #15615

Merged
merged 26 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1bb9df9
Implement ATE, ATT, ATC metrics
maurever Jun 19, 2023
9d9a277
Fix score with treatment column
maurever Jun 22, 2023
54c37f7
Enable custom metric for UpliftDRF
maurever Jun 22, 2023
af2106b
fix custom att calculation, add example to doc
maurever Jun 27, 2023
d1cf383
GH-6783 fix custom function definition
maurever Jul 20, 2023
42adee3
Add atc custom, test new metric in R
maurever Jul 24, 2023
c7f6334
fix test
maurever Jul 27, 2023
28ddaac
fix att and atc metric bug
maurever Aug 7, 2023
fcb832c
fix make metrics runit
maurever Aug 9, 2023
82d5d1e
Fix make_metrics bug
maurever Aug 11, 2023
9a2f84d
GH-15556 implement uplift mojo
maurever Jun 28, 2023
f745536
Prepare mojo/pojo writer and reader
maurever Aug 16, 2023
21d6235
Implement mojo structure, add junit test, add easy predict function
maurever Aug 23, 2023
2742fbd
rebase master
maurever Aug 23, 2023
8a0ff50
Fix make metrics functionality
maurever Aug 23, 2023
81b6b81
Fix UpliftOtput annotation
maurever Aug 24, 2023
60d5f88
Fix build
maurever Aug 24, 2023
a17fe60
fix custom auuc threshold api
maurever Aug 24, 2023
6173b3f
fix makeMetrics bug
maurever Aug 30, 2023
68dcad9
Fix make_metrics test
maurever Aug 31, 2023
0e5f59f
fix make metrics bug
maurever Sep 13, 2023
48afd94
fix mojo bug, fix make_metrics bug, clean code
maurever Sep 20, 2023
9b74707
Move custom threshold logic from make metrics to another PR
maurever Sep 20, 2023
b490cfa
Move custom threshold logic from make metrics to another PR
maurever Sep 20, 2023
821cb4e
Fix failed tests
maurever Sep 21, 2023
e21625d
Merge branch 'master' into maurever_GH-15556_upliftdrf_mojo
maurever Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/UpliftDRFModelV3.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hex.schemas;

import hex.tree.uplift.UpliftDRFModel;
import water.api.API;

public class UpliftDRFModelV3 extends SharedTreeModelV3<UpliftDRFModel,
UpliftDRFModelV3,
Expand All @@ -9,8 +10,18 @@ public class UpliftDRFModelV3 extends SharedTreeModelV3<UpliftDRFModel,
UpliftDRFModel.UpliftDRFOutput,
UpliftDRFModelV3.UpliftDRFModelOutputV3> {

public static final class UpliftDRFModelOutputV3 extends SharedTreeModelV3.SharedTreeModelOutputV3<UpliftDRFModel.UpliftDRFOutput, UpliftDRFModelOutputV3> {}
public static final class UpliftDRFModelOutputV3 extends SharedTreeModelV3.SharedTreeModelOutputV3<UpliftDRFModel.UpliftDRFOutput, UpliftDRFModelOutputV3> {

@API(help="Default thresholds to calculate AUUC metric. If validation is enabled, thresholds from validation metrics is saved here. Otherwise thresholds are from training metrics.")
public double[] default_auuc_thresholds;

@Override public UpliftDRFModelV3.UpliftDRFModelOutputV3 fillFromImpl(UpliftDRFModel.UpliftDRFOutput impl) {
UpliftDRFModelV3.UpliftDRFModelOutputV3 uov3 = super.fillFromImpl(impl);
uov3.default_auuc_thresholds = impl._defaultAuucThresholds;
return uov3;
}
}

public UpliftDRFV3.UpliftDRFParametersV3 createParametersSchema() { return new UpliftDRFV3.UpliftDRFParametersV3(); }
public UpliftDRFModelOutputV3 createOutputSchema() { return new UpliftDRFModelOutputV3(); }

Expand Down
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/tree/Score.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBMModel;
import hex.tree.uplift.UpliftDRFModel;
import org.apache.log4j.Logger;
import water.Iced;
import water.Key;
Expand Down
1 change: 0 additions & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,6 @@ protected final boolean doScoringAndSaveModel(boolean finalScoring, boolean oob,
out._training_metrics = mm;
if (oob) out._training_metrics._description = "Metrics reported on Out-Of-Bag training samples";
out._scored_train[out._ntrees].fillFrom(mm);

// Score again on validation data
if( _parms._valid != null) {
Frame v = new Frame(valid());
Expand Down
3 changes: 2 additions & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTreeModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.glm.GLMModel;
import hex.tree.uplift.UpliftDRFModel;
import hex.util.LinearAlgebraUtils;
import org.apache.log4j.Logger;
import water.*;
Expand Down Expand Up @@ -166,6 +166,7 @@ public boolean forceStrictlyReproducibleHistograms() {
case Binomial: return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
case Multinomial: return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(),domain, _parms._auc_type);
case Regression: return new ModelMetricsRegression.MetricBuilderRegression();
case BinomialUplift: return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, ((UpliftDRFModel.UpliftDRFOutput)_output)._defaultAuucThresholds);
default: throw H2O.unimpl();
}
}
Expand Down
22 changes: 20 additions & 2 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRF.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package hex.tree.uplift;

import hex.*;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.upliftdrf.UpliftDrfMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.*;
import org.apache.log4j.Logger;
Expand Down Expand Up @@ -50,12 +52,12 @@ public UpliftDRF(boolean startup_once) {

@Override
public boolean haveMojo() {
return false;
return true;
}

@Override
public boolean havePojo() {
return false;
return true;
}

@Override
Expand Down Expand Up @@ -472,6 +474,22 @@ static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output,
return table;
}

@Override
public PojoWriter makePojoWriter(Model<?, ?, ?> genericModel, MojoModel mojoModel) {
UpliftDrfMojoModel upliftDrfMojoModel = (UpliftDrfMojoModel) mojoModel;
CompressedTree[][] trees = MojoUtils.extractCompressedTrees(upliftDrfMojoModel);
return new UpliftDrfPojoWriter(genericModel, upliftDrfMojoModel.getCategoricalEncoding(), false, trees, upliftDrfMojoModel._balanceClasses);
}

@Override
protected void addCustomInfo(UpliftDRFModel.UpliftDRFOutput out) {
if(out._validation_metrics != null){
out._defaultAuucThresholds = ((ModelMetricsBinomialUplift)out._validation_metrics)._auuc._ths;
} else {
out._defaultAuucThresholds = ((ModelMetricsBinomialUplift)out._training_metrics)._auuc._ths;
}
}

@Override
protected UpliftScoreExtension makeScoreExtension() {
return new UpliftScoreExtension();
Expand Down
29 changes: 22 additions & 7 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRFModel.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package hex.tree.uplift;

import hex.*;
import hex.tree.CompressedForest;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModelWithContributions;
import hex.tree.SharedTreePojoWriter;
import hex.util.EffectiveParametersUtils;
import water.H2O;
import water.Key;

public class UpliftDRFModel extends SharedTreeModel<UpliftDRFModel, UpliftDRFModel.UpliftDRFParameters, UpliftDRFModel.UpliftDRFOutput> {
Expand All @@ -13,7 +14,6 @@ public static class UpliftDRFParameters extends SharedTreeModel.SharedTreeParame
public String algoName() { return "UpliftDRF"; }
public String fullName() { return "Uplift Distributed Random Forest"; }
public String javaName() { return UpliftDRFModel.class.getName(); }
public boolean _binomial_double_trees = false;


public enum UpliftMetricType { AUTO, KL, ChiSquared, Euclidean }
Expand All @@ -36,6 +36,9 @@ public long progressUnits() {
}

public static class UpliftDRFOutput extends SharedTreeModelWithContributions.SharedTreeOutput {

public double[] _defaultAuucThresholds; // thresholds for AUUC to calculate metrics

public UpliftDRFOutput( UpliftDRF b) { super(b); }

@Override
Expand All @@ -45,7 +48,11 @@ public ModelCategory getModelCategory() {

@Override
public boolean isBinomialClassifier() {
return false;
return true;
}

public void setDefaultAuucThresholds(double[] defaultAuucThresholds) {
this._defaultAuucThresholds = defaultAuucThresholds;
}
}

Expand Down Expand Up @@ -77,10 +84,18 @@ public void initActualParamValues() {
}

@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
if (_output.getModelCategory() == ModelCategory.BinomialUplift) {
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain);
}
throw H2O.unimpl();
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, _output._defaultAuucThresholds);
}

@Override
public UpliftDrfMojoWriter getMojo() {
return new UpliftDrfMojoWriter(this);
}

@Override
protected SharedTreePojoWriter makeTreePojoWriter() {
CompressedForest compressedForest = new CompressedForest(_output._treeKeys, _output._domains);
CompressedForest.LocalCompressedForest localCompressedForest = compressedForest.fetch();
return new UpliftDrfPojoWriter(this, localCompressedForest._trees);
}
}
23 changes: 23 additions & 0 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDrfMojoWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package hex.tree.uplift;

import hex.tree.SharedTreeMojoWriter;

import java.io.IOException;

public class UpliftDrfMojoWriter extends SharedTreeMojoWriter<UpliftDRFModel, UpliftDRFModel.UpliftDRFParameters, UpliftDRFModel.UpliftDRFOutput> {

@SuppressWarnings("unused") // Called through reflection in ModelBuildersHandler
public UpliftDrfMojoWriter() {}

public UpliftDrfMojoWriter(UpliftDRFModel model) { super(model); }

@Override public String mojoVersion() {
return "1.40";
}

@Override
protected void writeModelData() throws IOException {
super.writeModelData();
writekv("default_auuc_thresholds", model._output._defaultAuucThresholds);
}
}
28 changes: 28 additions & 0 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDrfPojoWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package hex.tree.uplift;

import hex.Model;
import hex.genmodel.CategoricalEncoding;
import hex.tree.CompressedTree;
import hex.tree.SharedTreePojoWriter;
import water.util.SBPrintStream;

public class UpliftDrfPojoWriter extends SharedTreePojoWriter {

UpliftDrfPojoWriter(UpliftDRFModel model, CompressedTree[][] trees) {
super(model._key, model._output, model.getGenModelEncoding(), model.binomialOpt(),
trees, model._output._treeStats);
}

UpliftDrfPojoWriter(Model<?, ?, ?> model, CategoricalEncoding encoding,
boolean binomialOpt, CompressedTree[][] trees,
boolean balanceClasses) {
super(model._key, model._output, encoding, binomialOpt, trees, null);
}

@Override
protected void toJavaUnifyPreds(SBPrintStream body) {
body.ip("preds[1] /= " + _trees.length + ";").nl();
body.ip("preds[2] /= " + _trees.length + ";").nl();
body.ip("preds[0] = preds[1] - preds[2]");
}
}
47 changes: 47 additions & 0 deletions h2o-algos/src/test/java/hex/tree/gbm/GBMTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4458,6 +4458,53 @@ public void testResetThreshold() throws Exception {
}
}

@Test
public void testMojoMetrics() throws Exception {
GBMModel gbm = null;
try {
Scope.enter();
Frame frame = new TestFrameBuilder()
.withName("data")
.withColNames("ColA", "ColB", "Response")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM)
.withDataForCol(0, ard(0, 1, 0, 1, 0, 1, 0))
.withDataForCol(1, ard(Double.NaN, 1, 2, 3, 4, 5.6, 7))
.withDataForCol(2, ard(1, 0, 1, 1, 1, 0, 1))
.build();

frame = frame.toCategoricalCol(2);

Frame frameVal = new TestFrameBuilder()
.withName("dataVal")
.withColNames("ColA", "ColB", "Response")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM)
.withDataForCol(0, ard(0, 1, 1, 1, 0, 0, 1))
.withDataForCol(1, ard(Double.NaN, 1, 3, 2, 4, 8, 7))
.withDataForCol(2, ard(1, 1, 1, 0, 0, 1, 1))
.build();

frameVal = frameVal.toCategoricalCol(2);

GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = frame._key;
parms._valid = frameVal._key;
parms._response_column = "Response";
parms._ntrees = 1;
parms._min_rows = 0.1;
parms._distribution = bernoulli;

gbm = new GBM(parms).trainModel().get();
Scope.track_generic(gbm);
Frame train_score = gbm.score(frame);
Scope.track_generic(train_score);

assertTrue(gbm.testJavaScoring(frame, train_score, 1e-15));

} finally {
Scope.exit();
}
}

@Test
public void testGBMFeatureInteractions() {
Scope.enter();
Expand Down
94 changes: 94 additions & 0 deletions h2o-algos/src/test/java/hex/tree/uplift/UpliftDRFTest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package hex.tree.uplift;

import hex.ScoreKeeper;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.UpliftBinomialModelPrediction;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import org.junit.Assume;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.H2O;
import water.Scope;
import water.TestUtil;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
Expand All @@ -14,7 +20,14 @@
import water.runner.CloudSize;
import water.runner.H2ORunner;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import static org.junit.Assert.*;

Expand Down Expand Up @@ -356,4 +369,85 @@ public void testPredictCorrectOutput() {
Scope.exit();
}
}

@Test
public void testMojo() {
try {
Scope.enter();
Frame train = new TestFrameBuilder()
.withColNames("C0", "C1", "treatment", "conversion")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT, Vec.T_CAT)
.withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(2, ar("T", "C", "T", "T", "T", "C", "C", "C", "C", "C"))
.withDataForCol(3, ar("Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes", "Yes"))
.build();
train.toCategoricalCol("treatment");
train.toCategoricalCol("conversion");
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._response_column = "conversion";
p._treatment_column = "treatment";
p._ntrees = 4;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
Frame preds = model.score(train);
Scope.track_generic(preds);

assertTrue(model.testJavaScoring(train, preds,1e-15));
} finally {
Scope.exit();
}
}

@Test
public void testEasyPredictMojo() throws Exception {
try {
Scope.enter();
Frame train = new TestFrameBuilder()
.withColNames("C0", "C1", "treatment", "conversion")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT, Vec.T_CAT)
.withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(2, ar("T", "C", "T", "T", "T", "C", "C", "C", "C", "C"))
.withDataForCol(3, ar("Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes", "Yes"))
.build();
train.toCategoricalCol("treatment");
train.toCategoricalCol("conversion");
Scope.track_generic(train);
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._response_column = "conversion";
p._treatment_column = "treatment";
p._ntrees = 4;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
MojoModel mojo = model.toMojo();
EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(
new EasyPredictModelWrapper.Config()
.setModel(mojo)
.setEnableContributions(false)
);
Frame featureFr = train.subframe(mojo.features());
Scope.track_generic(featureFr);
for (int i = 0; i < featureFr.numRows(); i++) {
RowData row = new RowData();
for (String feat : featureFr.names()) {
if (!featureFr.vec(feat).isNA(i)) {
double value = featureFr.vec(feat).at(i);
row.put(feat, value);
}
}
UpliftBinomialModelPrediction pred = wrapper.predictUpliftBinomial(row);
assertEquals(pred.predictions.length,3);
assertEquals(pred.predictions[0], pred.predictions[1]-pred.predictions[2], 0);
}
} finally {
Scope.exit();
}
}
}
Loading