Skip to content

Commit

Permalink
Merge branch 'master' into custom-constraints-update
Browse files Browse the repository at this point in the history
  • Loading branch information
nkuehnel authored Jun 25, 2024
2 parents a9949d5 + abbab85 commit 38d54d6
Show file tree
Hide file tree
Showing 31 changed files with 2,639 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package org.matsim.application.analysis.population;

import org.matsim.core.config.ReflectiveConfigGroup;
import org.matsim.utils.objectattributes.attributable.Attributes;

import java.util.*;
import java.util.regex.Pattern;

/**
* Helper class to categorize values into groups.
*/
public final class Category {

private static final Set<String> TRUE = Set.of("true", "yes", "1", "on", "y", "j", "ja");
private static final Set<String> FALSE = Set.of("false", "no", "0", "off", "n", "nein");

/**
* Unique values of the category.
*/
private final Set<String> values;

/**
* Groups of values that have been subsumed under a single category.
* These are values separated by ,
*/
private final Map<String, String> grouped;


/**
* Regular expressions for each category.
*/
private final Map<String, Pattern> regex;

/**
* Range categories.
*/
private final List<Range> ranges;

public Category(Set<String> values) {
this.values = values;
this.grouped = new HashMap<>();
this.regex = new HashMap<>();
for (String v : values) {
if (v.contains(",")) {
String[] grouped = v.split(",");
for (String g : grouped) {
this.grouped.put(g, v);
}
}

if (v.startsWith("/") && v.endsWith("/")) {
this.regex.put(v, Pattern.compile(v.substring(1, v.length() - 1), Pattern.CASE_INSENSITIVE));
}
}

boolean range = this.values.stream().allMatch(v -> v.contains("-") || v.contains("+"));
if (range) {
ranges = new ArrayList<>();
for (String value : this.values) {
if (value.contains("-")) {
String[] parts = value.split("-");
ranges.add(new Range(Double.parseDouble(parts[0]), Double.parseDouble(parts[1]), value));
} else if (value.contains("+")) {
ranges.add(new Range(Double.parseDouble(value.replace("+", "")), Double.POSITIVE_INFINITY, value));
}
}

ranges.sort(Comparator.comparingDouble(r -> r.left));
} else
ranges = null;


// Check if all values are boolean
if (values.stream().allMatch(v -> TRUE.contains(v.toLowerCase()) || FALSE.contains(v.toLowerCase()))) {
for (String value : values) {
Set<String> group = TRUE.contains(value.toLowerCase()) ? TRUE : FALSE;
for (String g : group) {
this.grouped.put(g, value);
}
}
}
}

/**
* Create categories from config parameters.
*/
public static Map<String, Category> fromConfigParams(Collection<? extends ReflectiveConfigGroup> params) {

Map<String, Set<String>> categories = new HashMap<>();

// Collect all values
for (ReflectiveConfigGroup parameter : params) {
for (Map.Entry<String, String> kv : parameter.getParams().entrySet()) {
categories.computeIfAbsent(kv.getKey(), k -> new HashSet<>()).add(kv.getValue());
}
}

return categories.entrySet().stream()
.collect(HashMap::new, (m, e) -> m.put(e.getKey(), new Category(e.getValue())), HashMap::putAll);
}

/**
* Match attributes from an object with parameters defined in config.
*/
public static boolean matchAttributesWithConfig(Attributes attr, ReflectiveConfigGroup config, Map<String, Category> categories) {

for (Map.Entry<String, String> e : config.getParams().entrySet()) {
// might be null if not defined
Object objValue = attr.getAttribute(e.getKey());
String category = categories.get(e.getKey()).categorize(objValue);

// compare as string
if (!Objects.toString(category).equals(e.getValue()))
return false;
}

return true;
}

/**
* Categorize a single value.
*/
public String categorize(Object value) {

if (value == null)
return null;

if (value instanceof Boolean) {
// Booleans and synonyms are in the group map
return categorize(((Boolean) value).toString().toLowerCase());
} else if (value instanceof Number) {
return categorizeNumber((Number) value);
} else {
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);
else {
for (Map.Entry<String, Pattern> kv : regex.entrySet()) {
if (kv.getValue().matcher(v).matches())
return kv.getKey();
}
}

try {
double d = Double.parseDouble(v);
return categorizeNumber(d);
} catch (NumberFormatException e) {
return null;
}
}
}

private String categorizeNumber(Number value) {

if (ranges != null) {
for (Range r : ranges) {
if (value.doubleValue() >= r.left && value.doubleValue() < r.right)
return r.label;
}
}

// Match string representation
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);


// Convert the number to a whole number, which will have a different string representation
if (value instanceof Float || value instanceof Double) {
return categorizeNumber(value.longValue());
}

return null;
}

@Override
public String toString() {
return "Category{" +
"values=" + values +
(grouped != null && !grouped.isEmpty() ? ", grouped=" + grouped : "") +
(regex != null && !regex.isEmpty() ? ", regex=" + regex : "") +
(ranges != null && !ranges.isEmpty() ? ", ranges=" + ranges : "") +
'}';
}

/**
* Number range.
*
* @param left Left bound of the range.
* @param right Right bound of the range. (exclusive)
* @param label Label of this group.
*/
private record Range(double left, double right, String label) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,49 @@
import tech.tablesaw.joining.DataFrameJoiner;
import tech.tablesaw.selection.Selection;

import java.io.*;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.file.Files;
import java.util.*;
import java.util.zip.GZIPInputStream;

import static tech.tablesaw.aggregate.AggregateFunctions.count;

