diff --git a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/MiningModelTranslator.java b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/MiningModelTranslator.java index c79f7da..e93b762 100644 --- a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/MiningModelTranslator.java +++ b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/MiningModelTranslator.java @@ -27,6 +27,7 @@ import org.dmg.pmml.Model; import org.dmg.pmml.Output; import org.dmg.pmml.PMML; +import org.dmg.pmml.Target; import org.dmg.pmml.Targets; import org.dmg.pmml.mining.MiningModel; import org.dmg.pmml.mining.Segment; @@ -66,6 +67,21 @@ public ModelTranslator newModelTranslator(Model model){ return modelTranslatorFactory.newModelTranslator(pmml, model); } + static + public Number extractIntercept(Target target){ + Number rescaleFactor = target.getRescaleFactor(); + Number rescaleConstant = target.getRescaleConstant(); + + if(rescaleFactor.doubleValue() == 1d && rescaleConstant.doubleValue() != 0d){ + // XXX + target.setRescaleConstant(null); + + return rescaleConstant; + } + + return null; + } + static public void checkMiningSchema(Model model){ MiningSchema miningSchema = model.requireMiningSchema(); diff --git a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelAggregatorTranslator.java b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelAggregatorTranslator.java index 2f8491c..f5e6487 100644 --- a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelAggregatorTranslator.java +++ b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelAggregatorTranslator.java @@ -40,6 +40,7 @@ import org.dmg.pmml.MiningFunction; import org.dmg.pmml.PMML; import org.dmg.pmml.PMMLObject; +import org.dmg.pmml.Target; import org.dmg.pmml.True; import org.dmg.pmml.mining.MiningModel; import org.dmg.pmml.mining.Segment; @@ -49,6 +50,7 @@ import org.jpmml.evaluator.Classification; import org.jpmml.evaluator.ProbabilityAggregator; import org.jpmml.evaluator.ProbabilityDistribution; +import org.jpmml.evaluator.TargetField; import org.jpmml.evaluator.Value; import org.jpmml.evaluator.ValueAggregator; import org.jpmml.evaluator.ValueFactory; @@ -59,14 +61,12 @@ import org.jpmml.translator.IdentifierUtil; import org.jpmml.translator.JBinaryFileInitializer; import org.jpmml.translator.JDirectInitializer; -import org.jpmml.translator.JVarBuilder; import org.jpmml.translator.MethodScope; import org.jpmml.translator.ModelTranslator; import org.jpmml.translator.Modifiers; import org.jpmml.translator.PMMLObjectUtil; import org.jpmml.translator.Scope; import org.jpmml.translator.TranslationContext; -import org.jpmml.translator.ValueBuilder; import org.jpmml.translator.ValueFactoryRef; import org.jpmml.translator.tree.NodeScoreDistributionManager; import org.jpmml.translator.tree.NodeScoreManager; @@ -295,6 +295,25 @@ private void translateValueAggregatorSegmentation(Segmentation segmentation, Tra JFieldVar methodsVar = codeInitializer.initLambdas(IdentifierUtil.create("methods", segmentation), modelFuncInterface, methods); + switch(multipleModelMethod){ + case SUM: + { + TargetField targetField = getTargetField(); + + Target target = targetField.getTarget(); + if(target != null){ + Number intercept = extractIntercept(target); + + if(intercept != null){ + aggregatorBuilder.update("add", intercept); + } + } + } + break; + default: + break; + } + JBlock block = context.block(); try { @@ -361,10 +380,7 @@ private void translateValueAggregatorSegmentation(Segmentation segmentation, Tra throw new UnsupportedAttributeException(segmentation, multipleModelMethod); } - JVarBuilder resultBuilder = new ValueBuilder(context) - .declare(context.getValueType(), "result", valueInit); - - context._return(resultBuilder.getVariable()); + context._return(valueInit); } private void translateProbabilityAggregatorSegmentation(Segmentation segmentation, TranslationContext context){ diff --git a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java index c24deb3..3c23896 100644 --- a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java +++ b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java @@ -158,6 +158,11 @@ public JMethod translateRegressor(TranslationContext context){ Target target = Iterables.getOnlyElement(targets); + Number intercept = extractIntercept(target); + if(intercept == null){ + intercept = 0; + } + ModelTranslator modelTranslator = new TreeModelTranslator(pmml, treeModel); Node root = treeModel.getNode(); @@ -173,10 +178,10 @@ public JMethod translateRegressor(TranslationContext context){ switch(mathContext){ case FLOAT: - resultVar = context.declare(float.class, "result", JExpr.lit(0f)); + resultVar = context.declare(float.class, "result", JExpr.lit(intercept.floatValue())); break; case DOUBLE: - resultVar = context.declare(double.class, "result", JExpr.lit(0d)); + resultVar = context.declare(double.class, "result", JExpr.lit(intercept.doubleValue())); break; default: throw new UnsupportedAttributeException(miningModel, mathContext); diff --git a/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/ClassificationTest.java b/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/ClassificationTest.java index f81103c..ac05fd9 100644 --- a/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/ClassificationTest.java +++ b/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/ClassificationTest.java @@ -65,12 +65,12 @@ public void evaluateSelectFirstAudit() throws Exception { @Test public void evaluateXGBoostAudit() throws Exception { - evaluate(XGBOOST, AUDIT, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(32 + 48)); + evaluate(XGBOOST, AUDIT, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(8 + 4)); } @Test public void evaluateXGBoostAuditNA() throws Exception { - evaluate(XGBOOST, AUDIT_NA, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(32 + 48)); + evaluate(XGBOOST, AUDIT_NA, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(8 + 4)); } @Test @@ -95,7 +95,7 @@ public void evaluateRandomForestSentiment() throws Exception { @Test public void evaluateXGBoostSentiment() throws Exception { - evaluate(XGBOOST, SENTIMENT, excludeFields(SENTIMENT_PROBABILITY_FALSE), new FloatEquivalence(24)); + evaluate(XGBOOST, SENTIMENT, excludeFields(SENTIMENT_PROBABILITY_FALSE), new FloatEquivalence(8 + 4)); } @Test @@ -125,6 +125,6 @@ public void evaluateRandomForestIris() throws Exception { @Test public void evaluateXGBoostIris() throws Exception { - evaluate(XGBOOST, IRIS, new FloatEquivalence(10)); + evaluate(XGBOOST, IRIS, new FloatEquivalence(8 + 2)); } } \ No newline at end of file diff --git a/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/RegressionTest.java b/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/RegressionTest.java index 51693d6..0a03050 100644 --- a/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/RegressionTest.java +++ b/pmml-transpiler/src/test/java/org/jpmml/transpiler/testing/RegressionTest.java @@ -80,11 +80,11 @@ public void evaluateVotingEnsembleAuto() throws Exception { @Test public void evaluateXGBoostAuto() throws Exception { - evaluate(XGBOOST, AUTO, new FloatEquivalence(8)); + evaluate(XGBOOST, AUTO, new FloatEquivalence(2)); } @Test public void evaluateXGBoostAutoNA() throws Exception { - evaluate(XGBOOST, AUTO_NA, new FloatEquivalence(8 + 4)); + evaluate(XGBOOST, AUTO_NA, new FloatEquivalence(2)); } } \ No newline at end of file