Skip to content

Commit

Permalink
Improved support for DRF and GBM mojo types
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jun 3, 2022
1 parent a68bfad commit 488b33c
Show file tree
Hide file tree
Showing 16 changed files with 4,958 additions and 29 deletions.
37 changes: 24 additions & 13 deletions pmml-h2o/src/main/java/org/jpmml/h2o/DrfMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.mining.Segmentation.MultipleModelMethod;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CMatrixUtil;
Expand All @@ -48,7 +47,7 @@ public DrfMojoModelConverter(DrfMojoModel model){
}

@Override
public MiningModel encodeModel(Schema schema){
public Model encodeModel(Schema schema){
DrfMojoModel model = getModel();

boolean binomialDoubleTrees = getBinomialDoubleTrees(model);
Expand All @@ -62,20 +61,27 @@ public MiningModel encodeModel(Schema schema){
if(model._nclasses == 1){
ContinuousLabel continuousLabel = (ContinuousLabel)label;

MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return encodeTreeEnsemble(treeModels, (List<TreeModel> ensembleTreeModels) -> {
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, ensembleTreeModels));

return miningModel;
return miningModel;
});
} else

if(model._nclasses == 2 && !binomialDoubleTrees){
ContinuousLabel continuousLabel = new ContinuousLabel(DataType.DOUBLE);

MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethod.AVERAGE, treeModels))
.setOutput(ModelUtil.createPredictedOutput("drfValue", OpType.CONTINUOUS, DataType.DOUBLE));
Model pmmlModel = encodeTreeEnsemble(treeModels, (List<TreeModel> ensembleTreeModels) -> {
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, ensembleTreeModels));

return MiningModelUtil.createBinaryLogisticClassification(miningModel, -1d, 1d, RegressionModel.NormalizationMethod.NONE, true, schema);
return miningModel;
});

pmmlModel.setOutput(ModelUtil.createPredictedOutput("drfValue", OpType.CONTINUOUS, DataType.DOUBLE));

return MiningModelUtil.createBinaryLogisticClassification(pmmlModel, -1d, 1d, RegressionModel.NormalizationMethod.NONE, true, schema);
} else

{
Expand All @@ -84,11 +90,16 @@ public MiningModel encodeModel(Schema schema){
List<Model> models = new ArrayList<>();

for(int i = 0; i < categoricalLabel.size(); i++){
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, CMatrixUtil.getRow(treeModels, ntreesPerGroup, ntreeGroups, i)))
.setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create("drfValue", categoricalLabel.getValue(i)), OpType.CONTINUOUS, DataType.DOUBLE));
Model pmmlModel = encodeTreeEnsemble(CMatrixUtil.getRow(treeModels, ntreesPerGroup, ntreeGroups, i), (List<TreeModel> ensembleTreeModels) -> {
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, ensembleTreeModels));

return miningModel;
});

pmmlModel.setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create("drfValue", categoricalLabel.getValue(i)), OpType.CONTINUOUS, DataType.DOUBLE));

models.add(miningModel);
models.add(pmmlModel);
}

return MiningModelUtil.createClassification(models, RegressionModel.NormalizationMethod.SIMPLEMAX, true, schema);
Expand Down
53 changes: 37 additions & 16 deletions pmml-h2o/src/main/java/org/jpmml/h2o/GbmMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.mining.Segmentation.MultipleModelMethod;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CMatrixUtil;
Expand All @@ -48,7 +47,7 @@ public GbmMojoModelConverter(GbmMojoModel model){
}

@Override
public MiningModel encodeModel(Schema schema){
public Model encodeModel(Schema schema){
GbmMojoModel model = getModel();

int ntreeGroups = getNTreeGroups(model);
Expand All @@ -61,33 +60,50 @@ public MiningModel encodeModel(Schema schema){
if(model._family == DistributionFamily.gaussian){
ContinuousLabel continuousLabel = (ContinuousLabel)label;

MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethod.SUM, treeModels))
.setTargets(ModelUtil.createRescaleTargets(null, model._init_f, continuousLabel));
Model pmmlModel = encodeTreeEnsemble(treeModels, (List<TreeModel> ensembleTreeModels) -> {
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, ensembleTreeModels));

return miningModel;
return miningModel;
});

pmmlModel.setTargets(ModelUtil.createRescaleTargets(null, model._init_f, continuousLabel));

return pmmlModel;
} else

if((model._family == DistributionFamily.poisson) || (model._family == DistributionFamily.gamma) || (model._family == DistributionFamily.tweedie)){
ContinuousLabel continuousLabel = new ContinuousLabel(DataType.DOUBLE);

MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethod.SUM, treeModels))
Model pmmlModel = encodeTreeEnsemble(treeModels, (List<TreeModel> ensembleTreeModels) -> {
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, ensembleTreeModels));

return miningModel;
});

pmmlModel
.setTargets(ModelUtil.createRescaleTargets(null, model._init_f, continuousLabel))
.setOutput(ModelUtil.createPredictedOutput("gbmValue", OpType.CONTINUOUS, DataType.DOUBLE));

return MiningModelUtil.createRegression(miningModel, RegressionModel.NormalizationMethod.EXP, schema);
return MiningModelUtil.createRegression(pmmlModel, RegressionModel.NormalizationMethod.EXP, schema);
} else

