Skip to content

Commit

Permalink
Refactored the initial state of model arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Nov 19, 2020
1 parent 7a7499c commit 8440977
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 20 deletions.
96 changes: 78 additions & 18 deletions src/main/java/org/jpmml/translator/ArgumentsRef.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public JMethod getMethod(FieldInfo fieldInfo, TranslationContext context){
try {
context.pushOwner(argumentsClazz);

encoderMethod = createEncoderMethod(field, context);
encoderMethod = ArgumentsRef.createEncoderMethod(field, context);
} finally {
context.popOwner();
}
Expand All @@ -84,30 +84,25 @@ public JMethod getMethod(FieldInfo fieldInfo, TranslationContext context){

JBlock block = method.body();

JBlock initializerBlock;
JExpression valueExpr = JExpr.invoke(encoderMethod).arg(context.constantFieldName(name));

Integer count = fieldInfo.getCount();
if(count != null && count > 1){
JFieldVar fieldFlagVar = argumentsClazz.field(JMod.PRIVATE, boolean.class, "_" + stringName, JExpr.FALSE);
JFieldVar fieldVar = argumentsClazz.field(JMod.PRIVATE, type, stringName);

JBlock thenBlock = block._if(JExpr.refthis(fieldFlagVar.name()).not())._then();
JExpression initExpr;

thenBlock.assign(JExpr.refthis(fieldFlagVar.name()), JExpr.TRUE);
if(encoder != null){
initExpr = encoder.createInitExpression(field, context);
} else

initializerBlock = thenBlock;
} else

{
initializerBlock = block;
}
{
initExpr = ArgumentsRef.createInitExpression(field, context);
}

JExpression valueExpr = JExpr.invoke(encoderMethod).arg(context.constantFieldName(name));
JFieldVar fieldVar = argumentsClazz.field(JMod.PRIVATE, type, stringName, initExpr);

if(count != null && count > 1){
JFieldVar fieldVar = (argumentsClazz.fields()).get(stringName);
JBlock thenBlock = block._if(JExpr.refthis(fieldVar.name()).eq(initExpr))._then();

initializerBlock.assign(JExpr.refthis(fieldVar.name()), valueExpr);
thenBlock.assign(JExpr.refthis(fieldVar.name()), valueExpr);

block._return(JExpr.refthis(fieldVar.name()));
} else
Expand All @@ -119,7 +114,8 @@ public JMethod getMethod(FieldInfo fieldInfo, TranslationContext context){
return method;
}

public JMethod createEncoderMethod(Field<?> field, TranslationContext context){
static
private JMethod createEncoderMethod(Field<?> field, TranslationContext context){
JDefinedClass owner = context.getOwner();

JType fieldNameClazz = context.ref(FieldName.class);
Expand Down Expand Up @@ -177,4 +173,68 @@ public JMethod createEncoderMethod(Field<?> field, TranslationContext context){

return method;
}

static
private JExpression createInitExpression(Field<?> field, TranslationContext context){
JDefinedClass owner = context.getOwner();

DataType dataType = field.getDataType();

String name;

switch(dataType){
case STRING:
name = "INIT_STRING_VALUE";
break;
case INTEGER:
name = "INIT_INTEGER_VALUE";
break;
case FLOAT:
name = "INIT_FLOAT_VALUE";
break;
case DOUBLE:
name = "INIT_DOUBLE_VALUE";
break;
case BOOLEAN:
name = "INIT_BOOLEAN_VALUE";
break;
default:
throw new IllegalArgumentException(dataType.toString());
}

JFieldVar constantVar = (owner.fields()).get(name);
if(constantVar == null){
JType type;
JExpression initExpr;

switch(dataType){
case STRING:
type = context.ref(String.class);
initExpr = JExpr._new(type);
break;
case INTEGER:
type = context.ref(Integer.class);
initExpr = JExpr._new(type).arg(JExpr.lit(-999));
break;
case FLOAT:
type = context.ref(Float.class);
initExpr = JExpr._new(type).arg(JExpr.lit(-999f));
break;
case DOUBLE:
type = context.ref(Double.class);
initExpr = JExpr._new(type).arg(JExpr.lit(-999d));
break;
case BOOLEAN:
type = context.ref(Boolean.class);
initExpr = JExpr._new(type).arg(JExpr.lit(false));
break;
default:
throw new IllegalArgumentException(dataType.toString());
}

constantVar = owner.field((JMod.PRIVATE | JMod.FINAL | JMod.STATIC), type, name, initExpr);
}

return owner.staticRef(constantVar);
}
}
3 changes: 3 additions & 0 deletions src/main/java/org/jpmml/translator/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.jpmml.translator;

import com.sun.codemodel.JExpression;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JVar;
import org.dmg.pmml.Field;
Expand All @@ -31,4 +32,6 @@ public interface Encoder {
OperableRef ref(JVar variable);

JMethod createEncoderMethod(Field<?> field, TranslationContext context);

JExpression createInitExpression(Field<?> field, TranslationContext context);
}
24 changes: 22 additions & 2 deletions src/main/java/org/jpmml/translator/FpPrimitiveEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ public JMethod createEncoderMethod(Field<?> field, TranslationContext context){

switch(dataType){
case FLOAT:
nanExpr = JExpr.lit(Float.NaN);
nanExpr = FpPrimitiveEncoder.NAN_VALUE_FLOAT;
break;
case DOUBLE:
nanExpr = JExpr.lit(Double.NaN);
nanExpr = FpPrimitiveEncoder.NAN_VALUE_DOUBLE;
break;
default:
throw new IllegalArgumentException(dataType.toString());
Expand All @@ -120,4 +120,24 @@ public JMethod createEncoderMethod(Field<?> field, TranslationContext context){

return method;
}

@Override
public JExpression createInitExpression(Field<?> field, TranslationContext context){
DataType dataType = field.getDataType();

switch(dataType){
case FLOAT:
return FpPrimitiveEncoder.INIT_VALUE_FLOAT;
case DOUBLE:
return FpPrimitiveEncoder.INIT_VALUE_DOUBLE;
default:
throw new IllegalArgumentException(dataType.toString());
}
}

public static final JExpression INIT_VALUE_FLOAT = JExpr.lit(-999f);
public static final JExpression INIT_VALUE_DOUBLE = JExpr.lit(-999d);

public static final JExpression NAN_VALUE_FLOAT = JExpr.lit(Float.NaN);
public static final JExpression NAN_VALUE_DOUBLE = JExpr.lit(Double.NaN);
}
6 changes: 6 additions & 0 deletions src/main/java/org/jpmml/translator/OrdinalEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ public JMethod createEncoderMethod(Field<?> field, TranslationContext context){
return method;
}

@Override
public JExpression createInitExpression(Field<?> field, TranslationContext context){
return OrdinalEncoder.INIT_VALUE;
}

public JMethod ensureIsSetMethod(TranslationContext context){

if(this.isSetMethod == null){
Expand Down Expand Up @@ -128,5 +133,6 @@ private JMethod getOrCreateIsSetMethod(TranslationContext context){
return isSetMethod;
}

public static final JExpression INIT_VALUE = JExpr.lit(-999);
public static final JExpression MISSING_VALUE = JExpr.lit(-1);
}

0 comments on commit 8440977

Please sign in to comment.