Skip to content

Commit

Permalink
add cohen kappa to evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Jun 12, 2024
1 parent 6a48a05 commit e6000b6
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.stream.IntStream;

/**
* Helper class to analyze trip choices from persons against reference data.
Expand All @@ -42,6 +43,7 @@ final class TripChoiceAnalysis {
* Contains predication result for each mode.
*/
private final Map<String, Counts> counts = new HashMap<>();
private final Object2DoubleMap<Pair> pairs = new Object2DoubleOpenHashMap<>();

/**
* Contains confusion matrix for each mode.
Expand Down Expand Up @@ -73,26 +75,19 @@ public TripChoiceAnalysis(Table persons, Table trips, List<String> modeOrder) {
log.warn("Person {} trip {} does not match ref data ({})", person, n, split.length);
}

for (Entry e : data) {
pairs.mergeDouble(new Pair(e.trueMode(), e.predMode()), e.weight(), Double::sum);
}

for (String mode : modeOrder) {
counts.put(mode, countPredictions(mode, data));
DoubleArrayList preds = new DoubleArrayList();

for (String predMode : modeOrder) {
double c = 0;
for (Entry e : data) {
if (!e.trueMode.equals(mode))
continue;
if (e.predMode.equals(predMode))
c += e.weight;
}
double c = pairs.getDouble(new Pair(mode, predMode));
preds.add(c);
}

double sum = preds.doubleStream().sum();
for (int i = 0; i < preds.size(); i++) {
preds.set(i, preds.getDouble(i) / sum);
}

confusionMatrix.add(preds);
}
}
Expand All @@ -116,6 +111,35 @@ private static double f1(Counts c) {
return 2 * c.tp / (2 * c.tp + c.fp + c.fn);
}

/**
* Implemented as in sklearn.metrics.cohen_kappa_score.
* See <a href="https://en.wikipedia.org/wiki/Cohen%27s_kappa">Wikipedia</a>
*
* @param cm confusion matrix
*/
static double computeCohenKappa(List<DoubleList> cm) {
DoubleList sumRows = cm.stream().mapToDouble(l -> l.doubleStream().sum()).collect(DoubleArrayList::new, DoubleList::add, DoubleList::addAll);
DoubleList sumCols = IntStream.range(0, cm.size()).mapToDouble(i -> cm.stream().mapToDouble(l -> l.getDouble(i)).sum()).collect(DoubleArrayList::new, DoubleList::add, DoubleList::addAll);
double sumTotal = sumRows.doubleStream().sum();
double expected = 0;

for (int i = 0; i < cm.size(); i++) {
for (int j = 0; j < cm.size(); j++) {
if (i != j)
expected += sumRows.getDouble(i) * sumCols.getDouble(j) / sumTotal;
}
}

double k = 0;
for (int i = 0; i < cm.size(); i++) {
for (int j = 0; j < cm.size(); j++) {
if (i != j)
k += cm.get(i).getDouble(j) / expected;
}
}
return 1 - k;
}

private Counts countPredictions(String mode, List<Entry> data) {
double tp = 0, fp = 0, fn = 0, tn = 0;
double total = 0;
Expand Down Expand Up @@ -180,9 +204,8 @@ public void writeChoiceEvaluation(Path path) throws IOException {
csv.printRecord("Recall (macro avg.)", round(recall.orElse(0)));
csv.printRecord("F1 Score (micro avg.)", round(2 * tp / (tpfp + tpfn)));
csv.printRecord("F1 Score (macro avg.)", round(f1.orElse(0)));
csv.printRecord("Cohen’s Kappa", round(computeCohenKappa(confusionMatrix)));
}

// TODO Cohen’s Kappa, Mathews Correlation Coefficient (MCC)
}

