Skip to content

Commit

Permalink
support list of doubles and ints as parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Feb 11, 2024
1 parent edabc86 commit 72d42d9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,26 @@ void createParamSet() {
config, input.resolve("multiLevel.yml")
);

testGroup.addParam("values", "1, 2, 3");

assertThat(testGroup.values)
.containsExactly(1, 2, 3);

Collection<? extends ConfigGroup> params = testGroup.getParameterSets("params");

assertThat(params).hasSize(2);

Iterator<? extends ConfigGroup> it = params.iterator();
ConfigGroup next = it.next();
TestParamSet next = (TestParamSet) it.next();

assertThat(next.getParams().get("mode")).isEqualTo("car");
assertThat(next.getParams().get("values")).isEqualTo("-1, -2");
assertThat(next.getParams().get("values")).isEqualTo("-1.0, -2.0");
assertThat(next.values).containsExactly(-1d, -2d);

next = it.next();
next = (TestParamSet) it.next();

assertThat(next.getParams().get("mode")).isEqualTo("bike");
assertThat(next.getParams().get("values")).isEqualTo("3, 4");
assertThat(next.getParams().get("values")).isEqualTo("3.0, 4.0");
assertThat(next.getParams().get("extra")).isEqualTo("extra");
}

Expand All @@ -91,10 +97,10 @@ void multiLevel() {
ConfigGroup next = it.next();

// These parameters are recognized as lists correctly
assertThat(next.getParams().get("values")).isEqualTo("-1, -2");
assertThat(next.getParams().get("values")).isEqualTo("-1.0, -2.0");

next = it.next();
assertThat(next.getParams().get("values")).isEqualTo("3, 4");
assertThat(next.getParams().get("values")).isEqualTo("3.0, 4.0");
assertThat(next.getParams().get("extra")).isEqualTo("extra");

}
Expand Down Expand Up @@ -122,7 +128,7 @@ void ambiguous() {
public static final class TestConfigGroup extends ReflectiveConfigGroup {

@Parameter
private List<String> values;
private List<Integer> values;

public TestConfigGroup() {
super("test");
Expand All @@ -142,7 +148,7 @@ public ConfigGroup createParameterSet(String type) {
public static final class TestParamSet extends ReflectiveConfigGroup {

@Parameter
private List<String> values;
private List<Double> values;

public TestParamSet() {
super("params", true);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

test:
values: [1,2,3]
params:
- mode: car
subpopulation: person
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ private static boolean checkType(Type type) {
var rawType = pType.getRawType();
if (rawType.equals(List.class) || rawType.equals(Set.class)) {
var typeArgument = pType.getActualTypeArguments()[0];
return typeArgument.equals(String.class) || (typeArgument instanceof Class && ((Class<?>) typeArgument).isEnum());
return typeArgument.equals(String.class) ||
typeArgument.equals(Double.class) ||
typeArgument.equals(Integer.class) ||
(typeArgument instanceof Class && ((Class<?>) typeArgument).isEnum());
}

if (rawType.equals(Class.class))
Expand Down Expand Up @@ -412,6 +415,11 @@ private Object fromString(String value, Class<?> type, @Nullable Field paramFiel
List<? extends Enum<?>> enumConstants = getEnumConstants(paramField);
return stream.map(s -> stringToEnumValue(s, enumConstants)).collect(toImmutableSet());
}
if (paramField != null && isCollectionOfDoubleType(paramField))
return stream.map(Double::parseDouble).collect(toImmutableSet());
if (paramField != null && isCollectionOfIntegerType(paramField))
return stream.map(Integer::parseInt).collect(toImmutableSet());

return stream.collect(toImmutableSet());
} else if (type.equals(List.class)) {
if (value.isBlank()) {
Expand All @@ -422,6 +430,11 @@ private Object fromString(String value, Class<?> type, @Nullable Field paramFiel
List<? extends Enum<?>> enumConstants = getEnumConstants(paramField);
return stream.map(s -> stringToEnumValue(s, enumConstants)).toList();
}
if (paramField != null && isCollectionOfDoubleType(paramField))
return stream.map(Double::parseDouble).toList();
if (paramField != null && isCollectionOfIntegerType(paramField))
return stream.map(Integer::parseInt).toList();

return stream.toList();
} else if (type.equals(Class.class)) {
try {
Expand Down Expand Up @@ -488,7 +501,9 @@ private String getParamField(Field paramField) {
boolean accessible = enforceAccessible(paramField);
try {
var result = paramField.get(this);
if (result != null && isCollectionOfEnumsWithUniqueStringValues(paramField)) {
if (result != null && (isCollectionOfEnumsWithUniqueStringValues(paramField) ||
isCollectionOfDoubleType(paramField) ||
isCollectionOfIntegerType(paramField))) {
result = ((Collection<Object>) result).stream()
.map(Object::toString) // map enum values to string
.collect(Collectors.toList());
Expand Down Expand Up @@ -674,6 +689,30 @@ private static boolean isCollectionOfEnumsWithUniqueStringValues(Field paramFiel
return false;
}

private static boolean isCollectionOfIntegerType(Field paramField) {
var type = paramField.getGenericType();
if (type instanceof ParameterizedType pType) {
var rawType = pType.getRawType();
if (rawType.equals(List.class) || rawType.equals(Set.class)) {
var typeArgument = pType.getActualTypeArguments()[0];
return typeArgument.equals(Integer.class) || typeArgument.equals(Integer.TYPE);
}
}
return false;
}

private static boolean isCollectionOfDoubleType(Field paramField) {
var type = paramField.getGenericType();
if (type instanceof ParameterizedType pType) {
var rawType = pType.getRawType();
if (rawType.equals(List.class) || rawType.equals(Set.class)) {
var typeArgument = pType.getActualTypeArguments()[0];
return typeArgument.equals(Double.class) || typeArgument.equals(Double.TYPE);
}
}
return false;
}

private static <T> boolean enumStringsAreUnique(Class<T> enumClass) {
T[] enumConstants = enumClass.getEnumConstants();
long uniqueStringValues = Arrays.stream(enumConstants)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.matsim.api.core.v01.Coord;
import org.matsim.api.core.v01.Id;
import org.matsim.api.core.v01.network.Link;
import org.matsim.api.core.v01.population.Person;
import org.matsim.core.config.ReflectiveConfigGroup.InconsistentModuleException;
import org.matsim.testcases.MatsimTestUtils;

Expand Down Expand Up @@ -70,6 +71,7 @@ void testDumpAndRead() {
dumpedModule.enumSetField = Set.of(MyEnum.VALUE2);
dumpedModule.setField = ImmutableSet.of("a", "b", "c");
dumpedModule.listField = List.of("1", "2", "3");
dumpedModule.ints = List.of(1, 2, 3);
assertEqualAfterDumpAndRead(dumpedModule);
}

Expand Down Expand Up @@ -166,6 +168,7 @@ void testComments() {
expectedComments.put("enumListField", "list of enum");
expectedComments.put("enumSetField", "set of enum");
expectedComments.put("setField", "set");
expectedComments.put("ints", "list of ints");

assertThat(new MyModule().getComments()).isEqualTo(expectedComments);
}
Expand Down Expand Up @@ -324,7 +327,7 @@ void testFailUnsupportedType_StringCollections() {
void testFailUnsupportedType_NonStringList() {
assertThatThrownBy(() -> new ReflectiveConfigGroup("name") {
@Parameter("field")
private List<Double> stuff;
private List<Person> stuff;
}).isInstanceOf(InconsistentModuleException.class);
}

Expand Down Expand Up @@ -472,6 +475,10 @@ private static class MyModule extends ReflectiveConfigGroup {
@Parameter
private Set<MyEnum> enumSetField;

@Comment("list of ints")
@Parameter
private List<Integer> ints;

// Object fields:
// Id: string representation is toString
private Id<Link> idField;
Expand Down

0 comments on commit 72d42d9

Please sign in to comment.