diff --git a/pmml-h2o/src/main/java/org/jpmml/h2o/GbmMojoModelConverter.java b/pmml-h2o/src/main/java/org/jpmml/h2o/GbmMojoModelConverter.java index 2b9506f..a5ecb5a 100644 --- a/pmml-h2o/src/main/java/org/jpmml/h2o/GbmMojoModelConverter.java +++ b/pmml-h2o/src/main/java/org/jpmml/h2o/GbmMojoModelConverter.java @@ -30,6 +30,7 @@ import org.dmg.pmml.mining.MiningModel; import org.dmg.pmml.mining.Segmentation; import org.dmg.pmml.regression.RegressionModel; +import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.CMatrixUtil; import org.jpmml.converter.CategoricalLabel; @@ -131,4 +132,9 @@ public Model encodeModel(Schema schema){ throw new IllegalArgumentException("Distribution family " + model._family + " is not supported"); } } + + @Override + protected void ensureScore(Node node, double score){ + return; + } } \ No newline at end of file diff --git a/pmml-h2o/src/main/java/org/jpmml/h2o/IsolationForestMojoModelConverter.java b/pmml-h2o/src/main/java/org/jpmml/h2o/IsolationForestMojoModelConverter.java index 1b5926b..fb64890 100644 --- a/pmml-h2o/src/main/java/org/jpmml/h2o/IsolationForestMojoModelConverter.java +++ b/pmml-h2o/src/main/java/org/jpmml/h2o/IsolationForestMojoModelConverter.java @@ -30,6 +30,7 @@ import org.dmg.pmml.PMMLFunctions; import org.dmg.pmml.mining.MiningModel; import org.dmg.pmml.mining.Segmentation; +import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.PMMLUtil; @@ -82,6 +83,11 @@ public Expression createExpression(FieldRef fieldRef){ return miningModel; } + @Override + protected void ensureScore(Node node, double score){ + return; + } + static public int getMaxPathLength(IsolationForestMojoModel model){ return (int)getFieldValue(FIELD_MAX_PATH_LENGTH, model); diff --git a/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java b/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java index ccecd92..c3d7993 100644 --- a/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java +++ b/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.Function; import com.google.common.collect.Iterables; @@ -30,7 +31,6 @@ import hex.genmodel.utils.ByteBufferWrapper; import hex.genmodel.utils.GenmodelBitSet; import org.dmg.pmml.DataType; -import org.dmg.pmml.HasRecordCount; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.Model; import org.dmg.pmml.Predicate; @@ -85,17 +85,6 @@ public List encodeTreeModels(Schema schema){ return result; } - static - public Model encodeTreeEnsemble(List treeModels, Function, MiningModel> ensembleFunction){ - - if(treeModels.size() == 1){ - return Iterables.getOnlyElement(treeModels); - } - - return ensembleFunction.apply(treeModels); - } - - static public TreeModel encodeTreeModel(byte[] compressedTree, byte[] compressedTreeAux, PredicateManager predicateManager, Schema schema){ Label label = new ContinuousLabel(DataType.DOUBLE); @@ -111,7 +100,6 @@ public TreeModel encodeTreeModel(byte[] compressedTree, byte[] compressedTreeAux return treeModel; } - static public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predicate, byte[] compressedTree, Map auxInfos, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema){ SharedTreeMojoModel.AuxInfo auxInfo = auxInfos.get(id); if(auxInfo == null){ @@ -272,6 +260,9 @@ public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predi rightChild = encodeNode(rightByteBuffer, rightId, rightPredicate, compressedTree, auxInfos, rightCategoryManager, predicateManager, schema); } + ensureScore(leftChild, auxInfo.predL); + ensureScore(rightChild, auxInfo.predR); + ensureRecordCount(leftChild, auxInfo.weightL); ensureRecordCount(rightChild, auxInfo.weightR); @@ -281,21 +272,46 @@ public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predi .addNodes(leftChild, rightChild); if(id == 0){ - ensureRecordCount(result, auxInfo.weightL + auxInfo.weightR); + float weight = (auxInfo.weightL + auxInfo.weightR); + + ensureScore(result, (auxInfo.predL * auxInfo.weightL + auxInfo.predR * auxInfo.weightR) / weight); + ensureRecordCount(result, weight); } return result; } - static - private void ensureRecordCount(Node node, double recordCount){ - HasRecordCount hasRecordCount = (HasRecordCount)node; + protected void ensureScore(Node node, double score){ + + if(node.hasScore()){ + + if(!Objects.equals(node.getScore(), score)){ + throw new IllegalArgumentException(); + } + } else - if(hasRecordCount.getRecordCount() != null){ + { + node.setScore(score); + } + } + + protected void ensureRecordCount(Node node, double recordCount){ + + if(node.getRecordCount() != null){ throw new IllegalArgumentException(); } - hasRecordCount.setRecordCount(ValueUtil.narrow(recordCount)); + node.setRecordCount(ValueUtil.narrow(recordCount)); + } + + static + public Model encodeTreeEnsemble(List treeModels, Function, MiningModel> ensembleFunction){ + + if(treeModels.size() == 1){ + return Iterables.getOnlyElement(treeModels); + } + + return ensembleFunction.apply(treeModels); } static