Skip to content

Commit

Permalink
calculating some trip choice metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Jun 9, 2024
1 parent 67dd1ca commit e1874d9
Showing 1 changed file with 91 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.*;

/**
* Helper class to analyze trip choices from persons against reference data.
Expand All @@ -33,6 +32,11 @@ final class TripChoiceAnalysis {
*/
private final List<Entry> data = new ArrayList<>();

/**
* Contains predication result for each mode.
*/
private final Map<String, Counts> counts = new HashMap<>();

public TripChoiceAnalysis(Table persons, Table trips, List<String> modeOrder) {
persons = persons.where(persons.stringColumn("ref_modes").isNotEqualTo(""));
trips = new DataFrameJoiner(trips, "person").inner(persons);
Expand All @@ -57,6 +61,43 @@ public TripChoiceAnalysis(Table persons, Table trips, List<String> modeOrder) {
} else
log.warn("Person {} trip {} does not match ref data ({})", person, n, split.length);
}

for (String mode : modeOrder) {
counts.put(mode, countPredictions(mode, data));
}
}

private static double precision(Counts c) {
return c.tp / (c.tp + c.fp);
}

private static double recall(Counts c) {
return c.tp / (c.tp + c.fn);
}

private static double f1(Counts c) {
return 2 * c.tp / (2 * c.tp + c.fp + c.fn);
}

private Counts countPredictions(String mode, List<Entry> data) {
double tp = 0, fp = 0, fn = 0, tn = 0;
double total = 0;
for (Entry e : data) {
if (e.trueMode.equals(mode)) {
if (e.predMode.equals(mode))
tp += e.weight;
else
fn += e.weight;
} else {
if (e.predMode.equals(mode))
fp += e.weight;
else
tn += e.weight;
}
total += e.weight;
}

return new Counts(tp, fp, fn, tn, total);
}

/**
Expand All @@ -76,18 +117,32 @@ public void writeChoices(Path path) throws IOException {
*/
public void writeChoiceEvaluation(Path path) throws IOException {

try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {
double tp = 0;
double total = 0;
double tpfp = 0;
double tpfn = 0;
for (Counts c : counts.values()) {
tp += c.tp;
tpfp += c.tp + c.fp;
tpfn += c.tp + c.fn;
total = c.total;
}

csv.printRecord("Info", "Value");
OptionalDouble precision = counts.values().stream().mapToDouble(TripChoiceAnalysis::precision).average();
OptionalDouble recall = counts.values().stream().mapToDouble(TripChoiceAnalysis::recall).average();
OptionalDouble f1 = counts.values().stream().mapToDouble(TripChoiceAnalysis::f1).average();

csv.printRecord("Accuracy", "TODO");
csv.printRecord("F1 Score", "TODO");
// csv.printRecord("AUC-ROC", "");
csv.printRecord("Precision", "TODO");
csv.printRecord("Recall", "TODO");
try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {

// These can be micro and macro averaged
csv.printRecord("Info", "Value");

csv.printRecord("Accuracy", tp / total);
csv.printRecord("Precision (micro avg.)", tp / tpfp);
csv.printRecord("Precision (macro avg.)", precision.orElse(0));
csv.printRecord("Recall (micro avg.)", tp / tpfn);
csv.printRecord("Recall (macro avg.)", recall.orElse(0));
csv.printRecord("F1 Score (micro avg.)", 2 * tp / (tpfp + tpfn));
csv.printRecord("F1 Score (macro avg.)", f1.orElse(0));
}

// TODO Cohen’s Kappa, Cross-Entropy, Mathews Correlation Coefficient (MCC)
Expand All @@ -98,14 +153,23 @@ public void writeChoiceEvaluation(Path path) throws IOException {
*/
public void writeChoiceEvaluationPerMode(Path path) throws IOException {

// Precision in multi-class classification is the fraction of instances correctly classified as belonging to a specific class out of all instances the model predicted to belong to that class.

// Recall in multi-class classification is the fraction of instances in a class that the model correctly classified out of all instances in that class.

try (CSVPrinter csv = new CSVPrinter(Files.newBufferedWriter(path), CSVFormat.DEFAULT)) {

csv.printRecord("Accuracy", "TODO");
csv.printRecord("Precision", "TODO");
csv.printRecord("Recall", "TODO");
csv.printRecord("F1 Score", "TODO");
csv.printRecord("Mode", "Precision", "Recall", "F1 Score");
for (String m : modeOrder) {
csv.print(m);

Counts c = counts.get(m);

csv.print(precision(c));
csv.print(recall(c));
csv.print(f1(c));
csv.println();
}
}
}

Expand All @@ -115,4 +179,17 @@ public void writeChoiceEvaluationPerMode(Path path) throws IOException {

private record Entry(String person, double weight, int n, String trueMode, String predMode) {
}

/**
* Contains true positive, false positive, false negative and true negative counts.
*
* @param tp correctly predicted this class
* @param fp incorrectly predicted this class
* @param fn incorrectly predicted different class
* @param tn correctly predicated different class
*/
private record Counts(double tp, double fp, double fn, double tn, double total) {

}

}

0 comments on commit e1874d9

Please sign in to comment.