Skip to content

Commit

Permalink
Implement mojo structure, add junit test, add easy predict function
Browse files Browse the repository at this point in the history
  • Loading branch information
maurever committed Aug 23, 2023
1 parent f0e9727 commit 0ced502
Show file tree
Hide file tree
Showing 13 changed files with 214 additions and 51 deletions.
10 changes: 9 additions & 1 deletion h2o-algos/src/main/java/hex/schemas/UpliftDRFModelV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@ public class UpliftDRFModelV3 extends SharedTreeModelV3<UpliftDRFModel,
UpliftDRFModel.UpliftDRFOutput,
UpliftDRFModelV3.UpliftDRFModelOutputV3> {

public static final class UpliftDRFModelOutputV3 extends SharedTreeModelV3.SharedTreeModelOutputV3<UpliftDRFModel.UpliftDRFOutput, UpliftDRFModelOutputV3> {}
public static final class UpliftDRFModelOutputV3 extends SharedTreeModelV3.SharedTreeModelOutputV3<UpliftDRFModel.UpliftDRFOutput, UpliftDRFModelOutputV3> {
public double[] _metricThresholds;

@Override public UpliftDRFModelV3.UpliftDRFModelOutputV3 fillFromImpl(UpliftDRFModel.UpliftDRFOutput impl) {
UpliftDRFModelV3.UpliftDRFModelOutputV3 uov3 = super.fillFromImpl(impl);
uov3._metricThresholds = impl._metricThresholds;
return uov3;
}
}

public UpliftDRFV3.UpliftDRFParametersV3 createParametersSchema() { return new UpliftDRFV3.UpliftDRFParametersV3(); }
public UpliftDRFModelOutputV3 createOutputSchema() { return new UpliftDRFModelOutputV3(); }

Expand Down
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/tree/Score.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBMModel;
import hex.tree.uplift.UpliftDRFModel;
import org.apache.log4j.Logger;
import water.Iced;
import water.Key;
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ protected final boolean doScoringAndSaveModel(boolean finalScoring, boolean oob,
out._training_metrics = mm;
if (oob) out._training_metrics._description = "Metrics reported on Out-Of-Bag training samples";
out._scored_train[out._ntrees].fillFrom(mm);

// Score again on validation data
if( _parms._valid != null) {
Frame v = new Frame(valid());
Expand Down
9 changes: 9 additions & 0 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,15 @@ public PojoWriter makePojoWriter(Model<?, ?, ?> genericModel, MojoModel mojoMode
return new UpliftDrfPojoWriter(genericModel, upliftDrfMojoModel.getCategoricalEncoding(), false, trees, upliftDrfMojoModel._balanceClasses);
}

@Override
protected void addCustomInfo(UpliftDRFModel.UpliftDRFOutput out) {
if(out._validation_metrics != null){
out._metricThresholds = ((ModelMetricsBinomialUplift)out._validation_metrics)._auuc._ths;
} else {
out._metricThresholds = ((ModelMetricsBinomialUplift)out._training_metrics)._auuc._ths;
}
}

@Override
protected UpliftScoreExtension makeScoreExtension() {
return new UpliftScoreExtension();
Expand Down
3 changes: 3 additions & 0 deletions h2o-algos/src/main/java/hex/tree/uplift/UpliftDRFModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ public void initActualParamValues() {

@Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
if (_output.getModelCategory() == ModelCategory.BinomialUplift) {
if(_output._metricThresholds == null){
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, new double[]{0});
}
return new ModelMetricsBinomialUplift.MetricBuilderBinomialUplift(domain, _output._metricThresholds);
}
throw H2O.unimpl();
Expand Down
47 changes: 47 additions & 0 deletions h2o-algos/src/test/java/hex/tree/gbm/GBMTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4457,6 +4457,53 @@ public void testResetThreshold() throws Exception {
}
}

@Test
public void testMojoMetrics() throws Exception {
GBMModel gbm = null;
try {
Scope.enter();
Frame frame = new TestFrameBuilder()
.withName("data")
.withColNames("ColA", "ColB", "Response")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM)
.withDataForCol(0, ard(0, 1, 0, 1, 0, 1, 0))
.withDataForCol(1, ard(Double.NaN, 1, 2, 3, 4, 5.6, 7))
.withDataForCol(2, ard(1, 0, 1, 1, 1, 0, 1))
.build();

frame = frame.toCategoricalCol(2);

Frame frameVal = new TestFrameBuilder()
.withName("dataVal")
.withColNames("ColA", "ColB", "Response")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM)
.withDataForCol(0, ard(0, 1, 1, 1, 0, 0, 1))
.withDataForCol(1, ard(Double.NaN, 1, 3, 2, 4, 8, 7))
.withDataForCol(2, ard(1, 1, 1, 0, 0, 1, 1))
.build();

frameVal = frameVal.toCategoricalCol(2);

GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = frame._key;
parms._valid = frameVal._key;
parms._response_column = "Response";
parms._ntrees = 1;
parms._min_rows = 0.1;
parms._distribution = bernoulli;

gbm = new GBM(parms).trainModel().get();
Scope.track_generic(gbm);
Frame train_score = gbm.score(frame);
Scope.track_generic(train_score);

assertTrue(gbm.testJavaScoring(frame, train_score, 1e-15));

} finally {
Scope.exit();
}
}

@Test
public void testGBMFeatureInteractions() {
Scope.enter();
Expand Down
35 changes: 35 additions & 0 deletions h2o-algos/src/test/java/hex/tree/uplift/UpliftDRFTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package hex.tree.uplift;

import hex.ScoreKeeper;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import org.junit.Test;
Expand All @@ -14,6 +16,7 @@
import water.runner.CloudSize;
import water.runner.H2ORunner;

import java.io.IOException;
import java.util.Arrays;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -356,4 +359,36 @@ public void testPredictCorrectOutput() {
Scope.exit();
}
}

@Test
public void testMojo() {
try {
Scope.enter();
Frame train = new TestFrameBuilder()
.withColNames("C0", "C1", "treatment", "conversion")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT, Vec.T_CAT)
.withDataForCol(0, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0))
.withDataForCol(2, ar("T", "C", "T", "T", "T", "C", "C", "C", "C", "C"))
.withDataForCol(3, ar("Yes", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes", "Yes"))
.build();
train.toCategoricalCol("treatment");
train.toCategoricalCol("conversion");
UpliftDRFModel.UpliftDRFParameters p = new UpliftDRFModel.UpliftDRFParameters();
p._train = train._key;
p._response_column = "conversion";
p._treatment_column = "treatment";
p._ntrees = 4;

UpliftDRF udrf = new UpliftDRF(p);
UpliftDRFModel model = udrf.trainModel().get();
Scope.track_generic(model);
Frame preds = model.score(train);
Scope.track_generic(preds);

assertTrue(model.testJavaScoring(train, preds,1e-15));
} finally {
Scope.exit();
}
}
}
118 changes: 73 additions & 45 deletions h2o-core/src/main/java/hex/AUUC.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,43 +88,63 @@ private AUUC(AUUCBuilder bldr, boolean trueProbabilities, AUUCType auucType) {
_auucType = auucType;
_auucTypeIndx = getIndexByAUUCType(_auucType);
_nBins = bldr._nBins;
assert _nBins >= 1 : "Must have >= 1 bins for AUUC calculation, but got " + _nBins;
assert trueProbabilities || bldr._thresholds[_nBins - 1] == 1 : "Bins need to contain pred = 1 when 0-1 probabilities are used";
_n = bldr._n;
_ths = Arrays.copyOf(bldr._thresholds,_nBins);
_treatment = Arrays.copyOf(bldr._treatment,_nBins);
_control = Arrays.copyOf(bldr._control,_nBins);
_yTreatment = Arrays.copyOf(bldr._yTreatment,_nBins);
_yControl = Arrays.copyOf(bldr._yControl,_nBins);
_frequency = Arrays.copyOf(bldr._frequency, _nBins);
_frequencyCumsum = Arrays.copyOf(bldr._frequency, _nBins);
_uplift = new double[AUUCType.values().length][_nBins];
_upliftRandom = new double[AUUCType.values().length][_nBins];
_upliftNormalized = new double[AUUCType.values().length][_nBins];

// Rollup counts
long tmpt=0, tmpc=0, tmptp=0, tmpcp=0, tmpf=0;
for( int i=0; i<_nBins; i++ ) {
tmpt += _treatment[i]; _treatment[i] = tmpt;
tmpc += _control[i]; _control[i] = tmpc;
tmptp += _yTreatment[i]; _yTreatment[i] = tmptp;
tmpcp += _yControl[i]; _yControl[i] = tmpcp;
tmpf += _frequencyCumsum[i]; _frequencyCumsum[i] = tmpf;
}

// these methods need to be call in this order
setUplift();
setUpliftRandom();
setUpliftNormalized();

if (trueProbabilities) {
_auucs = computeAuucs();
_auucsRandom = computeAuucsRandom();
_aecu = computeAecu();
_auucsNormalized = computeAuucsNormalized();
_maxIdx = _auucType.maxCriterionIdx(this);
//assert _nBins >= 1 : "Must have >= 1 bins for AUUC calculation, but got " + _nBins;
if (_nBins > 0) {
assert trueProbabilities || bldr._thresholds[_nBins - 1] == 1 : "Bins need to contain pred = 1 when 0-1 probabilities are used";
_n = bldr._n;
_ths = Arrays.copyOf(bldr._thresholds, _nBins);
_treatment = Arrays.copyOf(bldr._treatment, _nBins);
_control = Arrays.copyOf(bldr._control, _nBins);
_yTreatment = Arrays.copyOf(bldr._yTreatment, _nBins);
_yControl = Arrays.copyOf(bldr._yControl, _nBins);
_frequency = Arrays.copyOf(bldr._frequency, _nBins);
_frequencyCumsum = Arrays.copyOf(bldr._frequency, _nBins);
_uplift = new double[AUUCType.values().length][_nBins];
_upliftRandom = new double[AUUCType.values().length][_nBins];
_upliftNormalized = new double[AUUCType.values().length][_nBins];

// Rollup counts
long tmpt = 0, tmpc = 0, tmptp = 0, tmpcp = 0, tmpf = 0;
for (int i = 0; i < _nBins; i++) {
tmpt += _treatment[i];
_treatment[i] = tmpt;
tmpc += _control[i];
_control[i] = tmpc;
tmptp += _yTreatment[i];
_yTreatment[i] = tmptp;
tmpcp += _yControl[i];
_yControl[i] = tmpcp;
tmpf += _frequencyCumsum[i];
_frequencyCumsum[i] = tmpf;
}

// these methods need to be call in this order
setUplift();
setUpliftRandom();
setUpliftNormalized();

if (trueProbabilities) {
_auucs = computeAuucs();
_auucsRandom = computeAuucsRandom();
_aecu = computeAecu();
_auucsNormalized = computeAuucsNormalized();
_maxIdx = _auucType.maxCriterionIdx(this);
} else {
_maxIdx = 0;
}
} else {
_maxIdx = 0;
_maxIdx = -1;
_n = 0;
_ths = null;
_treatment = null;
_control = null;
_yTreatment = null;
_yControl = null;
_frequency = null;
_frequencyCumsum = null;
_uplift = null;
_upliftRandom = null;
_upliftNormalized = null;
}
}

Expand Down Expand Up @@ -227,17 +247,23 @@ private double[] computeAuucsRandom(){
return computeAuucs(_upliftRandom);
}

private double[] computeAuucsNormalized() {return computeAuucs(_upliftNormalized);}
private double[] computeAuucsNormalized() {
return computeAuucs(_upliftNormalized);
}

private double[] computeAuucs(double[][] uplift){
AUUCType[] auucTypes = AUUCType.VALUES;
double[] auucs = new double[auucTypes.length];
for(int i = 0; i < auucTypes.length; i++ ) {
double area = 0;
for(int j = 0; j < _nBins; j++) {
area += uplift[i][j] * frequency(j);
if(_n == 0){
auucs[i] = Double.NaN;
} else {
double area = 0;
for (int j = 0; j < _nBins; j++) {
area += uplift[i][j] * frequency(j);
}
auucs[i] = area / (_n + 1);
}
auucs[i] = area/(_n+1);
}
return auucs;
}
Expand Down Expand Up @@ -270,21 +296,23 @@ public double auucNormalizedByType(AUUCType type){
return auucNormalized(idx);
}

public double auuc(int idx){ return _auucs[idx]; }
public double auuc (int idx){
return _n == 0 || idx < 0 ? Double.NaN : _auucs[idx];
}

public double auuc(){ return auuc(_auucTypeIndx); }

public double auucRandom(int idx){
return _auucsRandom[idx];
return _n == 0 || idx < 0 ? Double.NaN : _auucsRandom[idx];
}

public double auucRandom(){ return auucRandom(_auucTypeIndx); }

public double aecu(int idx) { return _aecu[idx];}
public double aecu(int idx) { return _n == 0 || idx < 0 ? Double.NaN : _aecu[idx];}

public double qini(){ return aecuByType(AUUCType.qini);}

public double auucNormalized(int idx){ return _auucsNormalized[idx]; }
public double auucNormalized(int idx){ return _n == 0 || idx < 0 ? Double.NaN : _auucsNormalized[idx]; }

public double auucNormalized(){ return auucNormalized(_auucTypeIndx); }

Expand Down
9 changes: 8 additions & 1 deletion h2o-core/src/main/java/hex/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -2214,11 +2214,15 @@ protected void setupLocal() {
if (isCancelled() || _j != null && _j.stop_requested()) return;
Chunk weightsChunk = _hasWeights && _computeMetrics ? chks[_output.weightsIdx()] : null;
Chunk offsetChunk = _output.hasOffset() ? chks[_output.offsetIdx()] : null;
Chunk treatmentChunk = _output.hasTreatment() ? chks[_output.treatmentIdx()] : null;
Chunk responseChunk = null;
float [] actual = null;
_mb = Model.this.makeMetricBuilder(_domain);
if (_computeMetrics) {
if (_output.hasResponse()) {
if (_output.hasTreatment()) {
actual = new float[2];
responseChunk = chks[_output.responseIdx()];
} else if (_output.hasResponse()) {
actual = new float[1];
responseChunk = chks[_output.responseIdx()];
} else
Expand All @@ -2245,6 +2249,9 @@ protected void setupLocal() {
for (int i = 0; i < actual.length; ++i)
actual[i] = (float) data(chks, row, i);
}
if (treatmentChunk != null) {
actual[1] = (float) treatmentChunk.atd(row);
}
_mb.perRow(preds, actual, weight, offset, Model.this);
// Handle custom metric
customMetricPerRow(preds, actual, weight, offset, Model.this);
Expand Down
7 changes: 5 additions & 2 deletions h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ public static class MetricBuilderBinomialUplift extends MetricBuilderSupervised<

public MetricBuilderBinomialUplift( String[] domain, double[] thresholds) {
super(2,domain);
assert thresholds != null: "Thresholds should not be null for metric creation.";
if(thresholds == null){
_auuc = null;
}
_auuc = new AUUC.AUUCBuilder(thresholds);

}

@Override public double[] perRow(double[] ds, float[] yact, Model m) {
Expand All @@ -156,7 +159,7 @@ public MetricBuilderBinomialUplift( String[] domain, double[] thresholds) {

@Override
public double[] perRow(double[] ds, float[] yact, double weight, double offset, Model m) {
assert _auuc == null || yact.length == 2 : "Treatment must be included in `yact` when calculating AUUC";
assert yact.length == 2 : "Treatment must be included in `yact` when calculating AUUC";
if(Float .isNaN(yact[0])) return ds; // No errors if actual is missing
if(ArrayUtils.hasNaNs(ds)) return ds; // No errors if prediction has missing values (can happen for GLM)
if(weight == 0 || Double.isNaN(weight)) return ds;
Expand Down
4 changes: 3 additions & 1 deletion h2o-genmodel/src/main/java/hex/genmodel/GenModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ public String[] getOutputNames() {
case CoxPH:
outputNames = new String[]{"lp"};
break;

case BinomialUplift:
outputNames = new String[]{"uplift_predict", "p_y1_ct1", "p_y1_ct0"};
break;
default:
throw new UnsupportedOperationException("Getting output column names for model category '" +
category + "' is not supported.");
Expand Down
Loading

0 comments on commit 0ced502

Please sign in to comment.