Skip to content

Commit

Permalink
normalize shares per reference group correctly, added some facet dash…
Browse files Browse the repository at this point in the history
…boards (WIP)
  • Loading branch information
rakow committed Jun 11, 2024
1 parent 9066011 commit a60bc3e
Show file tree
Hide file tree
Showing 5 changed files with 1,299 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class TripAnalysis implements MATSimAppCommand {

private static final Logger log = LogManager.getLogger(TripAnalysis.class);
/**
* Person attribute that contains the reference modes of a person.
* Person attribute that contains the reference modes of a person. Multiple modes are delimited by "-".
*/
public static String ATTR_REF_MODES = "ref_modes";
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,19 @@ void analyzeModeShare(Table trips, List<String> dists, Function<String, Path> ou
Comparator<Row> cmp = Comparator.comparingInt(row -> dists.indexOf(row.getString("dist_group")));
aggr = aggr.sortOn(cmp.thenComparing(row -> row.getString("main_mode")));

// TODO: norm by category and dist_group
// probably need two separate files as well (with and without dist)
// not normed is more useful for now
// Norm each group to 1
String norm = group.columns.get(0);
if (group.columns.size() > 1)
throw new UnsupportedOperationException("Multiple columns not supported yet");

for (String label : aggr.stringColumn(norm).asSet()) {
DoubleColumn dist_group = aggr.doubleColumn("sim_share");
Selection sel = aggr.stringColumn(norm).isEqualTo(label);

double total = dist_group.where(sel).sum();
if (total > 0)
dist_group.set(sel, dist_group.divide(total));
}

Table joined = new DataFrameJoiner(group.data, join).leftOuter(aggr);
joined.column("share").setName("ref_share");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ public TripChoiceAnalysis(Table persons, Table trips, List<String> modeOrder) {
}

private static double round(double d) {
if (Double.isNaN(d))
return Double.NaN;

return BigDecimal.valueOf(d).setScale(5, RoundingMode.HALF_UP).doubleValue();
}

Expand Down Expand Up @@ -162,9 +165,9 @@ public void writeChoiceEvaluation(Path path) throws IOException {
total = c.total;
}

OptionalDouble precision = counts.values().stream().mapToDouble(TripChoiceAnalysis::precision).average();
OptionalDouble recall = counts.values().stream().mapToDouble(TripChoiceAnalysis::recall).average();
OptionalDouble f1 = counts.values().stream().mapToDouble(TripChoiceAnalysis::f1).average();
OptionalDouble precision = counts.values().stream().mapToDouble(TripChoiceAnalysis::precision).filter(Double::isFinite).average();
OptionalDouble recall = counts.values().stream().mapToDouble(TripChoiceAnalysis::recall).filter(Double::isFinite).average();
OptionalDouble f1 = counts.values().stream().mapToDouble(TripChoiceAnalysis::f1).filter(Double::isFinite).average();

try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
package org.matsim.simwrapper.dashboard;

import org.matsim.application.analysis.population.TripAnalysis;
import org.matsim.application.options.CsvOptions;
import org.matsim.core.utils.io.IOUtils;
import org.matsim.simwrapper.Dashboard;
import org.matsim.simwrapper.Header;
import org.matsim.simwrapper.Layout;
import org.matsim.simwrapper.viz.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.tablesaw.plotly.components.Axis;
import tech.tablesaw.plotly.traces.BarTrace;

import javax.annotation.Nullable;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* Shows trip information, optionally against reference data.
*/
public class TripDashboard implements Dashboard {

private static final Logger log = LoggerFactory.getLogger(TripDashboard.class);

@Nullable
private final String modeShareRefCsv;
@Nullable
Expand Down Expand Up @@ -57,8 +68,18 @@ public TripDashboard(@Nullable String modeShareRefCsv, @Nullable String modeShar
}

private static String[] detectCategories(String groupedRefCsv) {
// TODO: Implement
return new String[0];
try {
Character c = CsvOptions.detectDelimiter(groupedRefCsv);
try (BufferedReader reader = IOUtils.getBufferedReader(groupedRefCsv)) {
String header = reader.readLine();
return Arrays.stream(header.split(String.valueOf(c)))
.filter(s -> !s.equals("main_mode") && !s.equals("share") && !s.equals("dist_group"))
.toArray(String[]::new);
}

} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

/**
Expand All @@ -71,6 +92,7 @@ public TripDashboard withGroupedRefData(String groupedRefCsv, String... categori
this.groupedRefCsv = groupedRefCsv;
if (categories.length == 0) {
categories = detectCategories(groupedRefCsv);
log.info("Detected categories from reference data: {}", Arrays.toString(categories));
}
this.categories = categories;
return this;
Expand Down Expand Up @@ -340,30 +362,35 @@ private void createChoiceTab(Layout layout, String[] args) {

private void createGroupedTab(Layout layout, String[] args) {

// age,economic_status,dist_group,main_mode,share
layout.row("facets", "By Groups").el(Plotly.class, (viz, data) -> {
for (String cat : Objects.requireNonNull(categories, "Categories not set")) {

viz.title = "FACETS";
viz.description = "by hour and purpose";
viz.layout = tech.tablesaw.plotly.components.Layout.builder()
.xAxis(Axis.builder().title("dist_group").build())
.yAxis(Axis.builder().title("sim_share").build())
.barMode(tech.tablesaw.plotly.components.Layout.BarMode.STACK)
.build();
layout.row("category_" + cat, "By Groups").el(Plotly.class, (viz, data) -> {

// TODO: Still in testing
viz.addTrace(BarTrace.builder(Plotly.OBJ_INPUT, Plotly.INPUT).build(),
viz.addDataset(data.computeWithPlaceholder(TripAnalysis.class, "mode_share_per_%s.csv", "age")).mapping()
.facetCol("age")
.name("main_mode", ColorScheme.Spectral)
.x("dist_group")
.y("sim_share")
);
viz.title = "Mode share";
viz.description = "by " + cat;
viz.layout = tech.tablesaw.plotly.components.Layout.builder()
.xAxis(Axis.builder().title("dist_group").build())
.yAxis(Axis.builder().title("sim_share").build())
.barMode(tech.tablesaw.plotly.components.Layout.BarMode.STACK)
.build();

});
// TODO: Still in testing
Plotly.DataMapping ds = viz.addDataset(data.computeWithPlaceholder(TripAnalysis.class, "mode_share_per_%s.csv", cat))
.pivot(List.of("main_mode"), "source", "share")
.aggregate(List.of("main_mode", "source", cat), "share", Plotly.AggrFunc.SUM)
.mapping()
.facetCol(cat)
.name("main_mode")
.x("share")
.y("source");

// TODO create the additional tab

viz.addTrace(BarTrace.builder(Plotly.OBJ_INPUT, Plotly.INPUT)
.orientation(BarTrace.Orientation.HORIZONTAL)
.build(), ds);
});

}
}

}
Loading

0 comments on commit a60bc3e

Please sign in to comment.