Skip to content

Commit

Permalink
Refactored the translation of regression-type tree ensemble models
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 29, 2024
1 parent 5bda4e9 commit 49ce1fc
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 130 deletions.
40 changes: 0 additions & 40 deletions pmml-transpiler/src/main/java/org/jpmml/translator/JExprUtil.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public JMethod createEvaluateRegressionMethod(JMethod evaluateMethod, Translatio

Target target = targetField.getTarget();
if(target != null){
translateRegressorTarget(target, valueBuilder);
translateRegressorTarget(model, target, valueBuilder);

// XXX
model.setTargets(null);
Expand Down Expand Up @@ -361,9 +361,10 @@ public <V extends Number> ValueFactory<V> getValueFactory(Model model){
}

static
public void translateRegressorTarget(Target target, ValueBuilder valueBuilder){
Number rescaleFactor = target.getRescaleFactor();
public void translateRegressorTarget(Model model, Target target, ValueBuilder valueBuilder){
MathContext mathContext = model.getMathContext();

Number rescaleFactor = target.getRescaleFactor();
if(rescaleFactor != null && rescaleFactor.doubleValue() != 1d){
valueBuilder.update("multiply", rescaleFactor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ private void translateSegmentation(Segmentation segmentation, TranslationContext

Target target = targetField.getTarget();
if(target != null){
translateRegressorTarget(target, valueBuilder);
translateRegressorTarget(model, target, valueBuilder);
}

pullUpDerivedFields(miningModel, model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
*/
package org.jpmml.translator.mining;

import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -56,6 +53,7 @@
import org.dmg.pmml.tree.ComplexNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ValueUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.model.PMMLObjectKey;
import org.jpmml.model.UnsupportedAttributeException;
Expand All @@ -64,7 +62,6 @@
import org.jpmml.model.visitors.NodeFilterer;
import org.jpmml.translator.FieldInfoMap;
import org.jpmml.translator.JCompoundAssignment;
import org.jpmml.translator.JExprUtil;
import org.jpmml.translator.MethodScope;
import org.jpmml.translator.ModelTranslator;
import org.jpmml.translator.TranslationContext;
Expand Down Expand Up @@ -176,8 +173,7 @@ public JMethod translateRegressor(TranslationContext context){

switch(mathContext){
case FLOAT:
// Use a double accumulator (instead of a float one) for improved numerical stability
resultVar = context.declare(double.class, "result", JExprUtil.directNoPara(toFloatString(floatAsDouble(0f)) + "D"));
resultVar = context.declare(float.class, "result", JExpr.lit(0f));
break;
case DOUBLE:
resultVar = context.declare(double.class, "result", JExpr.lit(0d));
Expand Down Expand Up @@ -220,9 +216,9 @@ private JStatement createCompoundAssignment(JVar resultVar, Number value){
switch(mathContext){
case FLOAT:
{
double floatAsDoubleValue = floatAsDouble(value);
float floatValue = value.floatValue();

return new JCompoundAssignment(resultVar, JExprUtil.directNoPara(toFloatString(Math.abs(floatAsDoubleValue)) + "D"), floatAsDoubleValue >= 0d ? "+=" : "-=");
return new JCompoundAssignment(resultVar, JExpr.lit(Math.abs(floatValue)), floatValue >= 0f ? "+=" : "-=");
}
case DOUBLE:
{
Expand All @@ -241,7 +237,7 @@ private JStatement createCompoundAssignment(JVar resultVar, Number value){
ValueBuilder valueBuilder = new ValueBuilder(context)
.declare("resultValue", context.getValueFactoryVariable().newValue(resultVar));

translateRegressorTarget(target, valueBuilder);
translateRegressorTarget(treeModel, target, valueBuilder);

context._return(valueBuilder.getVariable());
} finally {
Expand Down Expand Up @@ -288,7 +284,7 @@ private TreeModel transformSegmentation(MiningModel miningModel){

switch(mathContext){
case FLOAT:
zero = floatAsDouble(0f);
zero = 0f;
break;
case DOUBLE:
zero = 0d;
Expand Down Expand Up @@ -350,7 +346,7 @@ public VisitorAction visit(TreeModel treeModel){

Number score = (Number)root.requireScore();

target.setRescaleConstant(add(mathContext, target.getRescaleConstant(), score));
target.setRescaleConstant(ValueUtil.add(mathContext, target.getRescaleConstant(), score));

updateNodeScores(root, score);

Expand All @@ -360,7 +356,7 @@ public VisitorAction visit(TreeModel treeModel){
private void updateNodeScores(Node node, Number adjustment){
Number score = (Number)node.requireScore();

node.setScore(subtract(mathContext, score, adjustment));
node.setScore(ValueUtil.subtract(mathContext, score, adjustment));

if(node.hasNodes()){
List<Node> children = node.getNodes();
Expand Down Expand Up @@ -544,7 +540,7 @@ private void merge(List<Node> leftNodes, List<Node> rightNodes){
throw new IllegalArgumentException();
}

leftNode.setScore(add(mathContext, (Number)leftNode.requireScore(), (Number)rightNode.requireScore()));
leftNode.setScore(ValueUtil.add(mathContext, (Number)leftNode.requireScore(), (Number)rightNode.requireScore()));
}
}
};
Expand All @@ -553,78 +549,5 @@ private void merge(List<Node> leftNodes, List<Node> rightNodes){
return treeModel;
}

static
private Number add(MathContext mathContext, Number left, Number right){

switch(mathContext){
case FLOAT:
return (floatAsDouble(left) + floatAsDouble(right));
case DOUBLE:
return (left.doubleValue() + right.doubleValue());
default:
throw new IllegalArgumentException();
}
}

static
private Number subtract(MathContext mathContext, Number left, Number right){

switch(mathContext){
case FLOAT:
return (floatAsDouble(left) - floatAsDouble(right));
case DOUBLE:
return (left.doubleValue() - right.doubleValue());
default:
throw new IllegalArgumentException();
}
}

static
private double floatAsDouble(Number value){

if(value instanceof Float){
Float floatValue = (Float)value;

return Double.parseDouble(floatValue.toString());
} else

if(value instanceof Double){
Double doubleValue = (Double)value;

return doubleValue.doubleValue();
} else

{
throw new IllegalArgumentException();
}
}

static
private String toFloatString(Number value){
DecimalFormat formatter = TreeModelBoosterTranslator.FORMAT_DECIMAL32;

synchronized(formatter){
return formatter.format(value);
}
}

/**
* @see java.math.MathContext#DECIMAL32
*/
private static final DecimalFormat FORMAT_DECIMAL32;

static {
// Add one extra decimal place
String pattern = "0.#######" + "E0";

DecimalFormatSymbols symbols = new DecimalFormatSymbols();
symbols.setDecimalSeparator('.');
symbols.setMinusSign('-');
symbols.setExponentSeparator("E");

FORMAT_DECIMAL32 = new DecimalFormat(pattern, symbols);
FORMAT_DECIMAL32.setRoundingMode(RoundingMode.HALF_EVEN);
}

public static final int NODE_COUNT_LIMIT = Integer.getInteger(TreeModelBoosterTranslator.class.getName() + "#" + "NODE_COUNT_LIMIT", 1000);
}

0 comments on commit 49ce1fc

Please sign in to comment.