Skip to content

Commit

Permalink
Extracted SharedTree interface
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jul 9, 2023
1 parent ac2de2e commit e1ccc9a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 40 deletions.
33 changes: 33 additions & 0 deletions pmml-h2o/src/main/java/org/jpmml/h2o/SharedTree.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023 Villu Ruusmann
*
* This file is part of JPMML-H2O
*
* JPMML-H2O is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-H2O is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-H2O. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.h2o;

import hex.genmodel.algos.tree.SharedTreeMojoModel;
import org.dmg.pmml.tree.Node;

public interface SharedTree {

byte[] getCompressedTree();

byte[] getCompressedTreeAux();

SharedTreeMojoModel.AuxInfo getAuxInfo(int id);

void encodeAuxInfo(Node node, double score, double recordCount);
}
106 changes: 66 additions & 40 deletions pmml-h2o/src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,31 +77,84 @@ public List<TreeModel> encodeTreeModels(Schema schema){
byte[] compressedTree = compressedTrees[i];
byte[] compressedTreeAux = compressedTreesAux[i];

TreeModel treeModel = encodeTreeModel(compressedTree, compressedTreeAux, predicateManager, schema);
Map<Integer, SharedTreeMojoModel.AuxInfo> auxInfos = SharedTreeMojoModel.readAuxInfos(compressedTreeAux);

SharedTree sharedTree = new SharedTree(){

@Override
public byte[] getCompressedTree(){
return compressedTree;
}

@Override
public byte[] getCompressedTreeAux(){
return compressedTreeAux;
}

@Override
public SharedTreeMojoModel.AuxInfo getAuxInfo(int id){
return auxInfos.get(id);
}

@Override
public void encodeAuxInfo(Node node, double score, double recordCount){
ensureScore(node, score);
ensureRecordCount(node, recordCount);
}
};

TreeModel treeModel = encodeTreeModel(sharedTree, predicateManager, schema);

result.add(treeModel);
}

return result;
}

public TreeModel encodeTreeModel(byte[] compressedTree, byte[] compressedTreeAux, PredicateManager predicateManager, Schema schema){
Label label = new ContinuousLabel(DataType.DOUBLE);
protected void ensureScore(Node node, double score){

ByteBufferWrapper byteBuffer = new ByteBufferWrapper(compressedTree);
if(node.hasScore()){

if(!Objects.equals(node.getScore(), score)){
throw new IllegalArgumentException();
}
} else

{
node.setScore(score);
}
}

protected void ensureRecordCount(Node node, double recordCount){

if(node.getRecordCount() != null){
throw new IllegalArgumentException();
}

node.setRecordCount(ValueUtil.narrow(recordCount));
}

Map<Integer, SharedTreeMojoModel.AuxInfo> auxInfos = SharedTreeMojoModel.readAuxInfos(compressedTreeAux);
static
public TreeModel encodeTreeModel(SharedTree sharedTree, PredicateManager predicateManager, Schema schema){
Label label = new ContinuousLabel(DataType.DOUBLE);

Node root = encodeNode(byteBuffer, 0, True.INSTANCE, compressedTree, auxInfos, new CategoryManager(), predicateManager, schema);
Node root = encodeNode(sharedTree, null, 0, True.INSTANCE, new CategoryManager(), predicateManager, schema);

TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(label), root)
.setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);

return treeModel;
}

public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predicate, byte[] compressedTree, Map<Integer, SharedTreeMojoModel.AuxInfo> auxInfos, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema){
SharedTreeMojoModel.AuxInfo auxInfo = auxInfos.get(id);
static
public Node encodeNode(SharedTree sharedTree, ByteBufferWrapper byteBuffer, Integer id, Predicate predicate, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema){
byte[] compressedTree = sharedTree.getCompressedTree();

if(byteBuffer == null){
byteBuffer = new ByteBufferWrapper(compressedTree);
}

SharedTreeMojoModel.AuxInfo auxInfo = sharedTree.getAuxInfo(id);
if(auxInfo == null){
throw new IllegalArgumentException();
}
Expand Down Expand Up @@ -219,7 +272,7 @@ public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predi
} else

{
leftChild = encodeNode(leftByteBuffer, leftId, leftPredicate, compressedTree, auxInfos, leftCategoryManager, predicateManager, schema);
leftChild = encodeNode(sharedTree, leftByteBuffer, leftId, leftPredicate, leftCategoryManager, predicateManager, schema);
}

Node rightChild;
Expand Down Expand Up @@ -257,14 +310,11 @@ public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predi
} else

{
rightChild = encodeNode(rightByteBuffer, rightId, rightPredicate, compressedTree, auxInfos, rightCategoryManager, predicateManager, schema);
rightChild = encodeNode(sharedTree, rightByteBuffer, rightId, rightPredicate, rightCategoryManager, predicateManager, schema);
}

ensureScore(leftChild, auxInfo.predL);
ensureScore(rightChild, auxInfo.predR);

ensureRecordCount(leftChild, auxInfo.weightL);
ensureRecordCount(rightChild, auxInfo.weightR);
sharedTree.encodeAuxInfo(leftChild, auxInfo.predL, auxInfo.weightL);
sharedTree.encodeAuxInfo(rightChild, auxInfo.predR, auxInfo.weightR);

Node result = new CountingBranchNode(null, predicate)
.setId(id)
Expand All @@ -274,36 +324,12 @@ public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predi
if(id == 0){
float weight = (auxInfo.weightL + auxInfo.weightR);

ensureScore(result, (auxInfo.predL * auxInfo.weightL + auxInfo.predR * auxInfo.weightR) / weight);
ensureRecordCount(result, weight);
sharedTree.encodeAuxInfo(result, (auxInfo.predL * auxInfo.weightL + auxInfo.predR * auxInfo.weightR) / weight, weight);
}

return result;
}

protected void ensureScore(Node node, double score){

if(node.hasScore()){

if(!Objects.equals(node.getScore(), score)){
throw new IllegalArgumentException();
}
} else

{
node.setScore(score);
}
}

protected void ensureRecordCount(Node node, double recordCount){

if(node.getRecordCount() != null){
throw new IllegalArgumentException();
}

node.setRecordCount(ValueUtil.narrow(recordCount));
}

static
public Model encodeTreeEnsemble(List<TreeModel> treeModels, Function<List<TreeModel>, MiningModel> ensembleFunction){

Expand Down

0 comments on commit e1ccc9a

Please sign in to comment.