Skip to content

Commit

Permalink
Updated JPMML-Converter dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Feb 10, 2024
1 parent 6440d0d commit 70e2e86
Show file tree
Hide file tree
Showing 19 changed files with 118 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.sparkml.model.HasPredictionModelOptions;
Expand Down Expand Up @@ -86,7 +86,7 @@ public List<OutputField> registerOutputFields(Label label, Model pmmlModel, Spar

DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, keepPredictionCol);

MapValues mapValues = PMMLUtil.createMapValues(pmmlPredictedField.getName(), categoricalLabel.getValues(), categories)
MapValues mapValues = ExpressionUtil.createMapValues(pmmlPredictedField.getName(), categoricalLabel.getValues(), categories)
.setDataType(DataType.DOUBLE);

OutputField predictedOutputField = new OutputField(predictionCol, OpType.CONTINUOUS, DataType.DOUBLE)
Expand Down
107 changes: 46 additions & 61 deletions pmml-sparkml/src/main/java/org/jpmml/sparkml/ExpressionTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMMLFunctions;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.IfElseBuilder;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.ExpressionCompactor;
Expand Down Expand Up @@ -182,7 +182,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
throw new IllegalArgumentException(formatMessage(binaryMathExpression));
}

return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
return org.jpmml.converter.ExpressionUtil.createApply(function, translateInternal(left), translateInternal(right));
} else

if(expression instanceof BinaryOperator){
Expand Down Expand Up @@ -258,7 +258,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
throw new IllegalArgumentException(formatMessage(binaryOperator));
}

return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
return org.jpmml.converter.ExpressionUtil.createApply(function, translateInternal(left), translateInternal(right));
} else

if(expression instanceof CaseWhen){
Expand All @@ -268,41 +268,26 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){

Option<Expression> elseValue = caseWhen.elseValue();

Apply apply = null;
IfElseBuilder applyBuilder = new IfElseBuilder();

Iterator<Tuple2<Expression, Expression>> branchIt = branches.iterator();

Apply prevBranchApply = null;

do {
Tuple2<Expression, Expression> branch = branchIt.next();

Expression predicate = branch._1();
Expression value = branch._2();

Apply branchApply = PMMLUtil.createApply(PMMLFunctions.IF,
translateInternal(predicate),
translateInternal(value)
);

if(apply == null){
apply = branchApply;
} // End if

if(prevBranchApply != null){
prevBranchApply.addExpressions(branchApply);
}

prevBranchApply = branchApply;
applyBuilder.add(translateInternal(predicate), translateInternal(value));
} while(branchIt.hasNext());

if(elseValue.isDefined()){
Expression value = elseValue.get();

prevBranchApply.addExpressions(translateInternal(value));
applyBuilder.terminate(translateInternal(value));
}

return apply;
return applyBuilder.build();
} else

if(expression instanceof Cast){
Expand Down Expand Up @@ -350,7 +335,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){

List<Expression> children = JavaConversions.seqAsJavaList(concat.children());

Apply apply = PMMLUtil.createApply(PMMLFunctions.CONCAT);
Apply apply = org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.CONCAT);

for(Expression child : children){
apply.addExpressions(translateInternal(child));
Expand All @@ -364,7 +349,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){

List<Expression> children = JavaConversions.seqAsJavaList(greatest.children());

Apply apply = PMMLUtil.createApply(PMMLFunctions.MAX);
Apply apply = org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.MAX);

for(Expression child : children){
apply.addExpressions(translateInternal(child));
Expand All @@ -381,7 +366,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
Expression trueValue = _if.trueValue();
Expression falseValue = _if.falseValue();

return PMMLUtil.createApply(PMMLFunctions.IF,
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.IF,
translateInternal(predicate),
translateInternal(trueValue),
translateInternal(falseValue)
Expand All @@ -395,7 +380,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){

List<Expression> elements = JavaConversions.seqAsJavaList(in.list());

Apply apply = PMMLUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));
Apply apply = org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));

for(Expression element : elements){
apply.addExpressions(translateInternal(element));
Expand All @@ -409,7 +394,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){

List<Expression> children = JavaConversions.seqAsJavaList(least.children());

Apply apply = PMMLUtil.createApply(PMMLFunctions.MIN);
Apply apply = org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.MIN);

for(Expression child : children){
apply.addExpressions(translateInternal(child));
Expand All @@ -423,15 +408,15 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){

Expression child = length.child();

return PMMLUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
} else

if(expression instanceof Literal){
Literal literal = (Literal)expression;

Object value = literal.value();
if(value == null){
return PMMLUtil.createMissingConstant();
return org.jpmml.converter.ExpressionUtil.createMissingConstant();
}

DataType dataType;
Expand All @@ -451,7 +436,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
value = toSimpleObject(value);
}

return PMMLUtil.createConstant(value, dataType);
return org.jpmml.converter.ExpressionUtil.createConstant(dataType, value);
} else

if(expression instanceof RegExpReplace){
Expand All @@ -461,7 +446,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
Expression regexp = regexpReplace.regexp();
Expression rep = regexpReplace.rep();

return PMMLUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
} else

