From 49ce1fcb125c8229355ecd1401a7f8bead4c29f4 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sun, 29 Sep 2024 12:11:53 +0300 Subject: [PATCH] Refactored the translation of regression-type tree ensemble models --- .../java/org/jpmml/translator/JExprUtil.java | 40 -------- .../org/jpmml/translator/ModelTranslator.java | 7 +- .../mining/ModelChainTranslator.java | 2 +- .../mining/TreeModelBoosterTranslator.java | 95 ++----------------- 4 files changed, 14 insertions(+), 130 deletions(-) delete mode 100644 pmml-transpiler/src/main/java/org/jpmml/translator/JExprUtil.java diff --git a/pmml-transpiler/src/main/java/org/jpmml/translator/JExprUtil.java b/pmml-transpiler/src/main/java/org/jpmml/translator/JExprUtil.java deleted file mode 100644 index cdb7d1e..0000000 --- a/pmml-transpiler/src/main/java/org/jpmml/translator/JExprUtil.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021 Villu Ruusmann - * - * This file is part of JPMML-Transpiler - * - * JPMML-Transpiler is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * JPMML-Transpiler is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with JPMML-Transpiler. If not, see . - */ -package org.jpmml.translator; - -import com.sun.codemodel.JExpression; -import com.sun.codemodel.JExpressionImpl; -import com.sun.codemodel.JFormatter; - -public class JExprUtil { - - private JExprUtil(){ - } - - static - public JExpression directNoPara(String string){ - return new JExpressionImpl(){ - - @Override - public void generate(JFormatter formatter){ - formatter.p(string); - } - }; - } -} \ No newline at end of file diff --git a/pmml-transpiler/src/main/java/org/jpmml/translator/ModelTranslator.java b/pmml-transpiler/src/main/java/org/jpmml/translator/ModelTranslator.java index be891fa..ddd9427 100644 --- a/pmml-transpiler/src/main/java/org/jpmml/translator/ModelTranslator.java +++ b/pmml-transpiler/src/main/java/org/jpmml/translator/ModelTranslator.java @@ -176,7 +176,7 @@ public JMethod createEvaluateRegressionMethod(JMethod evaluateMethod, Translatio Target target = targetField.getTarget(); if(target != null){ - translateRegressorTarget(target, valueBuilder); + translateRegressorTarget(model, target, valueBuilder); // XXX model.setTargets(null); @@ -361,9 +361,10 @@ public ValueFactory getValueFactory(Model model){ } static - public void translateRegressorTarget(Target target, ValueBuilder valueBuilder){ - Number rescaleFactor = target.getRescaleFactor(); + public void translateRegressorTarget(Model model, Target target, ValueBuilder valueBuilder){ + MathContext mathContext = model.getMathContext(); + Number rescaleFactor = target.getRescaleFactor(); if(rescaleFactor != null && rescaleFactor.doubleValue() != 1d){ valueBuilder.update("multiply", rescaleFactor); } diff --git a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/ModelChainTranslator.java b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/ModelChainTranslator.java index 1f76a30..4194739 100644 --- a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/ModelChainTranslator.java +++ b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/ModelChainTranslator.java @@ -223,7 +223,7 @@ private void translateSegmentation(Segmentation segmentation, TranslationContext Target target = targetField.getTarget(); if(target != null){ - translateRegressorTarget(target, valueBuilder); + translateRegressorTarget(model, target, valueBuilder); } pullUpDerivedFields(miningModel, model); diff --git a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java index 3ba07b5..c24deb3 100644 --- a/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java +++ b/pmml-transpiler/src/main/java/org/jpmml/translator/mining/TreeModelBoosterTranslator.java @@ -18,9 +18,6 @@ */ package org.jpmml.translator.mining; -import java.math.RoundingMode; -import java.text.DecimalFormat; -import java.text.DecimalFormatSymbols; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -56,6 +53,7 @@ import org.dmg.pmml.tree.ComplexNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; +import org.jpmml.converter.ValueUtil; import org.jpmml.evaluator.Value; import org.jpmml.model.PMMLObjectKey; import org.jpmml.model.UnsupportedAttributeException; @@ -64,7 +62,6 @@ import org.jpmml.model.visitors.NodeFilterer; import org.jpmml.translator.FieldInfoMap; import org.jpmml.translator.JCompoundAssignment; -import org.jpmml.translator.JExprUtil; import org.jpmml.translator.MethodScope; import org.jpmml.translator.ModelTranslator; import org.jpmml.translator.TranslationContext; @@ -176,8 +173,7 @@ public JMethod translateRegressor(TranslationContext context){ switch(mathContext){ case FLOAT: - // Use a double accumulator (instead of a float one) for improved numerical stability - resultVar = context.declare(double.class, "result", JExprUtil.directNoPara(toFloatString(floatAsDouble(0f)) + "D")); + resultVar = context.declare(float.class, "result", JExpr.lit(0f)); break; case DOUBLE: resultVar = context.declare(double.class, "result", JExpr.lit(0d)); @@ -220,9 +216,9 @@ private JStatement createCompoundAssignment(JVar resultVar, Number value){ switch(mathContext){ case FLOAT: { - double floatAsDoubleValue = floatAsDouble(value); + float floatValue = value.floatValue(); - return new JCompoundAssignment(resultVar, JExprUtil.directNoPara(toFloatString(Math.abs(floatAsDoubleValue)) + "D"), floatAsDoubleValue >= 0d ? "+=" : "-="); + return new JCompoundAssignment(resultVar, JExpr.lit(Math.abs(floatValue)), floatValue >= 0f ? "+=" : "-="); } case DOUBLE: { @@ -241,7 +237,7 @@ private JStatement createCompoundAssignment(JVar resultVar, Number value){ ValueBuilder valueBuilder = new ValueBuilder(context) .declare("resultValue", context.getValueFactoryVariable().newValue(resultVar)); - translateRegressorTarget(target, valueBuilder); + translateRegressorTarget(treeModel, target, valueBuilder); context._return(valueBuilder.getVariable()); } finally { @@ -288,7 +284,7 @@ private TreeModel transformSegmentation(MiningModel miningModel){ switch(mathContext){ case FLOAT: - zero = floatAsDouble(0f); + zero = 0f; break; case DOUBLE: zero = 0d; @@ -350,7 +346,7 @@ public VisitorAction visit(TreeModel treeModel){ Number score = (Number)root.requireScore(); - target.setRescaleConstant(add(mathContext, target.getRescaleConstant(), score)); + target.setRescaleConstant(ValueUtil.add(mathContext, target.getRescaleConstant(), score)); updateNodeScores(root, score); @@ -360,7 +356,7 @@ public VisitorAction visit(TreeModel treeModel){ private void updateNodeScores(Node node, Number adjustment){ Number score = (Number)node.requireScore(); - node.setScore(subtract(mathContext, score, adjustment)); + node.setScore(ValueUtil.subtract(mathContext, score, adjustment)); if(node.hasNodes()){ List children = node.getNodes(); @@ -544,7 +540,7 @@ private void merge(List leftNodes, List rightNodes){ throw new IllegalArgumentException(); } - leftNode.setScore(add(mathContext, (Number)leftNode.requireScore(), (Number)rightNode.requireScore())); + leftNode.setScore(ValueUtil.add(mathContext, (Number)leftNode.requireScore(), (Number)rightNode.requireScore())); } } }; @@ -553,78 +549,5 @@ private void merge(List leftNodes, List rightNodes){ return treeModel; } - static - private Number add(MathContext mathContext, Number left, Number right){ - - switch(mathContext){ - case FLOAT: - return (floatAsDouble(left) + floatAsDouble(right)); - case DOUBLE: - return (left.doubleValue() + right.doubleValue()); - default: - throw new IllegalArgumentException(); - } - } - - static - private Number subtract(MathContext mathContext, Number left, Number right){ - - switch(mathContext){ - case FLOAT: - return (floatAsDouble(left) - floatAsDouble(right)); - case DOUBLE: - return (left.doubleValue() - right.doubleValue()); - default: - throw new IllegalArgumentException(); - } - } - - static - private double floatAsDouble(Number value){ - - if(value instanceof Float){ - Float floatValue = (Float)value; - - return Double.parseDouble(floatValue.toString()); - } else - - if(value instanceof Double){ - Double doubleValue = (Double)value; - - return doubleValue.doubleValue(); - } else - - { - throw new IllegalArgumentException(); - } - } - - static - private String toFloatString(Number value){ - DecimalFormat formatter = TreeModelBoosterTranslator.FORMAT_DECIMAL32; - - synchronized(formatter){ - return formatter.format(value); - } - } - - /** - * @see java.math.MathContext#DECIMAL32 - */ - private static final DecimalFormat FORMAT_DECIMAL32; - - static { - // Add one extra decimal place - String pattern = "0.#######" + "E0"; - - DecimalFormatSymbols symbols = new DecimalFormatSymbols(); - symbols.setDecimalSeparator('.'); - symbols.setMinusSign('-'); - symbols.setExponentSeparator("E"); - - FORMAT_DECIMAL32 = new DecimalFormat(pattern, symbols); - FORMAT_DECIMAL32.setRoundingMode(RoundingMode.HALF_EVEN); - } - public static final int NODE_COUNT_LIMIT = Integer.getInteger(TreeModelBoosterTranslator.class.getName() + "#" + "NODE_COUNT_LIMIT", 1000); } \ No newline at end of file