Skip to content

Commit

Permalink
Added method SharedTreeMojoModelConverter#encodeTreeModels(Schema)
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Oct 6, 2019
1 parent db89a47 commit 826cce7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 48 deletions.
14 changes: 1 addition & 13 deletions src/main/java/org/jpmml/h2o/DrfMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import hex.genmodel.algos.drf.DrfMojoModel;
import org.dmg.pmml.DataType;
Expand All @@ -40,7 +38,6 @@
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;

Expand All @@ -55,21 +52,12 @@ public MiningModel encodeModel(Schema schema){
DrfMojoModel model = getModel();

boolean binomialDoubleTrees = getBinomialDoubleTrees(model);
byte[][] compressedTrees = getCompressedTrees(model);
int ntreeGroups = getNTreeGroups(model);
int ntreesPerGroup = getNTreesPerGroup(model);

if(model._mojo_version < 1.2d){
throw new IllegalArgumentException("Version " + model._mojo_version + " is not supported");
}

Label label = schema.getLabel();

PredicateManager predicateManager = new PredicateManager();

List<TreeModel> treeModels = Stream.of(compressedTrees)
.map(compressedTree -> encodeTreeModel(compressedTree, predicateManager, schema))
.collect(Collectors.toList());
List<TreeModel> treeModels = encodeTreeModels(schema);

if(model._nclasses == 1){
ContinuousLabel continuousLabel = (ContinuousLabel)label;
Expand Down
14 changes: 1 addition & 13 deletions src/main/java/org/jpmml/h2o/GbmMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import hex.genmodel.algos.gbm.GbmMojoModel;
import hex.genmodel.utils.DistributionFamily;
Expand All @@ -40,7 +38,6 @@
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;

Expand All @@ -54,21 +51,12 @@ public GbmMojoModelConverter(GbmMojoModel model){
public MiningModel encodeModel(Schema schema){
GbmMojoModel model = getModel();

byte[][] compressedTrees = getCompressedTrees(model);
int ntreeGroups = getNTreeGroups(model);
int ntreesPerGroup = getNTreesPerGroup(model);

if(model._mojo_version < 1.2d){
throw new IllegalArgumentException("Version " + model._mojo_version + " is not supported");
}

Label label = schema.getLabel();

PredicateManager predicateManager = new PredicateManager();

List<TreeModel> treeModels = Stream.of(compressedTrees)
.map(compressedTree -> encodeTreeModel(compressedTree, predicateManager, schema))
.collect(Collectors.toList());
List<TreeModel> treeModels = encodeTreeModels(schema);

if((DistributionFamily.gaussian).equals(model._family)){
ContinuousLabel continuousLabel = (ContinuousLabel)label;
Expand Down
28 changes: 6 additions & 22 deletions src/main/java/org/jpmml/h2o/IsolationForestMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import java.lang.reflect.Field;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import hex.genmodel.algos.isofor.IsolationForestMojoModel;
import org.dmg.pmml.DataType;
Expand All @@ -37,7 +35,6 @@
import org.jpmml.converter.AbstractTransformation;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
Expand All @@ -52,30 +49,17 @@ public IsolationForestMojoModelConverter(IsolationForestMojoModel model){
public MiningModel encodeModel(Schema schema){
IsolationForestMojoModel model = getModel();

byte[][] compressedTrees = getCompressedTrees(model);
int minPathLength = getMinPathLength(model);
int maxPathLength = getMaxPathLength(model);

if(model._mojo_version < 1.2d){
throw new IllegalArgumentException("Version " + model._mojo_version + " is not supported");
if(minPathLength >= maxPathLength){
throw new IllegalArgumentException();
}

PredicateManager predicateManager = new PredicateManager();

List<TreeModel> treeModels = Stream.of(compressedTrees)
.map(compressedTree -> encodeTreeModel(compressedTree, predicateManager, schema))
.collect(Collectors.toList());
List<TreeModel> treeModels = encodeTreeModels(schema);

Transformation anomalyScore = new AbstractTransformation(){

private int minPathLength = getMinPathLength(model);
private int maxPathLength = getMaxPathLength(model);


{
if(this.minPathLength >= this.maxPathLength){
throw new IllegalArgumentException();
}
}

@Override
public FieldName getName(FieldName name){
return FieldName.create("anomalyScore");
Expand All @@ -88,7 +72,7 @@ public boolean isFinalResult(){

@Override
public Expression createExpression(FieldRef fieldRef){
return PMMLUtil.createApply(PMMLFunctions.DIVIDE, PMMLUtil.createApply(PMMLFunctions.SUBTRACT, PMMLUtil.createConstant(this.maxPathLength / (double)treeModels.size()), fieldRef), PMMLUtil.createConstant((this.maxPathLength - this.minPathLength) / (double)treeModels.size()));
return PMMLUtil.createApply(PMMLFunctions.DIVIDE, PMMLUtil.createApply(PMMLFunctions.SUBTRACT, PMMLUtil.createConstant(maxPathLength / (double)treeModels.size()), fieldRef), PMMLUtil.createConstant((maxPathLength - minPathLength) / (double)treeModels.size()));
}
};

Expand Down
20 changes: 20 additions & 0 deletions src/main/java/org/jpmml/h2o/SharedTreeMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import hex.genmodel.algos.tree.NaSplitDir;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
Expand Down Expand Up @@ -54,6 +56,24 @@ public SharedTreeMojoModelConverter(M model){
super(model);
}

public List<TreeModel> encodeTreeModels(Schema schema){
SharedTreeMojoModel model = getModel();

if(model._mojo_version < 1.2d){
throw new IllegalArgumentException("Version " + model._mojo_version + " is not supported");
}

byte[][] compressedTrees = getCompressedTrees(model);

PredicateManager predicateManager = new PredicateManager();

List<TreeModel> result = Stream.of(compressedTrees)
.map(compressedTree -> encodeTreeModel(compressedTree, predicateManager, schema))
.collect(Collectors.toList());

return result;
}

static
public TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predicateManager, Schema schema){
Label label = new ContinuousLabel(null, DataType.DOUBLE);
Expand Down

0 comments on commit 826cce7

Please sign in to comment.