@CommandLine.Command(name = "trips", description = "Calculates various trip related metrics.")
@CommandSpec(
requires = {"trips.csv", "persons.csv"},
produces = {"mode_share.csv", "mode_share_per_dist.csv", "mode_users.csv", "trip_stats.csv", "population_trip_stats.csv", "trip_purposes_by_hour.csv"}
produces = {
"mode_share.csv", "mode_share_per_dist.csv", "mode_users.csv", "trip_stats.csv",
"mode_share_per_%s.csv", "population_trip_stats.csv", "trip_purposes_by_hour.csv",
"mode_choices.csv", "mode_choice_evaluation.csv", "mode_choice_evaluation_per_mode.csv",
"mode_confusion_matrix.csv", "mode_prediction_error.csv"
}
)
public class TripAnalysis implements MATSimAppCommand {

private static final Logger log = LogManager.getLogger(TripAnalysis.class);

/**
* Attributes which relates this person to a reference person.
*/
public static String ATTR_REF_ID = "ref_id";
/**
* Person attribute that contains the reference modes of a person. Multiple modes are delimited by "-".
*/
public static String ATTR_REF_MODES = "ref_modes";
/**
* Person attribute containing its weight for analysis purposes.
*/
public static String ATTR_REF_WEIGHT = "ref_weight";

@CommandLine.Mixin
private InputOptions input = InputOptions.ofCommand(TripAnalysis.class);
@CommandLine.Mixin
private OutputOptions output = OutputOptions.ofCommand(TripAnalysis.class);

@CommandLine.Option(names = "--input-ref-data", description = "Optional path to reference data", required = false)
private String refData;

@CommandLine.Option(names = "--match-id", description = "Pattern to filter agents by id")
private String matchId;

Expand Down Expand Up @@ -95,7 +115,7 @@ public Integer call() throws Exception {
Table persons = Table.read().csv(CsvReadOptions.builder(IOUtils.getBufferedReader(input.getPath("persons.csv")))
.columnTypesPartial(Map.of("person", ColumnType.TEXT))
.sample(false)
.separator(new CsvOptions().detectDelimiter(input.getPath("persons.csv"))).build());
.separator(CsvOptions.detectDelimiter(input.getPath("persons.csv"))).build());

int total = persons.rowCount();

Expand Down Expand Up @@ -132,6 +152,7 @@ public Integer call() throws Exception {

// Map.of only has 10 argument max
columnTypes.put("traveled_distance", ColumnType.LONG);
columnTypes.put("euclidean_distance", ColumnType.LONG);

Table trips = Table.read().csv(CsvReadOptions.builder(IOUtils.getBufferedReader(input.getPath("trips.csv")))
.columnTypesPartial(columnTypes)
Expand Down Expand Up @@ -172,6 +193,12 @@ public Integer call() throws Exception {
trips = trips.where(Selection.with(idx.toIntArray()));
}

TripByGroupAnalysis groups = null;
if (refData != null) {
groups = new TripByGroupAnalysis(refData);
groups.groupPersons(persons);
}

// Use longest_distance_mode where main_mode is not present
trips.stringColumn("main_mode")
.set(trips.stringColumn("main_mode").isMissing(),
Expand All @@ -196,6 +223,24 @@ public Integer call() throws Exception {

writeModeShare(joined, labels);

if (groups != null) {
groups.analyzeModeShare(joined, labels, modeOrder, (g) -> output.getPath("mode_share_per_%s.csv", g));
}

if (persons.containsColumn(ATTR_REF_MODES)) {
try {
TripChoiceAnalysis choices = new TripChoiceAnalysis(persons, trips, modeOrder);

choices.writeChoices(output.getPath("mode_choices.csv"));
choices.writeChoiceEvaluation(output.getPath("mode_choice_evaluation.csv"));
choices.writeChoiceEvaluationPerMode(output.getPath("mode_choice_evaluation_per_mode.csv"));
choices.writeConfusionMatrix(output.getPath("mode_confusion_matrix.csv"));
choices.writeModePredictionError(output.getPath("mode_prediction_error.csv"));
} catch (RuntimeException e) {
log.error("Error while analyzing mode choices", e);
}
}

writePopulationStats(persons, joined);

writeTripStats(joined);
Expand Down Expand Up @@ -345,7 +390,6 @@ private void writePopulationStats(Table persons, Table trips) throws IOException
table.write().csv(output.getPath("mode_users.csv").toFile());

try (CSVPrinter printer = new CSVPrinter(Files.newBufferedWriter(output.getPath("population_trip_stats.csv")), CSVFormat.DEFAULT)) {

printer.printRecord("Info", "Value");
printer.printRecord("Persons", tripsPerPerson.size());
printer.printRecord("Mobile persons [%]", new BigDecimal(100 * totalMobile / tripsPerPerson.size()).setScale(2, RoundingMode.HALF_UP));
Expand Down Expand Up @@ -386,8 +430,7 @@ private void writeTripPurposes(Table trips) {
TextColumn purpose = trips.textColumn("end_activity_type");

// Remove suffix durations like _345
Selection withDuration = purpose.matchesRegex("^.+_[0-9]+$");
purpose.set(withDuration, purpose.where(withDuration).replaceAll("_[0-9]+$", ""));
purpose.set(Selection.withRange(0, purpose.size()), purpose.replaceAll("_[0-9]{2,}$", ""));

Table tArrival = trips.summarize("trip_id", count).by("end_activity_type", "arrival_h");

Expand Down
Loading

0 comments on commit 38d54d6

Please sign in to comment.