From 04510d0ad446bee001ed7bf93a09d0a3c89ffd1c Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Fri, 18 Jan 2019 14:18:19 +0200 Subject: [PATCH] Optimized the encoding of Node elements --- .../h2o/SharedTreeMojoModelConverter.java | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java b/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java index 724c627..4d7982f 100644 --- a/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java +++ b/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java @@ -33,7 +33,8 @@ import org.dmg.pmml.Predicate; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.True; -import org.dmg.pmml.tree.ComplexNode; +import org.dmg.pmml.tree.BranchNode; +import org.dmg.pmml.tree.LeafNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.CategoricalFeature; @@ -58,15 +59,11 @@ public SharedTreeMojoModelConverter(M model){ public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predicateManager, Schema schema){ Label label = new ContinuousLabel(null, DataType.DOUBLE); - AtomicInteger id = new AtomicInteger(1); - - Node root = new ComplexNode() - .setId(Integer.toString(id.getAndIncrement())) - .setPredicate(new True()); + AtomicInteger idSequence = new AtomicInteger(1); ByteBufferWrapper buffer = new ByteBufferWrapper(compressedTree); - encodeNode(root, id, compressedTree, buffer, predicateManager, new CategoryManager(), schema); + Node root = encodeNode(new True(), idSequence, compressedTree, buffer, predicateManager, new CategoryManager(), schema); TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label), root) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD); @@ -75,7 +72,9 @@ public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predica } static - public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteBufferWrapper byteBuffer, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema){ + public Node encodeNode(Predicate predicate, AtomicInteger idSequence, byte[] compressedTree, ByteBufferWrapper byteBuffer, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema){ + String id = nextId(idSequence); + int nodeType = byteBuffer.get1U(); int lmask = (nodeType & 51); @@ -167,9 +166,7 @@ public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteB } } - Node leftChild = new ComplexNode() - .setId(String.valueOf(id.getAndIncrement())) - .setPredicate(leftPredicate); + Node leftChild; ByteBufferWrapper leftByteBuffer = new ByteBufferWrapper(compressedTree); leftByteBuffer.skip(byteBuffer.position()); @@ -181,16 +178,17 @@ public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteB if((lmask & 16) != 0){ double score = leftByteBuffer.get4f(); - leftChild.setScore(ValueUtil.formatValue(score)); + leftChild = new LeafNode() + .setId(nextId(idSequence)) + .setScore(score) + .setPredicate(leftPredicate); } else { - encodeNode(leftChild, id, compressedTree, leftByteBuffer, predicateManager, leftCategoryManager, schema); + leftChild = encodeNode(leftPredicate, idSequence, compressedTree, leftByteBuffer, predicateManager, leftCategoryManager, schema); } - Node rightChild = new ComplexNode() - .setId(String.valueOf(id.getAndIncrement())) - .setPredicate(rightPredicate); + Node rightChild; ByteBufferWrapper rightByteBuffer = new ByteBufferWrapper(compressedTree); rightByteBuffer.skip(byteBuffer.position()); @@ -218,16 +216,23 @@ public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteB if((lmask2 & 16) != 0){ double score = rightByteBuffer.get4f(); - rightChild.setScore(ValueUtil.formatValue(score)); + rightChild = new LeafNode() + .setId(nextId(idSequence)) + .setScore(score) + .setPredicate(rightPredicate); } else { - encodeNode(rightChild, id, compressedTree, rightByteBuffer, predicateManager, rightCategoryManager, schema); + rightChild = encodeNode(rightPredicate, idSequence, compressedTree, rightByteBuffer, predicateManager, rightCategoryManager, schema); } - node - .addNodes(leftChild, rightChild) - .setDefaultChild(leftward ? leftChild.getId() : rightChild.getId()); + Node result = new BranchNode() + .setId(id) + .setDefaultChild(leftward ? leftChild.getId() : rightChild.getId()) + .setPredicate(predicate) + .addNodes(leftChild, rightChild); + + return result; } static @@ -245,6 +250,11 @@ public int getNTreesPerGroup(SharedTreeMojoModel model){ return (int)getFieldValue(FIELD_NTREESPERGROUP, model); } + static + private String nextId(AtomicInteger id){ + return String.valueOf(id.getAndIncrement()); + } + private static final Field FIELD_COMPRESSEDTREES; private static final Field FIELD_NTREEGROUPS; private static final Field FIELD_NTREESPERGROUP;