Skip to content

Commit

Permalink
Updated JPMML-Converter dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jul 7, 2019
1 parent 0beda30 commit bff3e4f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 39 deletions.
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-converter</artifactId>
<version>1.3.9</version>
<version>1.3.10</version>
</dependency>

<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-xgboost</artifactId>
<version>1.3.10</version>
<version>1.3.11</version>
</dependency>

<dependency>
Expand Down
9 changes: 3 additions & 6 deletions src/main/java/org/jpmml/h2o/ImputerUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueDecorator;
import org.jpmml.converter.ModelEncoder;
Expand All @@ -36,12 +37,8 @@ public Feature encodeFeature(Feature feature, Object replacementValue, MissingVa

Field<?> field = feature.getField();

if(field instanceof DataField){
MissingValueDecorator missingValueDecorator = new MissingValueDecorator()
.setMissingValueReplacement(replacementValue)
.setMissingValueTreatment(missingValueTreatmentMethod);

encoder.addDecorator(feature.getName(), missingValueDecorator);
if((field instanceof DataField) || (field instanceof DerivedOutputField)){
encoder.addDecorator(field.getName(), new MissingValueDecorator(missingValueTreatmentMethod, replacementValue));

return feature;
} else
Expand Down
57 changes: 26 additions & 31 deletions src/main/java/org/jpmml/h2o/StackedEnsembleMojoModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;

public class StackedEnsembleMojoModelConverter extends Converter<StackedEnsembleMojoModel> {

Expand All @@ -48,20 +47,18 @@ public StackedEnsembleMojoModelConverter(StackedEnsembleMojoModel model){
}

@Override
public Model encodeModel(Schema schema){
public Schema encodeSchema(H2OEncoder encoder){
StackedEnsembleMojoModel model = getModel();

ConverterFactory converterFactory = ConverterFactory.newConverterFactory();

Schema schema = super.encodeSchema(encoder);

Label label = schema.getLabel();
List<Feature> features = new ArrayList<>();

List<Model> models = new ArrayList<>();

Schema segmentSchema = schema.toAnonymousSchema();

H2OEncoder encoder = new H2OEncoder();

Object[] baseModels = getBaseModels(model);
for(int i = 0; i < baseModels.length; i++){
Object baseModel = baseModels[i];
Expand All @@ -83,15 +80,15 @@ public Model encodeModel(Schema schema){

Model segmentModel = converter.encodeModel(baseModelSchema);

List<OutputField> outputFields = new ArrayList<>();

if(model._nclasses == 1){
ContinuousLabel continuousLabel = (ContinuousLabel)label;

OutputField predictedField = ModelUtil.createPredictedField(FieldName.create("stack(" + i + ")"), OpType.CONTINUOUS, DataType.DOUBLE)
OutputField predictedOutputField = ModelUtil.createPredictedField(FieldName.create("stack(" + i + ")"), OpType.CONTINUOUS, DataType.DOUBLE)
.setFinalResult(false);

outputFields.add(predictedField);
DerivedOutputField predictedField = encoder.createDerivedField(segmentModel, predictedOutputField, false);

features.add(new ContinuousFeature(encoder, predictedField));
} else

{
Expand All @@ -106,39 +103,37 @@ public Model encodeModel(Schema schema){
}

for(Object value : values){
OutputField probabilityField = ModelUtil.createProbabilityField(FieldName.create("stack(" + i +", " + value + ")"), DataType.DOUBLE, value)
OutputField probabilityOutputField = ModelUtil.createProbabilityField(FieldName.create("stack(" + i +", " + value + ")"), DataType.DOUBLE, value)
.setFinalResult(false);

outputFields.add(probabilityField);
DerivedOutputField probabilityField = encoder.createDerivedField(segmentModel, probabilityOutputField, false);

features.add(new ContinuousFeature(encoder, probabilityField));
}
}

Output segmentOutput = ModelUtil.ensureOutput(segmentModel);

for(OutputField outputField : outputFields){
segmentOutput.addOutputFields(outputField);
encoder.addTransformer(segmentModel);
}

features.add(new ContinuousFeature(encoder, outputField));
return new Schema(label, features);
}

// XXX
encoder.createDataField(outputField.getName(), null);
}
@Override
public Model encodeModel(Schema schema){
StackedEnsembleMojoModel model = getModel();

models.add(segmentModel);
}
ConverterFactory converterFactory = ConverterFactory.newConverterFactory();

MojoModel metaLearner = getMetaLearner(model);
if(metaLearner != null){
Converter<?> converter = converterFactory.newConverter(metaLearner);

Schema metaLearnerSchema = converter.toMojoModelSchema(new Schema(label, features));
if(metaLearner == null){
throw new IllegalArgumentException();
}

Model segmentModel = converter.encodeModel(metaLearnerSchema);
Converter<?> converter = converterFactory.newConverter(metaLearner);

models.add(segmentModel);
}
Schema metaLearnerSchema = converter.toMojoModelSchema(schema);

return MiningModelUtil.createModelChain(models);
return converter.encodeModel(metaLearnerSchema);
}

static
Expand Down

0 comments on commit bff3e4f

Please sign in to comment.