diff --git a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java
index ee9186e8b3b..02c3fd676d8 100644
--- a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java
+++ b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java
@@ -38,13 +38,23 @@
@CommandLine.Command(name = "trips", description = "Calculates various trip related metrics.")
@CommandSpec(
requires = {"trips.csv", "persons.csv"},
- produces = {"mode_share.csv", "mode_share_per_dist.csv", "mode_users.csv", "trip_stats.csv",
- "mode_share_per_%s.csv", "population_trip_stats.csv", "trip_purposes_by_hour.csv", "mode_share_per_age.csv"}
+ produces = {
+ "mode_share.csv", "mode_share_per_dist.csv", "mode_users.csv", "trip_stats.csv",
+ "mode_share_per_%s.csv", "population_trip_stats.csv", "trip_purposes_by_hour.csv",
+ "mode_choices.csv", "mode_choice_evaluation.csv", "mode_choice_evaluation_per_mode.csv"
+ }
)
public class TripAnalysis implements MATSimAppCommand {
private static final Logger log = LogManager.getLogger(TripAnalysis.class);
-
+ /**
+ * Person attribute that contains the reference modes of a person.
+ */
+ public static String ATTR_REF_MODES = "ref_modes";
+ /**
+ * Person attribute containing its weight for analysis purposes.
+ */
+ public static String ATTR_REF_WEIGHT = "ref_weight";
@CommandLine.Mixin
private InputOptions input = InputOptions.ofCommand(TripAnalysis.class);
@CommandLine.Mixin
@@ -209,6 +219,18 @@ public Integer call() throws Exception {
groups.analyzeModeShare(joined, labels, (g) -> output.getPath("mode_share_per_%s.csv", g));
}
+ if (persons.containsColumn(ATTR_REF_MODES)) {
+ try {
+ TripChoiceAnalysis choices = new TripChoiceAnalysis(persons, trips, modeOrder);
+
+ choices.writeChoices(output.getPath("mode_choices.csv"));
+ choices.writeChoiceEvaluation(output.getPath("mode_choice_evaluation.csv"));
+ choices.writeChoiceEvaluationPerMode(output.getPath("mode_choice_evaluation_per_mode.csv"));
+ } catch (RuntimeException e) {
+ log.error("Error while analyzing mode choices", e);
+ }
+ }
+
writePopulationStats(persons, joined);
writeTripStats(joined);
@@ -358,7 +380,6 @@ private void writePopulationStats(Table persons, Table trips) throws IOException
table.write().csv(output.getPath("mode_users.csv").toFile());
try (CSVPrinter printer = new CSVPrinter(Files.newBufferedWriter(output.getPath("population_trip_stats.csv")), CSVFormat.DEFAULT)) {
-
printer.printRecord("Info", "Value");
printer.printRecord("Persons", tripsPerPerson.size());
printer.printRecord("Mobile persons [%]", new BigDecimal(100 * totalMobile / tripsPerPerson.size()).setScale(2, RoundingMode.HALF_UP));
diff --git a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java
index 62d95d0a152..59f669ad159 100644
--- a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java
+++ b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java
@@ -57,7 +57,7 @@ final class TripByGroupAnalysis {
groups.add(g);
}
- log.info("Detect groups: {}", groups);
+ log.info("Detected groups: {}", groups);
this.groups = new ArrayList<>();
diff --git a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripChoiceAnalysis.java b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripChoiceAnalysis.java
new file mode 100644
index 00000000000..ce4fa6727ec
--- /dev/null
+++ b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripChoiceAnalysis.java
@@ -0,0 +1,100 @@
+package org.matsim.application.analysis.population;
+
+import org.apache.commons.csv.CSVFormat;
+import org.apache.commons.csv.CSVPrinter;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import tech.tablesaw.api.Row;
+import tech.tablesaw.api.Table;
+import tech.tablesaw.joining.DataFrameJoiner;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Helper class to analyze trip choices from persons against reference data.
+ * Metrics for binary classification
+ * Evaluation multi-class classifiers
+ * more
+ * more
+ * ...
+ */
+final class TripChoiceAnalysis {
+
+ private static final Logger log = LogManager.getLogger(TripChoiceAnalysis.class);
+
+ private final List modeOrder;
+
+ /**
+ * Contains trip data with true and predicated (simulated) modes.
+ */
+ private final List data = new ArrayList<>();
+
+ public TripChoiceAnalysis(Table persons, Table trips, List modeOrder) {
+ persons = persons.where(persons.stringColumn("ref_modes").isNotEqualTo(""));
+ trips = new DataFrameJoiner(trips, "person").inner(persons);;
+ this.modeOrder = modeOrder;
+
+ log.info("Analyzing mode choices for {} persons", persons.rowCount());
+
+ for (Row trip : trips) {
+
+ String person = trip.getText("person");
+ int n = trip.getInt("trip_number") - 1;
+ double weight = trip.getDouble(TripAnalysis.ATTR_REF_WEIGHT);
+
+ String predMode = trip.getString("main_mode");
+ String[] split = trip.getString(TripAnalysis.ATTR_REF_MODES).split("-");
+
+ if (n < split.length) {
+ String trueMode = split[n];
+ data.add(new Entry(person, weight, n, trueMode, predMode));
+ } else
+ log.warn("Person {} trip {} does not match ref data ({})", person, n, split.length);
+ }
+ }
+
+ /**
+ * Writes all choices to csv.
+ */
+ public void writeChoices(Path path) throws IOException {
+ try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {
+ csv.printRecord("person", "weight", "true_mode", "pred_mode");
+ for (Entry e : data) {
+ csv.printRecord(e.person, e.weight, e.trueMode, e.predMode);
+ }
+ }
+ }
+
+ /**
+ * Writes aggregated choices metrics.
+ */
+ public void writeChoiceEvaluation(Path path) throws IOException {
+
+ try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {
+
+ csv.printRecord("Info", "Value");
+
+ csv.printRecord("Accuracy", "TODO");
+
+
+ }
+
+
+ // TODO: accuracy
+ // macro and micro averaged precision, recall, f1
+ }
+
+ /**
+ * Writes metrics per mode.
+ */
+ public void writeChoiceEvaluationPerMode(Path path) {
+
+ }
+
+ private record Entry(String person, double weight, int n, String trueMode, String predMode) {
+ }
+}