From 844097760f1e93905e1ad40a4fe241b5e5dddf5d Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Thu, 19 Nov 2020 10:04:05 +0200 Subject: [PATCH] Refactored the initial state of model arguments --- .../org/jpmml/translator/ArgumentsRef.java | 96 +++++++++++++++---- .../java/org/jpmml/translator/Encoder.java | 3 + .../jpmml/translator/FpPrimitiveEncoder.java | 24 ++++- .../org/jpmml/translator/OrdinalEncoder.java | 6 ++ 4 files changed, 109 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/jpmml/translator/ArgumentsRef.java b/src/main/java/org/jpmml/translator/ArgumentsRef.java index 285933d..df5e5be 100644 --- a/src/main/java/org/jpmml/translator/ArgumentsRef.java +++ b/src/main/java/org/jpmml/translator/ArgumentsRef.java @@ -72,7 +72,7 @@ public JMethod getMethod(FieldInfo fieldInfo, TranslationContext context){ try { context.pushOwner(argumentsClazz); - encoderMethod = createEncoderMethod(field, context); + encoderMethod = ArgumentsRef.createEncoderMethod(field, context); } finally { context.popOwner(); } @@ -84,30 +84,25 @@ public JMethod getMethod(FieldInfo fieldInfo, TranslationContext context){ JBlock block = method.body(); - JBlock initializerBlock; + JExpression valueExpr = JExpr.invoke(encoderMethod).arg(context.constantFieldName(name)); Integer count = fieldInfo.getCount(); if(count != null && count > 1){ - JFieldVar fieldFlagVar = argumentsClazz.field(JMod.PRIVATE, boolean.class, "_" + stringName, JExpr.FALSE); - JFieldVar fieldVar = argumentsClazz.field(JMod.PRIVATE, type, stringName); - - JBlock thenBlock = block._if(JExpr.refthis(fieldFlagVar.name()).not())._then(); + JExpression initExpr; - thenBlock.assign(JExpr.refthis(fieldFlagVar.name()), JExpr.TRUE); + if(encoder != null){ + initExpr = encoder.createInitExpression(field, context); + } else - initializerBlock = thenBlock; - } else - - { - initializerBlock = block; - } + { + initExpr = ArgumentsRef.createInitExpression(field, context); + } - JExpression valueExpr = JExpr.invoke(encoderMethod).arg(context.constantFieldName(name)); + JFieldVar fieldVar = argumentsClazz.field(JMod.PRIVATE, type, stringName, initExpr); - if(count != null && count > 1){ - JFieldVar fieldVar = (argumentsClazz.fields()).get(stringName); + JBlock thenBlock = block._if(JExpr.refthis(fieldVar.name()).eq(initExpr))._then(); - initializerBlock.assign(JExpr.refthis(fieldVar.name()), valueExpr); + thenBlock.assign(JExpr.refthis(fieldVar.name()), valueExpr); block._return(JExpr.refthis(fieldVar.name())); } else @@ -119,7 +114,8 @@ public JMethod getMethod(FieldInfo fieldInfo, TranslationContext context){ return method; } - public JMethod createEncoderMethod(Field field, TranslationContext context){ + static + private JMethod createEncoderMethod(Field field, TranslationContext context){ JDefinedClass owner = context.getOwner(); JType fieldNameClazz = context.ref(FieldName.class); @@ -177,4 +173,68 @@ public JMethod createEncoderMethod(Field field, TranslationContext context){ return method; } + + static + private JExpression createInitExpression(Field field, TranslationContext context){ + JDefinedClass owner = context.getOwner(); + + DataType dataType = field.getDataType(); + + String name; + + switch(dataType){ + case STRING: + name = "INIT_STRING_VALUE"; + break; + case INTEGER: + name = "INIT_INTEGER_VALUE"; + break; + case FLOAT: + name = "INIT_FLOAT_VALUE"; + break; + case DOUBLE: + name = "INIT_DOUBLE_VALUE"; + break; + case BOOLEAN: + name = "INIT_BOOLEAN_VALUE"; + break; + default: + throw new IllegalArgumentException(dataType.toString()); + } + + JFieldVar constantVar = (owner.fields()).get(name); + if(constantVar == null){ + JType type; + JExpression initExpr; + + switch(dataType){ + case STRING: + type = context.ref(String.class); + initExpr = JExpr._new(type); + break; + case INTEGER: + type = context.ref(Integer.class); + initExpr = JExpr._new(type).arg(JExpr.lit(-999)); + break; + case FLOAT: + type = context.ref(Float.class); + initExpr = JExpr._new(type).arg(JExpr.lit(-999f)); + break; + case DOUBLE: + type = context.ref(Double.class); + initExpr = JExpr._new(type).arg(JExpr.lit(-999d)); + break; + case BOOLEAN: + type = context.ref(Boolean.class); + initExpr = JExpr._new(type).arg(JExpr.lit(false)); + break; + default: + throw new IllegalArgumentException(dataType.toString()); + } + + constantVar = owner.field((JMod.PRIVATE | JMod.FINAL | JMod.STATIC), type, name, initExpr); + } + + return owner.staticRef(constantVar); + } } \ No newline at end of file diff --git a/src/main/java/org/jpmml/translator/Encoder.java b/src/main/java/org/jpmml/translator/Encoder.java index da98d77..ac0953b 100644 --- a/src/main/java/org/jpmml/translator/Encoder.java +++ b/src/main/java/org/jpmml/translator/Encoder.java @@ -18,6 +18,7 @@ */ package org.jpmml.translator; +import com.sun.codemodel.JExpression; import com.sun.codemodel.JMethod; import com.sun.codemodel.JVar; import org.dmg.pmml.Field; @@ -31,4 +32,6 @@ public interface Encoder { OperableRef ref(JVar variable); JMethod createEncoderMethod(Field field, TranslationContext context); + + JExpression createInitExpression(Field field, TranslationContext context); } \ No newline at end of file diff --git a/src/main/java/org/jpmml/translator/FpPrimitiveEncoder.java b/src/main/java/org/jpmml/translator/FpPrimitiveEncoder.java index 1ba2d47..7574c0a 100644 --- a/src/main/java/org/jpmml/translator/FpPrimitiveEncoder.java +++ b/src/main/java/org/jpmml/translator/FpPrimitiveEncoder.java @@ -104,10 +104,10 @@ public JMethod createEncoderMethod(Field field, TranslationContext context){ switch(dataType){ case FLOAT: - nanExpr = JExpr.lit(Float.NaN); + nanExpr = FpPrimitiveEncoder.NAN_VALUE_FLOAT; break; case DOUBLE: - nanExpr = JExpr.lit(Double.NaN); + nanExpr = FpPrimitiveEncoder.NAN_VALUE_DOUBLE; break; default: throw new IllegalArgumentException(dataType.toString()); @@ -120,4 +120,24 @@ public JMethod createEncoderMethod(Field field, TranslationContext context){ return method; } + + @Override + public JExpression createInitExpression(Field field, TranslationContext context){ + DataType dataType = field.getDataType(); + + switch(dataType){ + case FLOAT: + return FpPrimitiveEncoder.INIT_VALUE_FLOAT; + case DOUBLE: + return FpPrimitiveEncoder.INIT_VALUE_DOUBLE; + default: + throw new IllegalArgumentException(dataType.toString()); + } + } + + public static final JExpression INIT_VALUE_FLOAT = JExpr.lit(-999f); + public static final JExpression INIT_VALUE_DOUBLE = JExpr.lit(-999d); + + public static final JExpression NAN_VALUE_FLOAT = JExpr.lit(Float.NaN); + public static final JExpression NAN_VALUE_DOUBLE = JExpr.lit(Double.NaN); } \ No newline at end of file diff --git a/src/main/java/org/jpmml/translator/OrdinalEncoder.java b/src/main/java/org/jpmml/translator/OrdinalEncoder.java index 2d47a55..0fb8ad1 100644 --- a/src/main/java/org/jpmml/translator/OrdinalEncoder.java +++ b/src/main/java/org/jpmml/translator/OrdinalEncoder.java @@ -91,6 +91,11 @@ public JMethod createEncoderMethod(Field field, TranslationContext context){ return method; } + @Override + public JExpression createInitExpression(Field field, TranslationContext context){ + return OrdinalEncoder.INIT_VALUE; + } + public JMethod ensureIsSetMethod(TranslationContext context){ if(this.isSetMethod == null){ @@ -128,5 +133,6 @@ private JMethod getOrCreateIsSetMethod(TranslationContext context){ return isSetMethod; } + public static final JExpression INIT_VALUE = JExpr.lit(-999); public static final JExpression MISSING_VALUE = JExpr.lit(-1); } \ No newline at end of file