diff --git a/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTree.java b/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTree.java index cafbbcb..836f230 100644 --- a/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTree.java +++ b/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTree.java @@ -27,6 +27,8 @@ public interface SharedTree { byte[] getCompressedTreeAux(); + Integer nextId(); + SharedTreeMojoModel.AuxInfo getAuxInfo(int id); void encodeAuxInfo(Node node, double score, double recordCount); diff --git a/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java b/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java index 69be459..0ac415f 100644 --- a/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java +++ b/pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import com.google.common.collect.Iterables; @@ -81,6 +82,9 @@ public List encodeTreeModels(Schema schema){ SharedTree sharedTree = new SharedTree(){ + private AtomicInteger idSequence = new AtomicInteger(0); + + @Override public byte[] getCompressedTree(){ return compressedTree; @@ -91,6 +95,11 @@ public byte[] getCompressedTreeAux(){ return compressedTreeAux; } + @Override + public Integer nextId(){ + return this.idSequence.getAndIncrement(); + } + @Override public SharedTreeMojoModel.AuxInfo getAuxInfo(int id){ return auxInfos.get(id); @@ -138,7 +147,7 @@ protected void ensureRecordCount(Node node, double recordCount){ public TreeModel encodeTreeModel(SharedTree sharedTree, PredicateManager predicateManager, Schema schema){ Label label = new ContinuousLabel(DataType.DOUBLE); - Node root = encodeNode(sharedTree, null, 0, True.INSTANCE, new CategoryManager(), predicateManager, schema); + Node root = encodeNode(sharedTree, null, sharedTree.nextId(), True.INSTANCE, new CategoryManager(), predicateManager, schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label), root) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD); @@ -155,9 +164,6 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte } SharedTreeMojoModel.AuxInfo auxInfo = sharedTree.getAuxInfo(id); - if(auxInfo == null){ - throw new IllegalArgumentException(); - } int nodeType = byteBuffer.get1U(); @@ -171,7 +177,7 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte double score = byteBuffer.get4f(); Node result = new CountingLeafNode(score, predicate) - .setId(id); + .setId(toNodeId(auxInfo != null, id)); return result; } @@ -255,7 +261,7 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte Node leftChild; - Integer leftId = auxInfo.nidL; + Integer leftId = (auxInfo != null ? auxInfo.nidL : sharedTree.nextId()); ByteBufferWrapper leftByteBuffer = new ByteBufferWrapper(compressedTree); leftByteBuffer.skip(byteBuffer.position()); @@ -268,7 +274,7 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte double score = leftByteBuffer.get4f(); leftChild = new CountingLeafNode(score, leftPredicate) - .setId(leftId); + .setId(toNodeId(auxInfo != null, leftId)); } else { @@ -277,7 +283,7 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte Node rightChild; - Integer rightId = auxInfo.nidR; + Integer rightId = (auxInfo != null ? auxInfo.nidR : sharedTree.nextId()); ByteBufferWrapper rightByteBuffer = new ByteBufferWrapper(compressedTree); rightByteBuffer.skip(byteBuffer.position()); @@ -306,25 +312,30 @@ public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Inte double score = rightByteBuffer.get4f(); rightChild = new CountingLeafNode(score, rightPredicate) - .setId(rightId); + .setId(toNodeId(auxInfo != null, rightId)); } else { rightChild = encodeNode(sharedTree, rightByteBuffer, rightId, rightPredicate, rightCategoryManager, predicateManager, schema); - } + } // End if - sharedTree.encodeAuxInfo(leftChild, auxInfo.predL, auxInfo.weightL); - sharedTree.encodeAuxInfo(rightChild, auxInfo.predR, auxInfo.weightR); + if(auxInfo != null){ + sharedTree.encodeAuxInfo(leftChild, auxInfo.predL, auxInfo.weightL); + sharedTree.encodeAuxInfo(rightChild, auxInfo.predR, auxInfo.weightR); + } Node result = new CountingBranchNode(null, predicate) - .setId(id) + .setId(toNodeId(auxInfo != null, id)) .setDefaultChild(leftward ? leftChild.getId() : rightChild.getId()) .addNodes(leftChild, rightChild); - if(id == 0){ - float weight = (auxInfo.weightL + auxInfo.weightR); + if(auxInfo != null){ - sharedTree.encodeAuxInfo(result, (auxInfo.predL * auxInfo.weightL + auxInfo.predR * auxInfo.weightR) / weight, weight); + if(id == 0){ + float weight = (auxInfo.weightL + auxInfo.weightR); + + sharedTree.encodeAuxInfo(result, (auxInfo.predL * auxInfo.weightL + auxInfo.predR * auxInfo.weightR) / weight, weight); + } } return result; @@ -360,6 +371,11 @@ public int getNTreesPerGroup(SharedTreeMojoModel model){ return (int)getFieldValue(FIELD_NTREESPERGROUP, model); } + static + private Integer toNodeId(boolean hasAux, Integer id){ + return (hasAux ? id : id + 1); + } + private static final Field FIELD_COMPRESSEDTREES; private static final Field FIELD_COMPRESSEDTREESAUX; private static final Field FIELD_NTREEGROUPS;