Skip to content

Commit

Permalink
Added support for branch node scores
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jun 15, 2023
1 parent 3d4f61f commit 0c3095c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
Expand Down Expand Up @@ -131,4 +132,9 @@ public Model encodeModel(Schema schema){
throw new IllegalArgumentException("Distribution family " + model._family + " is not supported");
}
}

@Override
protected void ensureScore(Node node, double score){
return;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
Expand Down Expand Up @@ -82,6 +83,11 @@ public Expression createExpression(FieldRef fieldRef){
return miningModel;
}

@Override
protected void ensureScore(Node node, double score){
return;
}

static
public int getMaxPathLength(IsolationForestMojoModel model){
return (int)getFieldValue(FIELD_MAX_PATH_LENGTH, model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

import com.google.common.collect.Iterables;
Expand All @@ -30,7 +31,6 @@
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import org.dmg.pmml.DataType;
import org.dmg.pmml.HasRecordCount;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
Expand Down Expand Up @@ -85,17 +85,6 @@ public List<TreeModel> encodeTreeModels(Schema schema){
return result;
}

static
public Model encodeTreeEnsemble(List<TreeModel> treeModels, Function<List<TreeModel>, MiningModel> ensembleFunction){

if(treeModels.size() == 1){
return Iterables.getOnlyElement(treeModels);
}

return ensembleFunction.apply(treeModels);
}

static
public TreeModel encodeTreeModel(byte[] compressedTree, byte[] compressedTreeAux, PredicateManager predicateManager, Schema schema){
Label label = new ContinuousLabel(DataType.DOUBLE);

Expand All @@ -111,7 +100,6 @@ public TreeModel encodeTreeModel(byte[] compressedTree, byte[] compressedTreeAux
return treeModel;
}

static
public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predicate, byte[] compressedTree, Map<Integer, SharedTreeMojoModel.AuxInfo> auxInfos, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema){
SharedTreeMojoModel.AuxInfo auxInfo = auxInfos.get(id);
if(auxInfo == null){
Expand Down Expand Up @@ -272,6 +260,9 @@ public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predi
rightChild = encodeNode(rightByteBuffer, rightId, rightPredicate, compressedTree, auxInfos, rightCategoryManager, predicateManager, schema);
}

ensureScore(leftChild, auxInfo.predL);
ensureScore(rightChild, auxInfo.predR);

ensureRecordCount(leftChild, auxInfo.weightL);
ensureRecordCount(rightChild, auxInfo.weightR);

Expand All @@ -281,21 +272,46 @@ public Node encodeNode(ByteBufferWrapper byteBuffer, Integer id, Predicate predi
.addNodes(leftChild, rightChild);

if(id == 0){
ensureRecordCount(result, auxInfo.weightL + auxInfo.weightR);
float weight = (auxInfo.weightL + auxInfo.weightR);

ensureScore(result, (auxInfo.predL * auxInfo.weightL + auxInfo.predR * auxInfo.weightR) / weight);
ensureRecordCount(result, weight);
}

return result;
}

static
private void ensureRecordCount(Node node, double recordCount){
HasRecordCount<?> hasRecordCount = (HasRecordCount<?>)node;
protected void ensureScore(Node node, double score){

if(node.hasScore()){

if(!Objects.equals(node.getScore(), score)){
throw new IllegalArgumentException();
}
} else

if(hasRecordCount.getRecordCount() != null){
{
node.setScore(score);
}
}

protected void ensureRecordCount(Node node, double recordCount){

if(node.getRecordCount() != null){
throw new IllegalArgumentException();
}

hasRecordCount.setRecordCount(ValueUtil.narrow(recordCount));
node.setRecordCount(ValueUtil.narrow(recordCount));
}

static
public Model encodeTreeEnsemble(List<TreeModel> treeModels, Function<List<TreeModel>, MiningModel> ensembleFunction){

if(treeModels.size() == 1){
return Iterables.getOnlyElement(treeModels);
}

return ensembleFunction.apply(treeModels);
}

static
Expand Down

0 comments on commit 0c3095c

Please sign in to comment.