if(model._family == DistributionFamily.bernoulli){
ContinuousLabel continuousLabel = new ContinuousLabel(DataType.DOUBLE);

MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, treeModels))
Model pmmlModel = encodeTreeEnsemble(treeModels, (List<TreeModel> ensembleTreeModels) -> {
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, ensembleTreeModels));

return miningModel;
});

pmmlModel
.setTargets(ModelUtil.createRescaleTargets(null, model._init_f, continuousLabel))
.setOutput(ModelUtil.createPredictedOutput("gbmValue", OpType.CONTINUOUS, DataType.DOUBLE));

return MiningModelUtil.createBinaryLogisticClassification(miningModel, 1d, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
return MiningModelUtil.createBinaryLogisticClassification(pmmlModel, 1d, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
} else

if(model._family == DistributionFamily.multinomial){
Expand All @@ -96,11 +112,16 @@ public MiningModel encodeModel(Schema schema){
List<Model> models = new ArrayList<>();

for(int i = 0; i < categoricalLabel.size(); i++){
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, CMatrixUtil.getRow(treeModels, ntreesPerGroup, ntreeGroups, i)))
.setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create("gbmValue", categoricalLabel.getValue(i)), OpType.CONTINUOUS, DataType.DOUBLE));
Model pmmlModel = encodeTreeEnsemble(CMatrixUtil.getRow(treeModels, ntreesPerGroup, ntreeGroups, i), (List<TreeModel> ensembleTreeModels) -> {
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, ensembleTreeModels));

return miningModel;
});

pmmlModel.setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create("gbmValue", categoricalLabel.getValue(i)), OpType.CONTINUOUS, DataType.DOUBLE));

models.add(miningModel);
models.add(pmmlModel);
}

return MiningModelUtil.createClassification(models, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.google.common.collect.Iterables;
import hex.genmodel.algos.tree.NaSplitDir;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
Expand Down Expand Up @@ -73,6 +77,16 @@ public List<TreeModel> encodeTreeModels(Schema schema){
return result;
}

static
public Model encodeTreeEnsemble(List<TreeModel> treeModels, Function<List<TreeModel>, MiningModel> ensembleFunction){

if(treeModels.size() == 1){
return Iterables.getOnlyElement(treeModels);
}

return ensembleFunction.apply(treeModels);
}

static
public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predicateManager, Schema schema){
Label label = new ContinuousLabel(DataType.DOUBLE);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2022 Villu Ruusmann
*
* This file is part of JPMML-H2O
*
* JPMML-H2O is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-H2O is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-H2O. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.h2o.testing;

import java.util.List;

import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.visitors.AbstractVisitor;

public class SegmentationInspector extends AbstractVisitor {

@Override
public VisitorAction visit(Segmentation segmentation){
Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();

switch(multipleModelMethod){
case MAJORITY_VOTE:
case WEIGHTED_MAJORITY_VOTE:
case AVERAGE:
case WEIGHTED_AVERAGE:
case MEDIAN:
case WEIGHTED_MEDIAN:
case SUM:
case WEIGHTED_SUM:
{
List<Segment> segments = segmentation.getSegments();

if(segments.size() <= 1){
throw new InvalidElementException(segmentation);
}
}
break;
default:
break;
}

return super.visit(segmentation);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2022 Villu Ruusmann
*
* This file is part of JPMML-H2O
*
* JPMML-H2O is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-H2O is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-H2O. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.h2o.testing;

import java.util.function.Predicate;

import com.google.common.base.Equivalence;
import org.jpmml.converter.testing.Datasets;
import org.jpmml.converter.testing.Fields;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.PMMLEquivalence;
import org.jpmml.model.visitors.VisitorBattery;
import org.junit.Test;

public class TreeMojoModelConverterTest extends H2OEncoderBatchTest implements Datasets, Fields {

public TreeMojoModelConverterTest(){
super(new PMMLEquivalence(1e-13, 1e-13));
}

@Override
public H2OEncoderBatch createBatch(String algorithm, String dataset, Predicate<ResultField> columnFilter, Equivalence<Object> equivalence){
H2OEncoderBatch result = new H2OEncoderBatch(algorithm, dataset, columnFilter, equivalence){

@Override
public TreeMojoModelConverterTest getArchiveBatchTest(){
return TreeMojoModelConverterTest.this;
}

@Override
public VisitorBattery getValidators(){
VisitorBattery visitorBattery = super.getValidators();

visitorBattery.add(SegmentationInspector.class);

return visitorBattery;
}
};

return result;
}

@Test
public void evaluateAudit() throws Exception {
evaluate("DecisionTree", AUDIT);
}

@Test
public void evaluateAuditNA() throws Exception {
evaluate("DecisionTree", AUDIT_NA, excludeFields(AUDIT_ADJUSTED));
}

@Test
public void evaluateAuto() throws Exception {
evaluate("DecisionTree", AUTO);
}

@Test
public void evaluateAutoNA() throws Exception {
evaluate("DecisionTree", AUTO_NA);
}

@Test
public void evaluateIris() throws Exception {
evaluate("DecisionTree", IRIS);
}
}
Loading

0 comments on commit 488b33c

Please sign in to comment.