Skip to content

Commit

Permalink
fix make metrics bug
Browse files Browse the repository at this point in the history
  • Loading branch information
maurever committed Sep 20, 2023
1 parent 68dcad9 commit 0e5f59f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
19 changes: 10 additions & 9 deletions h2o-core/src/main/java/hex/AUUC.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,21 @@ public double[] upliftRandomByType(AUUCType type){
int idx = getIndexByAUUCType(type);
return idx < 0 ? null : _upliftRandom[idx];
}

public AUUC(int nBins, Vec probs, Vec y, Vec uplift, AUUCType auucType) {
this(new AUUCImpl(calculateQuantileThresholds(nBins, probs)).doAll(probs, y, uplift)._bldr, auucType);
}

public AUUC(double[] customThresholds, Vec probs, Vec y, Vec uplift, AUUCType auucType) {
this(new AUUCImpl(customThresholds).doAll(probs, y, uplift)._bldr, auucType);
public AUUC(Vec probs, Vec y, Vec uplift, AUUCType auucType, int nbins) {
this(new AUUCImpl(calculateQuantileThresholds(nbins, probs)).doAll(probs, y, uplift)._bldr, auucType);
}

public AUUC(AUUCBuilder bldr, AUUCType auucType) {
this(bldr, true, auucType);
}


public AUUC(double[] customThresholds, Vec probs, Vec y, Vec uplift, AUUCType auucType) {
this(new AUUCImpl(customThresholds).doAll(probs, y, uplift)._bldr, auucType);
}

private AUUC(AUUCBuilder bldr, boolean trueProbabilities, AUUCType auucType) {
public AUUC(AUUCBuilder bldr, boolean trueProbabilities, AUUCType auucType) {
_auucType = auucType;
_auucTypeIndx = getIndexByAUUCType(_auucType);
_nBins = bldr._nBins;
Expand Down Expand Up @@ -316,11 +317,11 @@ public double auucRandom(int idx){

public double auucNormalized(){ return auucNormalized(_auucTypeIndx); }

private static class AUUCImpl extends MRTask<AUUCImpl> {
public static class AUUCImpl extends MRTask<AUUCImpl> {
final double[] _thresholds;
AUUCBuilder _bldr;

AUUCImpl(double[] thresholds) {
public AUUCImpl(double[] thresholds) {
_thresholds = thresholds;
}

Expand Down
20 changes: 10 additions & 10 deletions h2o-core/src/main/java/hex/ModelMetricsBinomialUplift.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.*;
import water.util.ArrayUtils;
import water.util.Log;

import java.util.Arrays;
Expand Down Expand Up @@ -163,8 +162,7 @@ static public ModelMetricsBinomialUplift make(Vec predictedProbs, Vec actualLabe
mb = new UpliftBinomialMetrics(labels.domain(), customAuucThresholds).doAll(fr)._mb;
}
labels.remove();
ModelMetricsBinomialUplift mm = (ModelMetricsBinomialUplift) mb.makeModelMetrics(null, fr, new Frame(predictedProbs),
fr.vec("labels"), fr.vec("treatment"), auucType, auucNbins); // use the Vecs from the frame (to make sure the ESPC is identical)
ModelMetricsBinomialUplift mm = (ModelMetricsBinomialUplift) mb.makeModelMetrics(null, fr, auucType);
mm._description = "Computed on user-given predictions and labels.";
return mm;
} finally {
Expand Down Expand Up @@ -274,25 +272,27 @@ public double[] perRow(double[] ds, float[] yact, double weight, double offset,
treatment = frameWithExtraColumns.vec(m._parms._treatment_column);
}
}
int auucNbins = m==null || m._parms._auuc_nbins == -1?
int auucNbins = m==null || m._parms._auuc_nbins == -1?
AUUC.NBINS : m._parms._auuc_nbins;
return makeModelMetrics(m, f, preds, resp, treatment, auucType, auucNbins);
}

private ModelMetrics makeModelMetrics(final Model m, final Frame f, final Frame preds,
final Vec resp, final Vec treatment, AUUC.AUUCType auucType, int nbins) {
AUUC auuc = null;
if (preds != null && resp != null && treatment != null) {
if (_auuc == null || _auuc._nBins > 0) {
auuc = new AUUC(nbins, preds.vec(0), resp, treatment, auucType);
} else {
auuc = new AUUC(_auuc._thresholds, preds.vec(0), resp, treatment, auucType);
if (preds != null) {
if (resp != null) {
auuc = new AUUC(preds.vec(0), resp, treatment, auucType, nbins);
}
}
return makeModelMetrics(m, f, auuc);
}

private ModelMetrics makeModelMetrics(Model m, Frame f, AUUC auuc) {
private ModelMetrics makeModelMetrics(final Model m, final Frame f, AUUC.AUUCType auucType) {
return makeModelMetrics(m, f, new AUUC(_auuc, auucType));
}

public ModelMetrics makeModelMetrics(Model m, Frame f, AUUC auuc) {
double sigma = Double.NaN;
double ate = Double.NaN;
double atc = Double.NaN;
Expand Down
2 changes: 1 addition & 1 deletion h2o-core/src/test/java/hex/AUUCTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private static AUUC doAUUC(int nbins, double[] probs, double[] y, double[] treat
}
Frame fr = ArrayUtils.frame(new String[]{"probs", "y", "treatment"}, rows);
fr.vec("treatment").setDomain(new String[]{"0", "1"});
AUUC auuc = new AUUC(nbins, fr.vec("probs"),fr.vec("y"), fr.vec("treatment"), type);
AUUC auuc = new AUUC(fr.vec("probs"),fr.vec("y"), fr.vec("treatment"), type, nbins);
fr.remove();
return auuc;
}
Expand Down

0 comments on commit 0e5f59f

Please sign in to comment.