From d0201cb087dc01179c8d865973ad0e39ecb4d83c Mon Sep 17 00:00:00 2001 From: rakow Date: Fri, 22 Mar 2024 19:36:20 +0100 Subject: [PATCH] improve category matching in scoring --- .../scoring/AdvancedScoringConfigGroup.java | 15 +- .../java/org/matsim/run/scoring/Category.java | 133 ++++++++++++++++++ .../IndividualPersonScoringParameters.java | 30 +++- src/main/python/estimate_mixed_plan_choice.py | 8 +- 4 files changed, 167 insertions(+), 19 deletions(-) create mode 100644 src/main/java/org/matsim/run/scoring/Category.java diff --git a/src/main/java/org/matsim/run/scoring/AdvancedScoringConfigGroup.java b/src/main/java/org/matsim/run/scoring/AdvancedScoringConfigGroup.java index f4297724..f943a7aa 100644 --- a/src/main/java/org/matsim/run/scoring/AdvancedScoringConfigGroup.java +++ b/src/main/java/org/matsim/run/scoring/AdvancedScoringConfigGroup.java @@ -72,8 +72,6 @@ public static final class ScoringParameters extends ReflectiveConfigGroup { */ private final Map modeParams = new HashMap<>(); - // TODO: option to match as list - public ScoringParameters() { super(GROUP_NAME, true); } @@ -81,23 +79,16 @@ public ScoringParameters() { /** * Checks if the given attributes match the config. If true these parameters are applicable to tbe object. */ - public boolean matchObject(Attributes attr) { - - // TODO: special case int <-> double and numbers - // boolean values - // allow lists - - // TODO: matching is not yet correct - // TODO: add test + public boolean matchObject(Attributes attr, Map categories) { for (Map.Entry e : this.getParams().entrySet()) { // might be null if not defined Object objValue = attr.getAttribute(e.getKey()); + String category = categories.get(e.getKey()).categorize(objValue); // compare as string - if (!Objects.toString(objValue).equals(e.getValue())) + if (!Objects.toString(category).equals(e.getValue())) return false; - } return true; diff --git a/src/main/java/org/matsim/run/scoring/Category.java b/src/main/java/org/matsim/run/scoring/Category.java new file mode 100644 index 00000000..ac0af438 --- /dev/null +++ b/src/main/java/org/matsim/run/scoring/Category.java @@ -0,0 +1,133 @@ +package org.matsim.run.scoring; + +import java.util.*; + +/** + * Categorize values into groups. + */ +public final class Category { + + private static final Set TRUE = Set.of("true", "yes", "1", "on", "y", "j", "ja"); + private static final Set FALSE = Set.of("false", "no", "0", "off", "n", "nein"); + + /** + * Unique values of the category. + */ + private final Set values; + + /** + * Groups of values that have been subsumed under a single category. + * These are values separated by , + */ + private final Map grouped; + + /** + * Range categories. + */ + private final List ranges; + + public Category(Set values) { + this.values = values; + this.grouped = new HashMap<>(); + for (String v : values) { + if (v.contains(",")) { + String[] grouped = v.split(","); + for (String g : grouped) { + this.grouped.put(g, v); + } + } + } + + boolean range = this.values.stream().allMatch(v -> v.contains("-") || v.contains("+")); + if (range) { + ranges = new ArrayList<>(); + for (String value : this.values) { + if (value.contains("-")) { + String[] parts = value.split("-"); + ranges.add(new Range(Double.parseDouble(parts[0]), Double.parseDouble(parts[1]), value)); + } else if (value.contains("+")) { + ranges.add(new Range(Double.parseDouble(value.replace("+", "")), Double.POSITIVE_INFINITY, value)); + } + } + + ranges.sort(Comparator.comparingDouble(r -> r.left)); + } else + ranges = null; + + + // Check if all values are boolean + if (values.stream().allMatch(v -> TRUE.contains(v.toLowerCase()) || FALSE.contains(v.toLowerCase()))) { + for (String value : values) { + Set group = TRUE.contains(value.toLowerCase()) ? TRUE : FALSE; + for (String g : group) { + this.grouped.put(g, value); + } + } + } + } + + /** + * Categorize a single value. + */ + public String categorize(Object value) { + + if (value == null) + return null; + + if (value instanceof Boolean) { + // Booleans and synonyms are in the group map + return categorize(((Boolean) value).toString().toLowerCase()); + } else if (value instanceof Number) { + return categorizeNumber((Number) value); + } else { + String v = value.toString(); + if (values.contains(v)) + return v; + else if (grouped.containsKey(v)) + return grouped.get(v); + + try { + double d = Double.parseDouble(v); + return categorizeNumber(d); + } catch (NumberFormatException e) { + return null; + } + } + } + + private String categorizeNumber(Number value) { + + if (ranges != null) { + for (Range r : ranges) { + if (value.doubleValue() >= r.left && value.doubleValue() < r.right) + return r.label; + } + } + + // Match string representation + String v = value.toString(); + if (values.contains(v)) + return v; + else if (grouped.containsKey(v)) + return grouped.get(v); + + + // Convert the number to a whole number, which will have a different string representation + if (value instanceof Float || value instanceof Double) { + return categorizeNumber(value.longValue()); + } + + return null; + } + + /** + * @param left Left bound of the range. + * @param right Right bound of the range. (exclusive) + * @param label Label of this group. + */ + private record Range(double left, double right, String label) { + + + } + +} diff --git a/src/main/java/org/matsim/run/scoring/IndividualPersonScoringParameters.java b/src/main/java/org/matsim/run/scoring/IndividualPersonScoringParameters.java index c5b7c36f..68b54f40 100644 --- a/src/main/java/org/matsim/run/scoring/IndividualPersonScoringParameters.java +++ b/src/main/java/org/matsim/run/scoring/IndividualPersonScoringParameters.java @@ -55,10 +55,16 @@ public class IndividualPersonScoringParameters implements ScoringParametersForPe * Cache and reuse distance group arrays. */ private final Map distGroups = new ConcurrentHashMap<>(); + + /** + * Categories from config group. + */ + private final Map categories; + /** * Thread-local random number generator. */ - private final ThreadLocal rnd = ThreadLocal.withInitial(() -> new Context()); + private final ThreadLocal rnd = ThreadLocal.withInitial(Context::new); private final Scenario scenario; private final ScoringConfigGroup basicScoring; private final TransitConfigGroup transitConfig; @@ -75,9 +81,24 @@ public IndividualPersonScoringParameters(Scenario scenario) { this.scoring = ConfigUtils.addOrGetModule(scenario.getConfig(), AdvancedScoringConfigGroup.class); this.transitConfig = scenario.getConfig().transit(); this.globalAvgIncome = computeAvgIncome(scenario.getPopulation()); + this.categories = buildCategories(this.scoring); this.cache = new IdMap<>(Person.class, scenario.getPopulation().getPersons().size()); } + static Map buildCategories(AdvancedScoringConfigGroup scoring) { + + Map> categories = new HashMap<>(); + + // Collect all values + for (AdvancedScoringConfigGroup.ScoringParameters parameter : scoring.getScoringParameters()) { + for (Map.Entry kv : parameter.getParams().entrySet()) { + categories.computeIfAbsent(kv.getKey(), k -> new HashSet<>()).add(kv.getValue()); + } + } + + return categories.entrySet().stream() + .collect(HashMap::new, (m, e) -> m.put(e.getKey(), new Category(e.getValue())), HashMap::putAll); + } static DistanceGroup[] calcDistanceGroups(List dists, DoubleList distUtils) { @@ -202,7 +223,7 @@ public ScoringParameters getScoringParameters(Person person) { for (AdvancedScoringConfigGroup.ScoringParameters parameter : scoring.getScoringParameters()) { - if (parameter.matchObject(person.getAttributes())) { + if (parameter.matchObject(person.getAttributes(), categories)) { for (Map.Entry mode : parameter.getModeParams().entrySet()) { DistanceGroupModeUtilityParameters.DeltaBuilder b = @@ -227,11 +248,12 @@ public ScoringParameters getScoringParameters(Person person) { // Collect final adjustments information Object2DoubleMap values = info.computeIfAbsent(person.getId(), k -> new Object2DoubleOpenHashMap<>()); + // Write the overall constants, but only if they are different to the base values if (delta.constant != 0) - values.put(mode.getKey() + "_constant", delta.constant); + values.put(mode.getKey() + "_constant", p.constant); if (delta.dailyUtilityConstant != 0) - values.put(mode.getKey() + "_dailyConstant", delta.dailyUtilityConstant); + values.put(mode.getKey() + "_dailyConstant", p.dailyUtilityConstant); if (groups != null) { for (DistanceGroup group : groups) { diff --git a/src/main/python/estimate_mixed_plan_choice.py b/src/main/python/estimate_mixed_plan_choice.py index 3f6e4343..dbd643bd 100644 --- a/src/main/python/estimate_mixed_plan_choice.py +++ b/src/main/python/estimate_mixed_plan_choice.py @@ -11,7 +11,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Estimate the plan choice mixed logit model") - parser.add_argument("--input", help="Path to the input file", type=str, default="../../../plan-choices.csv") + parser.add_argument("--input", help="Path to the input file", type=str, default="../../../plan-choices-random.csv") parser.add_argument("--n-draws", help="Number of draws for the estimation", type=int, default=1500) parser.add_argument("--batch-size", help="Batch size for the estimation", type=int, default=None) parser.add_argument("--sample", help="Use sample of choice data", type=float, default=0.2) @@ -59,6 +59,7 @@ # ASC is present as mode_usage varnames = [f"{mode}_usage" for mode in modes if mode != "walk" and mode != "car"] + ["car_used"] # varnames += ["pt_ride_hours", "car_ride_hours", "bike_ride_hours"] + # varnames = ["car_used", "car_usage"] # Additive costs addit = df["costs"] + df["car_fixed_cost"] - df["pt_n_switches"] @@ -70,11 +71,12 @@ addit=addit, # randvars={"car_used": "tn"}, randvars={"car_used": "tn", "bike_usage": "n", "pt_usage": "n", "ride_usage": "n"}, - n_draws=args.n_draws, batch_size=args.batch_size, + fixedvars={"car_used": None}, + n_draws=args.n_draws, batch_size=args.batch_size, halton=True, skip_std_errs=True, optim_method='L-BFGS-B') else: - varnames += ["car_usage"] + #varnames += ["car_usage"] model = MultinomialLogit() model.fit(X=df[varnames], y=df['choice'], varnames=varnames,