Skip to content

Commit

Permalink
improve category matching in scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Mar 22, 2024
1 parent 88caf0c commit d0201cb
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,32 +72,23 @@ public static final class ScoringParameters extends ReflectiveConfigGroup {
*/
private final Map<String, ModeParams> modeParams = new HashMap<>();

// TODO: option to match as list

public ScoringParameters() {
super(GROUP_NAME, true);
}

/**
* Checks if the given attributes match the config. If true these parameters are applicable to tbe object.
*/
public boolean matchObject(Attributes attr) {

// TODO: special case int <-> double and numbers
// boolean values
// allow lists

// TODO: matching is not yet correct
// TODO: add test
public boolean matchObject(Attributes attr, Map<String, Category> categories) {

for (Map.Entry<String, String> e : this.getParams().entrySet()) {
// might be null if not defined
Object objValue = attr.getAttribute(e.getKey());
String category = categories.get(e.getKey()).categorize(objValue);

// compare as string
if (!Objects.toString(objValue).equals(e.getValue()))
if (!Objects.toString(category).equals(e.getValue()))
return false;

}

return true;
Expand Down
133 changes: 133 additions & 0 deletions src/main/java/org/matsim/run/scoring/Category.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package org.matsim.run.scoring;

import java.util.*;

/**
* Categorize values into groups.
*/
public final class Category {

private static final Set<String> TRUE = Set.of("true", "yes", "1", "on", "y", "j", "ja");
private static final Set<String> FALSE = Set.of("false", "no", "0", "off", "n", "nein");

/**
* Unique values of the category.
*/
private final Set<String> values;

/**
* Groups of values that have been subsumed under a single category.
* These are values separated by ,
*/
private final Map<String, String> grouped;

/**
* Range categories.
*/
private final List<Range> ranges;

public Category(Set<String> values) {
this.values = values;
this.grouped = new HashMap<>();
for (String v : values) {
if (v.contains(",")) {
String[] grouped = v.split(",");
for (String g : grouped) {
this.grouped.put(g, v);
}
}
}

boolean range = this.values.stream().allMatch(v -> v.contains("-") || v.contains("+"));
if (range) {
ranges = new ArrayList<>();
for (String value : this.values) {
if (value.contains("-")) {
String[] parts = value.split("-");
ranges.add(new Range(Double.parseDouble(parts[0]), Double.parseDouble(parts[1]), value));
} else if (value.contains("+")) {
ranges.add(new Range(Double.parseDouble(value.replace("+", "")), Double.POSITIVE_INFINITY, value));
}
}

ranges.sort(Comparator.comparingDouble(r -> r.left));
} else
ranges = null;


// Check if all values are boolean
if (values.stream().allMatch(v -> TRUE.contains(v.toLowerCase()) || FALSE.contains(v.toLowerCase()))) {
for (String value : values) {
Set<String> group = TRUE.contains(value.toLowerCase()) ? TRUE : FALSE;
for (String g : group) {
this.grouped.put(g, value);
}
}
}
}

/**
* Categorize a single value.
*/
public String categorize(Object value) {

if (value == null)
return null;

if (value instanceof Boolean) {
// Booleans and synonyms are in the group map
return categorize(((Boolean) value).toString().toLowerCase());
} else if (value instanceof Number) {
return categorizeNumber((Number) value);
} else {
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);

try {
double d = Double.parseDouble(v);
return categorizeNumber(d);
} catch (NumberFormatException e) {
return null;
}
}
}

private String categorizeNumber(Number value) {

if (ranges != null) {
for (Range r : ranges) {
if (value.doubleValue() >= r.left && value.doubleValue() < r.right)
return r.label;
}
}

// Match string representation
String v = value.toString();
if (values.contains(v))
return v;
else if (grouped.containsKey(v))
return grouped.get(v);


// Convert the number to a whole number, which will have a different string representation
if (value instanceof Float || value instanceof Double) {
return categorizeNumber(value.longValue());
}

return null;
}

/**
* @param left Left bound of the range.
* @param right Right bound of the range. (exclusive)
* @param label Label of this group.
*/
private record Range(double left, double right, String label) {


}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@ public class IndividualPersonScoringParameters implements ScoringParametersForPe
* Cache and reuse distance group arrays.
*/
private final Map<DoubleList, DistanceGroup[]> distGroups = new ConcurrentHashMap<>();

/**
* Categories from config group.
*/
private final Map<String, Category> categories;

/**
* Thread-local random number generator.
*/
private final ThreadLocal<Context> rnd = ThreadLocal.withInitial(() -> new Context());
private final ThreadLocal<Context> rnd = ThreadLocal.withInitial(Context::new);
private final Scenario scenario;
private final ScoringConfigGroup basicScoring;
private final TransitConfigGroup transitConfig;
Expand All @@ -75,9 +81,24 @@ public IndividualPersonScoringParameters(Scenario scenario) {
this.scoring = ConfigUtils.addOrGetModule(scenario.getConfig(), AdvancedScoringConfigGroup.class);
this.transitConfig = scenario.getConfig().transit();
this.globalAvgIncome = computeAvgIncome(scenario.getPopulation());
this.categories = buildCategories(this.scoring);
this.cache = new IdMap<>(Person.class, scenario.getPopulation().getPersons().size());
}

static Map<String, Category> buildCategories(AdvancedScoringConfigGroup scoring) {

Map<String, Set<String>> categories = new HashMap<>();

// Collect all values
for (AdvancedScoringConfigGroup.ScoringParameters parameter : scoring.getScoringParameters()) {
for (Map.Entry<String, String> kv : parameter.getParams().entrySet()) {
categories.computeIfAbsent(kv.getKey(), k -> new HashSet<>()).add(kv.getValue());
}
}

return categories.entrySet().stream()
.collect(HashMap::new, (m, e) -> m.put(e.getKey(), new Category(e.getValue())), HashMap::putAll);
}

static DistanceGroup[] calcDistanceGroups(List<Integer> dists, DoubleList distUtils) {

Expand Down Expand Up @@ -202,7 +223,7 @@ public ScoringParameters getScoringParameters(Person person) {

for (AdvancedScoringConfigGroup.ScoringParameters parameter : scoring.getScoringParameters()) {

if (parameter.matchObject(person.getAttributes())) {
if (parameter.matchObject(person.getAttributes(), categories)) {
for (Map.Entry<String, AdvancedScoringConfigGroup.ModeParams> mode : parameter.getModeParams().entrySet()) {

DistanceGroupModeUtilityParameters.DeltaBuilder b =
Expand All @@ -227,11 +248,12 @@ public ScoringParameters getScoringParameters(Person person) {
// Collect final adjustments information
Object2DoubleMap<String> values = info.computeIfAbsent(person.getId(), k -> new Object2DoubleOpenHashMap<>());

// Write the overall constants, but only if they are different to the base values
if (delta.constant != 0)
values.put(mode.getKey() + "_constant", delta.constant);
values.put(mode.getKey() + "_constant", p.constant);

if (delta.dailyUtilityConstant != 0)
values.put(mode.getKey() + "_dailyConstant", delta.dailyUtilityConstant);
values.put(mode.getKey() + "_dailyConstant", p.dailyUtilityConstant);

if (groups != null) {
for (DistanceGroup group : groups) {
Expand Down
8 changes: 5 additions & 3 deletions src/main/python/estimate_mixed_plan_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Estimate the plan choice mixed logit model")
parser.add_argument("--input", help="Path to the input file", type=str, default="../../../plan-choices.csv")
parser.add_argument("--input", help="Path to the input file", type=str, default="../../../plan-choices-random.csv")
parser.add_argument("--n-draws", help="Number of draws for the estimation", type=int, default=1500)
parser.add_argument("--batch-size", help="Batch size for the estimation", type=int, default=None)
parser.add_argument("--sample", help="Use sample of choice data", type=float, default=0.2)
Expand Down Expand Up @@ -59,6 +59,7 @@
# ASC is present as mode_usage
varnames = [f"{mode}_usage" for mode in modes if mode != "walk" and mode != "car"] + ["car_used"]
# varnames += ["pt_ride_hours", "car_ride_hours", "bike_ride_hours"]
# varnames = ["car_used", "car_usage"]

# Additive costs
addit = df["costs"] + df["car_fixed_cost"] - df["pt_n_switches"]
Expand All @@ -70,11 +71,12 @@
addit=addit,
# randvars={"car_used": "tn"},
randvars={"car_used": "tn", "bike_usage": "n", "pt_usage": "n", "ride_usage": "n"},
n_draws=args.n_draws, batch_size=args.batch_size,
fixedvars={"car_used": None},
n_draws=args.n_draws, batch_size=args.batch_size, halton=True, skip_std_errs=True,
optim_method='L-BFGS-B')

else:
varnames += ["car_usage"]
#varnames += ["car_usage"]

model = MultinomialLogit()
model.fit(X=df[varnames], y=df['choice'], varnames=varnames,
Expand Down

0 comments on commit d0201cb

Please sign in to comment.