Skip to content

Commit

Permalink
Fixed the translation of regressors in model chain models
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 28, 2024
1 parent 8b48ff3 commit 732927c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ public <V extends Number> ValueFactory<V> getValueFactory(Model model){
}

static
private void translateRegressorTarget(Target target, ValueBuilder valueBuilder){
public void translateRegressorTarget(Target target, ValueBuilder valueBuilder){
Number rescaleFactor = target.getRescaleFactor();

if(rescaleFactor != null && rescaleFactor.doubleValue() != 1d){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
Expand All @@ -40,6 +41,7 @@
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.TargetField;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.MissingElementException;
import org.jpmml.model.UnsupportedAttributeException;
Expand All @@ -50,6 +52,7 @@
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.PMMLObjectUtil;
import org.jpmml.translator.TranslationContext;
import org.jpmml.translator.ValueBuilder;
import org.jpmml.translator.ValueFactoryRef;
import org.jpmml.translator.ValueMapBuilder;
import org.jpmml.translator.regression.RegressionModelTranslator;
Expand Down Expand Up @@ -209,11 +212,19 @@ private void translateSegmentation(Segmentation segmentation, TranslationContext

ModelTranslator<?> modelTranslator = newModelTranslator(model);

TargetField targetField = modelTranslator.getTargetField();

JMethod evaluateMethod = modelTranslator.translateRegressor(context);

JInvocation methodInvocation = createEvaluatorMethodInvocation(evaluateMethod, context);

context.declare(context.getValueType(), IdentifierUtil.create("value", outputField.requireName()), methodInvocation);
ValueBuilder valueBuilder = new ValueBuilder(context)
.declare(IdentifierUtil.create("value", outputField.requireName()), methodInvocation);

Target target = targetField.getTarget();
if(target != null){
translateRegressorTarget(target, valueBuilder);
}

pullUpDerivedFields(miningModel, model);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
import org.jpmml.translator.FieldInfoMap;
import org.jpmml.translator.JCompoundAssignment;
import org.jpmml.translator.JExprUtil;
import org.jpmml.translator.JVarBuilder;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.TranslationContext;
Expand Down Expand Up @@ -239,33 +238,12 @@ private JStatement createCompoundAssignment(JVar resultVar, Number value){

TreeModelTranslator.translateNode(treeModel, root, scorer, fieldInfos, context);

JVarBuilder valueBuilder = new ValueBuilder(context)
ValueBuilder valueBuilder = new ValueBuilder(context)
.declare("resultValue", context.getValueFactoryVariable().newValue(resultVar));

Number intercept = target.getRescaleConstant();
translateRegressorTarget(target, valueBuilder);

switch(mathContext){
case FLOAT:
{
double floatAsDoubleValue = floatAsDouble(intercept);

valueBuilder.update("add", JExprUtil.directNoPara(toFloatString(floatAsDoubleValue) + "D"));
}
break;
case DOUBLE:
{
double doubleValue = intercept.doubleValue();

valueBuilder.update("add", JExpr.lit(doubleValue));
}
break;
default:
throw new UnsupportedAttributeException(miningModel, mathContext);
}

JVar resultValueVar = valueBuilder.getVariable();

context._return(resultValueVar);
context._return(valueBuilder.getVariable());
} finally {
context.popScope();
}
Expand Down

0 comments on commit 732927c

Please sign in to comment.