diff --git a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java index ee9186e8b3b..02c3fd676d8 100644 --- a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java +++ b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripAnalysis.java @@ -38,13 +38,23 @@ @CommandLine.Command(name = "trips", description = "Calculates various trip related metrics.") @CommandSpec( requires = {"trips.csv", "persons.csv"}, - produces = {"mode_share.csv", "mode_share_per_dist.csv", "mode_users.csv", "trip_stats.csv", - "mode_share_per_%s.csv", "population_trip_stats.csv", "trip_purposes_by_hour.csv", "mode_share_per_age.csv"} + produces = { + "mode_share.csv", "mode_share_per_dist.csv", "mode_users.csv", "trip_stats.csv", + "mode_share_per_%s.csv", "population_trip_stats.csv", "trip_purposes_by_hour.csv", + "mode_choices.csv", "mode_choice_evaluation.csv", "mode_choice_evaluation_per_mode.csv" + } ) public class TripAnalysis implements MATSimAppCommand { private static final Logger log = LogManager.getLogger(TripAnalysis.class); - + /** + * Person attribute that contains the reference modes of a person. + */ + public static String ATTR_REF_MODES = "ref_modes"; + /** + * Person attribute containing its weight for analysis purposes. + */ + public static String ATTR_REF_WEIGHT = "ref_weight"; @CommandLine.Mixin private InputOptions input = InputOptions.ofCommand(TripAnalysis.class); @CommandLine.Mixin @@ -209,6 +219,18 @@ public Integer call() throws Exception { groups.analyzeModeShare(joined, labels, (g) -> output.getPath("mode_share_per_%s.csv", g)); } + if (persons.containsColumn(ATTR_REF_MODES)) { + try { + TripChoiceAnalysis choices = new TripChoiceAnalysis(persons, trips, modeOrder); + + choices.writeChoices(output.getPath("mode_choices.csv")); + choices.writeChoiceEvaluation(output.getPath("mode_choice_evaluation.csv")); + choices.writeChoiceEvaluationPerMode(output.getPath("mode_choice_evaluation_per_mode.csv")); + } catch (RuntimeException e) { + log.error("Error while analyzing mode choices", e); + } + } + writePopulationStats(persons, joined); writeTripStats(joined); @@ -358,7 +380,6 @@ private void writePopulationStats(Table persons, Table trips) throws IOException table.write().csv(output.getPath("mode_users.csv").toFile()); try (CSVPrinter printer = new CSVPrinter(Files.newBufferedWriter(output.getPath("population_trip_stats.csv")), CSVFormat.DEFAULT)) { - printer.printRecord("Info", "Value"); printer.printRecord("Persons", tripsPerPerson.size()); printer.printRecord("Mobile persons [%]", new BigDecimal(100 * totalMobile / tripsPerPerson.size()).setScale(2, RoundingMode.HALF_UP)); diff --git a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java index 62d95d0a152..59f669ad159 100644 --- a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java +++ b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripByGroupAnalysis.java @@ -57,7 +57,7 @@ final class TripByGroupAnalysis { groups.add(g); } - log.info("Detect groups: {}", groups); + log.info("Detected groups: {}", groups); this.groups = new ArrayList<>(); diff --git a/contribs/application/src/main/java/org/matsim/application/analysis/population/TripChoiceAnalysis.java b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripChoiceAnalysis.java new file mode 100644 index 00000000000..ce4fa6727ec --- /dev/null +++ b/contribs/application/src/main/java/org/matsim/application/analysis/population/TripChoiceAnalysis.java @@ -0,0 +1,100 @@ +package org.matsim.application.analysis.population; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVPrinter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import tech.tablesaw.api.Row; +import tech.tablesaw.api.Table; +import tech.tablesaw.joining.DataFrameJoiner; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +/** + * Helper class to analyze trip choices from persons against reference data. + * Metrics for binary classification + * Evaluation multi-class classifiers + * more + * more + * ... + */ +final class TripChoiceAnalysis { + + private static final Logger log = LogManager.getLogger(TripChoiceAnalysis.class); + + private final List modeOrder; + + /** + * Contains trip data with true and predicated (simulated) modes. + */ + private final List data = new ArrayList<>(); + + public TripChoiceAnalysis(Table persons, Table trips, List modeOrder) { + persons = persons.where(persons.stringColumn("ref_modes").isNotEqualTo("")); + trips = new DataFrameJoiner(trips, "person").inner(persons);; + this.modeOrder = modeOrder; + + log.info("Analyzing mode choices for {} persons", persons.rowCount()); + + for (Row trip : trips) { + + String person = trip.getText("person"); + int n = trip.getInt("trip_number") - 1; + double weight = trip.getDouble(TripAnalysis.ATTR_REF_WEIGHT); + + String predMode = trip.getString("main_mode"); + String[] split = trip.getString(TripAnalysis.ATTR_REF_MODES).split("-"); + + if (n < split.length) { + String trueMode = split[n]; + data.add(new Entry(person, weight, n, trueMode, predMode)); + } else + log.warn("Person {} trip {} does not match ref data ({})", person, n, split.length); + } + } + + /** + * Writes all choices to csv. + */ + public void writeChoices(Path path) throws IOException { + try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) { + csv.printRecord("person", "weight", "true_mode", "pred_mode"); + for (Entry e : data) { + csv.printRecord(e.person, e.weight, e.trueMode, e.predMode); + } + } + } + + /** + * Writes aggregated choices metrics. + */ + public void writeChoiceEvaluation(Path path) throws IOException { + + try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) { + + csv.printRecord("Info", "Value"); + + csv.printRecord("Accuracy", "TODO"); + + + } + + + // TODO: accuracy + // macro and micro averaged precision, recall, f1 + } + + /** + * Writes metrics per mode. + */ + public void writeChoiceEvaluationPerMode(Path path) { + + } + + private record Entry(String person, double weight, int n, String trueMode, String predMode) { + } +}