Skip to content

Commit

Permalink
refactor categories
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Mar 22, 2024
1 parent 5fa3cb2 commit cdf15b3
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 111 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package org.matsim.application.analysis.population;

import java.util.*;

/**
* Categorize values into groups.
*/
public final class Category {

private static final Set<String> TRUE = Set.of("true", "yes", "1", "on", "y", "j", "ja");
private static final Set<String> FALSE = Set.of("false", "no", "0", "off", "n", "nein");

/**
* Unique values of the category.
*/
private final Set<String> values;

/**
* Groups of values that have been subsumed under a single category.
* These are values separated by ,
*/
private final Map<String, String> grouped;

/**
* Range categories.
*/
private final List<Range> ranges;

public Category(Set<String> values) {
this.values = values;
this.grouped = new HashMap<>();
for (String v : values) {
if (v.contains(",")) {
String[] grouped = v.split(",");
for (String g : grouped) {
this.grouped.put(g, v);
}
}
}

boolean range = this.values.stream().allMatch(v -> v.contains("-") || v.contains("+"));
if (range) {
ranges = new ArrayList<>();
for (String value : this.values) {
if (value.contains("-")) {
String[] parts = value.split("-");
ranges.add(new Range(Double.parseDouble(parts[0]), Double.parseDouble(parts[1]), value));
} else if (value.contains("+")) {
ranges.add(new Range(Double.parseDouble(value.replace("+", "")), Double.POSITIVE_INFINITY, value));
}
}

ranges.sort(Comparator.comparingDouble(r -> r.left));
} else
ranges = null;


// Check if all values are boolean
if (values.stream().allMatch(v -> TRUE.contains(v.toLowerCase()) || FALSE.contains(v.toLowerCase()))) {
for (String value : values) {
Set<String> group = TRUE.contains(value.toLowerCase()) ? TRUE : FALSE;
for (String g : group) {
this.grouped.put(g, value);
}
}
}
}

/**
* Categorize a single value.
*/
public String categorize(Object value) {

if (value == null)
return null;

if (value instanceof Boolean) {
// Booleans and synonyms are in the group map
return categorize(((Boolean) value).toString().toLowerCase());
} else if (value instanceof Number) {
return categorizeNumber((Number) value);
} else {
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);

try {
double d = Double.parseDouble(v);
return categorizeNumber(d);
} catch (NumberFormatException e) {
return null;
}
}
}

private String categorizeNumber(Number value) {

if (ranges != null) {
for (Range r : ranges) {
if (value.doubleValue() >= r.left && value.doubleValue() < r.right)
return r.label;
}
}

// Match string representation
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);


// Convert the number to a whole number, which will have a different string representation
if (value instanceof Float || value instanceof Double) {
return categorizeNumber(value.longValue());
}

return null;
}

/**
* @param left Left bound of the range.
* @param right Right bound of the range. (exclusive)
* @param label Label of this group.
*/
private record Range(double left, double right, String label) {


}

}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ final class TripByGroupAnalysis {
for (List<String> group : groups) {
for (String g : group) {
if (!this.categories.containsKey(g)) {
this.categories.put(g, new Category(ref.column(g)));
this.categories.put(g, new Category(ref.column(g).asStringColumn().removeMissing().asSet()));
}
}
}
Expand Down Expand Up @@ -143,114 +143,4 @@ void groupPersons(Table persons) {
private record Group(List<String> columns, Table data) {
}

private static final class Category {

/**
* Unique values of the category.
*/
private final Set<String> values;

/**
* Groups of values that have been subsumed under a single category.
* These are values separated by ,
*/
private final Map<String, String> grouped;

/**
* Range categories.
*/
private final List<Range> ranges;

public Category(Column<?> data) {
this.values = data.asStringColumn().unique()
.removeMissing()
.asSet();

this.grouped = new HashMap<>();
for (String v : values) {
if (v.contains(",")) {
String[] grouped = v.split(",");
for (String g : grouped) {
this.grouped.put(g, v);
}
}
}

boolean range = this.values.stream().allMatch(v -> v.contains("-") || v.contains("+"));
if (range) {
ranges = new ArrayList<>();
for (String value : this.values) {
if (value.contains("-")) {
String[] parts = value.split("-");
ranges.add(new Range(Double.parseDouble(parts[0]), Double.parseDouble(parts[1]), value));
} else if (value.contains("+")) {
ranges.add(new Range(Double.parseDouble(value.replace("+", "")), Double.POSITIVE_INFINITY, value));
}
}

ranges.sort(Comparator.comparingDouble(r -> r.left));
} else
ranges = null;
}

/**
* Categorize a single value.
*/
public String categorize(Object value) {

if (value == null)
return null;

// TODO: handle booleans

if (value instanceof Number) {
return categorizeNumber((Number) value);
} else {
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);

try {
double d = Double.parseDouble(v);
return categorizeNumber(d);
} catch (NumberFormatException e) {
return null;
}
}
}

private String categorizeNumber(Number value) {

if (ranges != null) {
for (Range r : ranges) {
if (value.doubleValue() >= r.left && value.doubleValue() < r.right)
return r.label;
}
}

// Match string representation
// TODO: int and float could be represented differently
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);

return null;
}

}

/**
* @param left Left bound of the range.
* @param right Right bound of the range. (exclusive)
* @param label Label of this group.
*/
private record Range(double left, double right, String label) {


}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package org.matsim.application.analysis.population;

import org.junit.jupiter.api.Test;

import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;

class CategoryTest {

@Test
void standard() {

Category c = new Category(Set.of("a", "b", "c"));

assertThat(c.categorize("a")).isEqualTo("a");
assertThat(c.categorize("b")).isEqualTo("b");
assertThat(c.categorize("c")).isEqualTo("c");
assertThat(c.categorize("d")).isNull();

}

@Test
void ranges() {

Category c = new Category(Set.of("1-2", "2-4", "4+"));

assertThat(c.categorize("1")).isEqualTo("1-2");
assertThat(c.categorize(1)).isEqualTo("1-2");
assertThat(c.categorize(1.0)).isEqualTo("1-2");

assertThat(c.categorize("2")).isEqualTo("2-4");
assertThat(c.categorize("3")).isEqualTo("2-4");
assertThat(c.categorize("5")).isEqualTo("4+");
assertThat(c.categorize(5)).isEqualTo("4+");
assertThat(c.categorize(5.0)).isEqualTo("4+");

}

@Test
void grouped() {

Category c = new Category(Set.of("a,b", "101,102"));

assertThat(c.categorize("a")).isEqualTo("a,b");
assertThat(c.categorize("b")).isEqualTo("a,b");
assertThat(c.categorize(101)).isEqualTo("101,102");
assertThat(c.categorize(102)).isEqualTo("101,102");

}

@Test
void bool() {

Category c = new Category(Set.of("y", "n"));

assertThat(c.categorize("y")).isEqualTo("y");
assertThat(c.categorize("yes")).isEqualTo("y");
assertThat(c.categorize("1")).isEqualTo("y");

assertThat(c.categorize(true)).isEqualTo("y");
assertThat(c.categorize(false)).isEqualTo("n");

}
}

0 comments on commit cdf15b3

Please sign in to comment.