diff --git a/pom.xml b/pom.xml index 849adde..f118620 100644 --- a/pom.xml +++ b/pom.xml @@ -72,13 +72,13 @@ org.jpmml jpmml-converter - 1.3.9 + 1.3.10 org.jpmml jpmml-xgboost - 1.3.10 + 1.3.11 diff --git a/src/main/java/org/jpmml/h2o/ImputerUtil.java b/src/main/java/org/jpmml/h2o/ImputerUtil.java index 2c3586a..61907a5 100644 --- a/src/main/java/org/jpmml/h2o/ImputerUtil.java +++ b/src/main/java/org/jpmml/h2o/ImputerUtil.java @@ -21,6 +21,7 @@ import org.dmg.pmml.DataField; import org.dmg.pmml.Field; import org.dmg.pmml.MissingValueTreatmentMethod; +import org.jpmml.converter.DerivedOutputField; import org.jpmml.converter.Feature; import org.jpmml.converter.MissingValueDecorator; import org.jpmml.converter.ModelEncoder; @@ -36,12 +37,8 @@ public Feature encodeFeature(Feature feature, Object replacementValue, MissingVa Field field = feature.getField(); - if(field instanceof DataField){ - MissingValueDecorator missingValueDecorator = new MissingValueDecorator() - .setMissingValueReplacement(replacementValue) - .setMissingValueTreatment(missingValueTreatmentMethod); - - encoder.addDecorator(feature.getName(), missingValueDecorator); + if((field instanceof DataField) || (field instanceof DerivedOutputField)){ + encoder.addDecorator(field.getName(), new MissingValueDecorator(missingValueTreatmentMethod, replacementValue)); return feature; } else diff --git a/src/main/java/org/jpmml/h2o/StackedEnsembleMojoModelConverter.java b/src/main/java/org/jpmml/h2o/StackedEnsembleMojoModelConverter.java index aa463ef..dafc936 100644 --- a/src/main/java/org/jpmml/h2o/StackedEnsembleMojoModelConverter.java +++ b/src/main/java/org/jpmml/h2o/StackedEnsembleMojoModelConverter.java @@ -29,17 +29,16 @@ import org.dmg.pmml.FieldName; import org.dmg.pmml.Model; import org.dmg.pmml.OpType; -import org.dmg.pmml.Output; import org.dmg.pmml.OutputField; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.ContinuousLabel; +import org.jpmml.converter.DerivedOutputField; import org.jpmml.converter.Feature; import org.jpmml.converter.Label; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.Schema; import org.jpmml.converter.SchemaUtil; -import org.jpmml.converter.mining.MiningModelUtil; public class StackedEnsembleMojoModelConverter extends Converter { @@ -48,20 +47,18 @@ public StackedEnsembleMojoModelConverter(StackedEnsembleMojoModel model){ } @Override - public Model encodeModel(Schema schema){ + public Schema encodeSchema(H2OEncoder encoder){ StackedEnsembleMojoModel model = getModel(); ConverterFactory converterFactory = ConverterFactory.newConverterFactory(); + Schema schema = super.encodeSchema(encoder); + Label label = schema.getLabel(); List features = new ArrayList<>(); - List models = new ArrayList<>(); - Schema segmentSchema = schema.toAnonymousSchema(); - H2OEncoder encoder = new H2OEncoder(); - Object[] baseModels = getBaseModels(model); for(int i = 0; i < baseModels.length; i++){ Object baseModel = baseModels[i]; @@ -83,15 +80,15 @@ public Model encodeModel(Schema schema){ Model segmentModel = converter.encodeModel(baseModelSchema); - List outputFields = new ArrayList<>(); - if(model._nclasses == 1){ ContinuousLabel continuousLabel = (ContinuousLabel)label; - OutputField predictedField = ModelUtil.createPredictedField(FieldName.create("stack(" + i + ")"), OpType.CONTINUOUS, DataType.DOUBLE) + OutputField predictedOutputField = ModelUtil.createPredictedField(FieldName.create("stack(" + i + ")"), OpType.CONTINUOUS, DataType.DOUBLE) .setFinalResult(false); - outputFields.add(predictedField); + DerivedOutputField predictedField = encoder.createDerivedField(segmentModel, predictedOutputField, false); + + features.add(new ContinuousFeature(encoder, predictedField)); } else { @@ -106,39 +103,37 @@ public Model encodeModel(Schema schema){ } for(Object value : values){ - OutputField probabilityField = ModelUtil.createProbabilityField(FieldName.create("stack(" + i +", " + value + ")"), DataType.DOUBLE, value) + OutputField probabilityOutputField = ModelUtil.createProbabilityField(FieldName.create("stack(" + i +", " + value + ")"), DataType.DOUBLE, value) .setFinalResult(false); - outputFields.add(probabilityField); + DerivedOutputField probabilityField = encoder.createDerivedField(segmentModel, probabilityOutputField, false); + + features.add(new ContinuousFeature(encoder, probabilityField)); } } - Output segmentOutput = ModelUtil.ensureOutput(segmentModel); - - for(OutputField outputField : outputFields){ - segmentOutput.addOutputFields(outputField); + encoder.addTransformer(segmentModel); + } - features.add(new ContinuousFeature(encoder, outputField)); + return new Schema(label, features); + } - // XXX - encoder.createDataField(outputField.getName(), null); - } + @Override + public Model encodeModel(Schema schema){ + StackedEnsembleMojoModel model = getModel(); - models.add(segmentModel); - } + ConverterFactory converterFactory = ConverterFactory.newConverterFactory(); MojoModel metaLearner = getMetaLearner(model); - if(metaLearner != null){ - Converter converter = converterFactory.newConverter(metaLearner); - - Schema metaLearnerSchema = converter.toMojoModelSchema(new Schema(label, features)); + if(metaLearner == null){ + throw new IllegalArgumentException(); + } - Model segmentModel = converter.encodeModel(metaLearnerSchema); + Converter converter = converterFactory.newConverter(metaLearner); - models.add(segmentModel); - } + Schema metaLearnerSchema = converter.toMojoModelSchema(schema); - return MiningModelUtil.createModelChain(models); + return converter.encodeModel(metaLearnerSchema); } static