/**
Expand Down Expand Up @@ -211,7 +234,7 @@ public void writeChoiceEvaluationPerMode(Path path) throws IOException {
}

/**
* Write confusion matrix.
* Write confusion matrix. This normalizes per row.
*/
public void writeConfusionMatrix(Path path) throws IOException {
try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {
Expand All @@ -224,8 +247,9 @@ public void writeConfusionMatrix(Path path) throws IOException {
for (int i = 0; i < modeOrder.size(); i++) {
csv.print(modeOrder.get(i));
DoubleList row = confusionMatrix.get(i);
double sum = row.doubleStream().sum();
for (int j = 0; j < row.size(); j++) {
csv.print(row.getDouble(j));
csv.print(row.getDouble(j) / sum);
}
csv.println();
}
Expand All @@ -234,28 +258,22 @@ public void writeConfusionMatrix(Path path) throws IOException {

public void writeModePredictionError(Path path) throws IOException {

Object2DoubleMap<Pair> counts = new Object2DoubleOpenHashMap<>();
// inefficient, should not be used on large datasets
for (Entry e : data) {
counts.mergeDouble(new Pair(e.trueMode(), e.predMode()), e.weight(), Double::sum);
}

try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {
csv.printRecord("true_mode", "predicted_mode", "count");
for (String trueMode : modeOrder) {
for (String predMode : modeOrder) {
double c = counts.getDouble(new Pair(trueMode, predMode));
double c = pairs.getDouble(new Pair(trueMode, predMode));
if (c > 0)
csv.printRecord(trueMode, predMode, c);
}
}
}
}

private record Entry(String person, double weight, int n, String trueMode, String predMode) {
record Entry(String person, double weight, int n, String trueMode, String predMode) {
}

private record Pair(String trueMode, String predMode) {
record Pair(String trueMode, String predMode) {
}

/**
Expand All @@ -266,7 +284,7 @@ private record Pair(String trueMode, String predMode) {
* @param fn incorrectly predicted different class
* @param tn correctly predicated different class
*/
private record Counts(double tp, double fp, double fn, double tn, double total) {
record Counts(double tp, double fp, double fn, double tn, double total) {

}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package org.matsim.application.analysis.population;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.objects.Object2DoubleMap;
import it.unimi.dsi.fastutil.objects.Object2DoubleOpenHashMap;
import org.assertj.core.data.Offset;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

class TripChoiceAnalysisTest {

/**
* Create confusion matrix.
*/
public static List<DoubleList> cm(String... entries) {
List<DoubleList> rows = new ArrayList<>();
List<String> distinct = Arrays.stream(entries).distinct().toList();
Object2DoubleMap<TripChoiceAnalysis.Pair> pairs = new Object2DoubleOpenHashMap<>();

for (int i = 0; i < entries.length; i += 2) {
TripChoiceAnalysis.Pair pair = new TripChoiceAnalysis.Pair(entries[i], entries[i + 1]);
pairs.mergeDouble(pair, 1, Double::sum);
}

for (String d1 : distinct) {
DoubleArrayList row = new DoubleArrayList();
for (String d2 : distinct) {
row.add(pairs.getDouble(new TripChoiceAnalysis.Pair(d1, d2)));
}
rows.add(row);
}

return rows;
}

@Test
void cohenKappa() {

double ck = TripChoiceAnalysis.computeCohenKappa(List.of());
assertThat(ck).isEqualTo(1);

ck = TripChoiceAnalysis.computeCohenKappa(cm(
"a", "a",
"b", "b",
"b", "b",
"c", "c")
);

assertThat(ck).isEqualTo(1.0);
ck = TripChoiceAnalysis.computeCohenKappa(cm(
"a", "c",
"d", "e",
"a", "b",
"b", "d"
));

assertThat(ck).isLessThan(0.0);

// These have been verified with sklearn
ck = TripChoiceAnalysis.computeCohenKappa(cm(
"negative", "negative",
"positive", "positive",
"negative", "negative",
"neutral", "neutral",
"positive", "negative"
));

assertThat(ck).isEqualTo(0.6875);

ck = TripChoiceAnalysis.computeCohenKappa(cm(
"negative", "positive",
"positive", "neutral",
"negative", "negative",
"neutral", "neutral",
"positive", "negative"
));

assertThat(ck).isEqualTo( 0.11764705882352955, Offset.offset(1e-5));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -308,20 +308,25 @@ public void configure(Header header, Layout layout) {
private void createChoiceTab(Layout layout, String[] args) {

layout.row("choice-intro", "Mode Choice").el(TextBlock.class, (viz, data) -> {
viz.title = "Introduction";
viz.title = "Information";
viz.content = """
Information regarding the metrics used:
Precision is the fraction of instances correctly classified as belonging to a specific class out of all instances the model predicted to belong to that class.
Recall is the fraction of instances in a class that the model correctly classified out of all instances in that class.
The macro-average computes the metric independently for each class and then take the average (hence treating all classes equally).
The micro-average will aggregate the contributions of all classes to compute the average metric.""";
Note that these metrics are based on a single run and may have limited interpretability. For a more robust evaluation, consider running multiple simulations with different seeds and use metrics that consider probabilities as well.
(log-likelihood, Brier score, etc.)
For policy cases, these metrics do not have any meaning. They are solely for the base-case.
- Precision is the fraction of instances correctly classified as belonging to a specific class out of all instances the model predicted to belong to that class.
- Recall is the fraction of instances in a class that the model correctly classified out of all instances in that class.
- The macro-average computes the metric independently for each class and then take the average (hence treating all classes equally).
- The micro-average will aggregate the contributions of all classes to compute the average metric.
- Cohen's Kappa is a measure of agreement between two raters that corrects for chance agreement. 1.0 indicates perfect agreement, 0.0 or less indicates agreement by chance.
""";
});

layout.row("choice", "Mode Choice").el(Table.class, (viz, data) -> {
viz.title = "Choice Evaluation";
viz.description = "Metrics for mode choice.";
viz.showAllRows = true;
viz.height = 6d;
viz.dataset = data.compute(TripAnalysis.class, "mode_choice_evaluation.csv", args);
});

Expand All @@ -334,7 +339,7 @@ private void createChoiceTab(Layout layout, String[] args) {

layout.row("choice-plots", "Mode Choice").el(Heatmap.class, (viz, data) -> {
viz.title = "Confusion Matrix";
viz.description = "Confusion matrix for mode choice.";
viz.description = "Share of (mis)classified modes.";
viz.xAxisTitle = "Predicted";
viz.yAxisTitle = "True";
viz.dataset = data.compute(TripAnalysis.class, "mode_confusion_matrix.csv", args);
Expand Down Expand Up @@ -371,7 +376,6 @@ private void createGroupedTab(Layout layout, String[] args) {
viz.description = "by " + cat;
viz.height = 6d;
viz.layout = tech.tablesaw.plotly.components.Layout.builder()
.xAxis(Axis.builder().title("share").build())
.barMode(tech.tablesaw.plotly.components.Layout.BarMode.STACK)
.build();

Expand Down Expand Up @@ -410,7 +414,8 @@ private void createGroupedTab(Layout layout, String[] args) {
.rename("sim_share", "Sim")
.rename("ref_share", "Ref")
.mapping()
.name(cat)
.name("main_mode")
.facetCol(cat)
.x("dist_group")
.y("share");

Expand All @@ -420,7 +425,6 @@ private void createGroupedTab(Layout layout, String[] args) {
.orientation(BarTrace.Orientation.VERTICAL)
.build(), ds);


});

}
Expand Down

0 comments on commit e6000b6

Please sign in to comment.