Skip to content

Commit

Permalink
Improved support for categorical integer features
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 29, 2024
1 parent 4ad3b70 commit 5bda4e9
Show file tree
Hide file tree
Showing 7 changed files with 18,542 additions and 14,549 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,20 @@ public JFieldVar initNumbersMap(String name, Map<?, Number> map){

JDefinedClass owner = context.getOwner();

JFieldVar constant = createMapConstant(name, context.ref(Object.class), context.ref(Number.class), context);
Set<?> keys = map.keySet();
Collection<Number> values = map.values();

Class<?> keyClazz = getValueClass(keys);
Class<?> valueClazz = getValueClass(values);

JFieldVar constant = createMapConstant(name, context.ref(keyClazz), context.ref(valueClazz), context);

String keyReadMethod;
String valueReadMethod;

try(OutputStream os = binaryFile.getDataStore()){
DataOutput dataOutput = new DataOutputStream(os);

Set<?> keys = map.keySet();

Class<?> keyClazz = getValueClass(keys);

if(Objects.equals(keyClazz, String.class)){
ResourceUtil.writeStrings(dataOutput, keys.toArray(new String[keys.size()]));
keyReadMethod = "readStrings";
Expand All @@ -387,11 +389,12 @@ public JFieldVar initNumbersMap(String name, Map<?, Number> map){

{
throw new IllegalArgumentException();
}
} // End if

Collection<Number> values = map.values();

Class<?> valueClazz = getValueClass(values);
if(Objects.equals(valueClazz, Integer.class)){
ResourceUtil.writeIntegers(dataOutput, values.toArray(new Integer[values.size()]));
valueReadMethod = "readIntegers";
} else

if(Objects.equals(valueClazz, Float.class)){
ResourceUtil.writeFloats(dataOutput, values.toArray(new Float[values.size()]));
Expand All @@ -410,17 +413,17 @@ public JFieldVar initNumbersMap(String name, Map<?, Number> map){
throw new RuntimeException(ioe);
}

JClass objectArrayClass = context.ref(Object[].class);
JClass numberArrayClass = context.ref(Number[].class);
JClass keysArrayClazz = (context.ref(keyClazz)).array();
JClass valuesArrayClazz = (context.ref(valueClazz)).array();

JMethod putAllMethod = owner.getMethod("putAll", new JType[]{constant.type(), objectArrayClass, numberArrayClass});
JMethod putAllMethod = owner.getMethod("putAll", new JType[]{constant.type(), keysArrayClazz, valuesArrayClazz});
if(putAllMethod == null){
putAllMethod = owner.method(Modifiers.PRIVATE_STATIC_FINAL, void.class, "putAll");

JVar mapParam = putAllMethod.param(constant.type(), "map");

JVar keysParam = putAllMethod.param(objectArrayClass, "keys");
JVar valuesParam = putAllMethod.param(numberArrayClass, "values");
JVar keysParam = putAllMethod.param(keysArrayClazz, "keys");
JVar valuesParam = putAllMethod.param(valuesArrayClazz, "values");

JBlock block = putAllMethod.body();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void evaluateSelectFirstAudit() throws Exception {

@Test
public void evaluateXGBoostAudit() throws Exception {
evaluate(XGBOOST, AUDIT, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(32 + 16));
evaluate(XGBOOST, AUDIT, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(32 + 48));
}

@Test
Expand Down
Loading

0 comments on commit 5bda4e9

Please sign in to comment.