if(expression instanceof RLike){
Expand All @@ -470,7 +455,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
Expression left = rlike.left();
Expression right = rlike.right();

return PMMLUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
} else

if(expression instanceof StringReplace){
Expand All @@ -480,7 +465,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
Expression searchExpr = stringReplace.searchExpr();
Expression replaceExpr = stringReplace.replaceExpr();

return PMMLUtil.createApply(PMMLFunctions.REPLACE, translateInternal(srcExpr), transformString(translateInternal(searchExpr), ExpressionTranslator::escapeSearchString), transformString(translateInternal(replaceExpr), ExpressionTranslator::escapeReplacementString));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.REPLACE, translateInternal(srcExpr), transformString(translateInternal(searchExpr), ExpressionTranslator::escapeSearchString), transformString(translateInternal(replaceExpr), ExpressionTranslator::escapeReplacementString));
} else

if(expression instanceof StringTrim){
Expand All @@ -492,7 +477,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
throw new IllegalArgumentException();
}

return PMMLUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
} else

if(expression instanceof Substring){
Expand All @@ -512,7 +497,7 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
// XXX
lenValue = Math.min(lenValue, MAX_STRING_LENGTH);

return PMMLUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), PMMLUtil.createConstant(posValue), PMMLUtil.createConstant(lenValue));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), org.jpmml.converter.ExpressionUtil.createConstant(posValue), org.jpmml.converter.ExpressionUtil.createConstant(lenValue));
} else

if(expression instanceof UnaryExpression){
Expand All @@ -521,112 +506,112 @@ private org.dmg.pmml.Expression translateInternal(Expression expression){
Expression child = unaryExpression.child();

if(expression instanceof Abs){
return PMMLUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
} else

if(expression instanceof Acos){
return PMMLUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
} else

if(expression instanceof Asin){
return PMMLUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
} else

if(expression instanceof Atan){
return PMMLUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
} else

if(expression instanceof Ceil){
return PMMLUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
} else

if(expression instanceof Cos){
return PMMLUtil.createApply(PMMLFunctions.COS, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.COS, translateInternal(child));
} else

if(expression instanceof Cosh){
return PMMLUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
} else

if(expression instanceof Exp){
return PMMLUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
} else

if(expression instanceof Expm1){
return PMMLUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
} else

if(expression instanceof Floor){
return PMMLUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
} else

if(expression instanceof Log){
return PMMLUtil.createApply(PMMLFunctions.LN, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.LN, translateInternal(child));
} else

if(expression instanceof Log10){
return PMMLUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
} else

if(expression instanceof Log1p){
return PMMLUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
} else

if(expression instanceof Lower){
return PMMLUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
} else

if(expression instanceof IsNaN){
// XXX
return PMMLUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
} else

if(expression instanceof IsNotNull){
return PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
} else

if(expression instanceof IsNull){
return PMMLUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
} else

if(expression instanceof Not){
return PMMLUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
} else

if(expression instanceof Rint){
return PMMLUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
} else

if(expression instanceof Sin){
return PMMLUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
} else

if(expression instanceof Sinh){
return PMMLUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
} else

if(expression instanceof Sqrt){
return PMMLUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
} else

if(expression instanceof Tan){
return PMMLUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
} else

if(expression instanceof Tanh){
return PMMLUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
} else

if(expression instanceof UnaryMinus){
return PMMLUtil.toNegative(translateInternal(child));
return org.jpmml.converter.ExpressionUtil.toNegative(translateInternal(child));
} else

if(expression instanceof UnaryPositive){
return translateInternal(child);
} else

if(expression instanceof Upper){
return PMMLUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
return org.jpmml.converter.ExpressionUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
} else

{
Expand Down
8 changes: 4 additions & 4 deletions pmml-sparkml/src/main/java/org/jpmml/sparkml/TermFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.ParameterField;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.model.ToStringHelper;

public class TermFeature extends Feature {
Expand Down Expand Up @@ -74,7 +74,7 @@ public WeightedTermFeature toWeightedTermFeature(Number weight){
List<ParameterField> weightedParameterFields = new ArrayList<>(defineFunction.getParameterFields());
weightedParameterFields.add(weightField);

Apply apply = PMMLUtil.createApply(PMMLFunctions.MULTIPLY, defineFunction.requireExpression(), new FieldRef(weightField));
Apply apply = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, defineFunction.requireExpression(), new FieldRef(weightField));

weightedDefineFunction = new DefineFunction(name, OpType.CONTINUOUS, DataType.DOUBLE, weightedParameterFields, apply);

Expand All @@ -89,9 +89,9 @@ public Apply createApply(){
Feature feature = getFeature();
String value = getValue();

Constant constant = PMMLUtil.createConstant(value, DataType.STRING);
Constant constant = ExpressionUtil.createConstant(DataType.STRING, value);

return PMMLUtil.createApply(defineFunction, feature.ref(), constant);
return ExpressionUtil.createApply(defineFunction, feature.ref(), constant);
}

@Override
Expand Down
Loading

0 comments on commit 70e2e86

Please sign in to comment.