Skip to content

Commit

Permalink
Optimized the encoding of Node elements
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 18, 2019
1 parent 07689c0 commit 04510d0
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ComplexNode;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
Expand All @@ -58,15 +59,11 @@ public SharedTreeMojoModelConverter(M model){
public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predicateManager, Schema schema){
Label label = new ContinuousLabel(null, DataType.DOUBLE);

AtomicInteger id = new AtomicInteger(1);

Node root = new ComplexNode()
.setId(Integer.toString(id.getAndIncrement()))
.setPredicate(new True());
AtomicInteger idSequence = new AtomicInteger(1);

ByteBufferWrapper buffer = new ByteBufferWrapper(compressedTree);

encodeNode(root, id, compressedTree, buffer, predicateManager, new CategoryManager(), schema);
Node root = encodeNode(new True(), idSequence, compressedTree, buffer, predicateManager, new CategoryManager(), schema);

TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label), root)
.setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
Expand All @@ -75,7 +72,9 @@ public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predica
}

static
public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteBufferWrapper byteBuffer, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema){
public Node encodeNode(Predicate predicate, AtomicInteger idSequence, byte[] compressedTree, ByteBufferWrapper byteBuffer, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema){
String id = nextId(idSequence);

int nodeType = byteBuffer.get1U();

int lmask = (nodeType & 51);
Expand Down Expand Up @@ -167,9 +166,7 @@ public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteB
}
}

Node leftChild = new ComplexNode()
.setId(String.valueOf(id.getAndIncrement()))
.setPredicate(leftPredicate);
Node leftChild;

ByteBufferWrapper leftByteBuffer = new ByteBufferWrapper(compressedTree);
leftByteBuffer.skip(byteBuffer.position());
Expand All @@ -181,16 +178,17 @@ public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteB
if((lmask & 16) != 0){
double score = leftByteBuffer.get4f();

leftChild.setScore(ValueUtil.formatValue(score));
leftChild = new LeafNode()
.setId(nextId(idSequence))
.setScore(score)
.setPredicate(leftPredicate);
} else

{
encodeNode(leftChild, id, compressedTree, leftByteBuffer, predicateManager, leftCategoryManager, schema);
leftChild = encodeNode(leftPredicate, idSequence, compressedTree, leftByteBuffer, predicateManager, leftCategoryManager, schema);
}

Node rightChild = new ComplexNode()
.setId(String.valueOf(id.getAndIncrement()))
.setPredicate(rightPredicate);
Node rightChild;

ByteBufferWrapper rightByteBuffer = new ByteBufferWrapper(compressedTree);
rightByteBuffer.skip(byteBuffer.position());
Expand Down Expand Up @@ -218,16 +216,23 @@ public void encodeNode(Node node, AtomicInteger id, byte[] compressedTree, ByteB
if((lmask2 & 16) != 0){
double score = rightByteBuffer.get4f();

rightChild.setScore(ValueUtil.formatValue(score));
rightChild = new LeafNode()
.setId(nextId(idSequence))
.setScore(score)
.setPredicate(rightPredicate);
} else

{
encodeNode(rightChild, id, compressedTree, rightByteBuffer, predicateManager, rightCategoryManager, schema);
rightChild = encodeNode(rightPredicate, idSequence, compressedTree, rightByteBuffer, predicateManager, rightCategoryManager, schema);
}

node
.addNodes(leftChild, rightChild)
.setDefaultChild(leftward ? leftChild.getId() : rightChild.getId());
Node result = new BranchNode()
.setId(id)
.setDefaultChild(leftward ? leftChild.getId() : rightChild.getId())
.setPredicate(predicate)
.addNodes(leftChild, rightChild);

return result;
}

static
Expand All @@ -245,6 +250,11 @@ public int getNTreesPerGroup(SharedTreeMojoModel model){
return (int)getFieldValue(FIELD_NTREESPERGROUP, model);
}

static
private String nextId(AtomicInteger id){
return String.valueOf(id.getAndIncrement());
}

private static final Field FIELD_COMPRESSEDTREES;
private static final Field FIELD_NTREEGROUPS;
private static final Field FIELD_NTREESPERGROUP;
Expand Down

0 comments on commit 04510d0

Please sign in to comment.