Skip to content

Commit

Permalink
added tests and improved naming
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Mar 16, 2024
1 parent ce5bf59 commit 9b8f1a9
Show file tree
Hide file tree
Showing 10 changed files with 1,879 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.matsim.core.network.NetworkUtils;
import org.matsim.core.network.algorithms.MultimodalNetworkCleaner;
import org.matsim.core.network.filter.NetworkFilterManager;
import org.matsim.core.router.FastDijkstraFactory;
import org.matsim.core.router.DijkstraFactory;
import org.matsim.core.router.costcalculators.OnlyTimeDependentTravelDisutility;
import org.matsim.core.router.util.LeastCostPathCalculator;
import org.matsim.core.router.util.TravelTime;
Expand Down Expand Up @@ -293,7 +293,7 @@ private Network createCityNetwork(Network network) {
private LeastCostPathCalculator createRandomizedRouter(Network network, TravelTime tt) {

OnlyTimeDependentTravelDisutility util = new OnlyTimeDependentTravelDisutility(tt);
return new FastDijkstraFactory(false).createPathCalculator(network, util, tt);
return new DijkstraFactory(false).createPathCalculator(network, util, tt);
}

private static final class RandomizedTravelTime implements TravelTime {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class ApplyNetworkParams implements MATSimAppCommand {
private final OutputOptions output = OutputOptions.ofCommand(ApplyNetworkParams.class);

@CommandLine.Parameters(arity = "1..*", description = "Type of parameters to apply. Available: ${COMPLETION-CANDIDATES}")
private Set<Parameter> params;
private Set<NetworkAttribute> params;

@CommandLine.Option(names = "--input-params", description = "Path to parameter json")
private String inputParams;
Expand All @@ -53,7 +53,7 @@ public class ApplyNetworkParams implements MATSimAppCommand {
private double[] speedFactorBounds;

private NetworkModel model;
private NetworkParamsOpt.Request paramsOpt;
private NetworkParams paramsOpt;

private int warn = 0;

Expand Down Expand Up @@ -88,8 +88,10 @@ public Integer call() throws Exception {
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.setSerializationInclusion(JsonInclude.Include.NON_DEFAULT);

try (BufferedReader in = IOUtils.getBufferedReader(inputParams)) {
paramsOpt = mapper.readValue(in, NetworkParamsOpt.Request.class);
if (inputParams != null) {
try (BufferedReader in = IOUtils.getBufferedReader(inputParams)) {
paramsOpt = mapper.readValue(in, NetworkParams.class);
}
}

Map<Id<Link>, Feature> features = NetworkParamsOpt.readFeatures(input.getPath("features.csv"), network.getLinks().size());
Expand All @@ -115,7 +117,7 @@ private void applyChanges(Link link, String junctionType, Object2DoubleMap<Strin

boolean modified = false;

if (params.contains(Parameter.capacity)) {
if (params.contains(NetworkAttribute.capacity)) {

FeatureRegressor capacity = model.capacity(junctionType);

Expand All @@ -140,11 +142,10 @@ private void applyChanges(Link link, String junctionType, Object2DoubleMap<Strin
}

link.setCapacity(link.getNumberOfLanes() * perLane);

}


if (params.contains(Parameter.freespeed)) {
if (params.contains(NetworkAttribute.freespeed)) {

double speedFactor = 1.0;
FeatureRegressor speedModel = model.speedFactor(junctionType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public static void main(String[] args) {

static Result applyAndEvaluateParams(
Network network, NetworkModel model, Object2DoubleMap<SampleValidationRoutes.FromToNodes> validationSet, Map<Id<Link>, Feature> features,
double[] speedFactorBounds, Request request, String save) throws IOException {
double[] speedFactorBounds, NetworkParams request, String save) throws IOException {

Map<Id<Link>, double[]> attributes = new HashMap<>();

Expand Down Expand Up @@ -169,13 +169,13 @@ public Integer call() throws Exception {


log.info("Model score:");
Result r = applyAndEvaluateParams(network, model, validationSet, features, speedFactorBounds, new Request(0), save(getParamsName(null)));
Result r = applyAndEvaluateParams(network, model, validationSet, features, speedFactorBounds, new NetworkParams(0), save(getParamsName(null)));
writeResult(csv, null, r);

if (params != null) {
log.info("Model with parameter score:");
r = applyAndEvaluateParams(network, model, validationSet, features, speedFactorBounds,
mapper.readValue(params.toFile(), Request.class), save(getParamsName(params)));
mapper.readValue(params.toFile(), NetworkParams.class), save(getParamsName(params)));
writeResult(csv, params, r);
}

Expand Down Expand Up @@ -210,12 +210,12 @@ private void evalSpeedFactors(Path eval, String save) throws IOException {

String networkName = FilenameUtils.getName(input.getNetworkPath());

Request best = null;
NetworkParams best = null;
double bestScore = Double.POSITIVE_INFINITY;
double[] bounds = {0, 1};

for (int i = 25; i <= 100; i++) {
Request req = new Request(i / 100d);
NetworkParams req = new NetworkParams(i / 100d);
Result res = applyAndEvaluateParams(network, model, validationSet, features, bounds, req, null);
csv.printRecord(networkName, i / 100d, res.mae(), res.rmse());
if (best == null || res.mae() < bestScore) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.matsim.application.analysis.traffic.traveltime.SampleValidationRoutes;
import org.matsim.application.options.InputOptions;
import org.matsim.application.prepare.network.opt.NetworkParamsOpt.Feature;
import org.matsim.application.prepare.network.opt.NetworkParamsOpt.Request;
import org.matsim.application.prepare.network.opt.NetworkParamsOpt.Result;
import picocli.CommandLine;

Expand Down Expand Up @@ -122,7 +121,7 @@ public Integer call() throws Exception {
return 0;
}

private Result applyAndEvaluateParams(Request request, String save) throws IOException {
private Result applyAndEvaluateParams(NetworkParams request, String save) throws IOException {
return EvalFreespeedParams.applyAndEvaluateParams(network, model, validationSet, features, speedFactorBounds,
request, save);
}
Expand All @@ -132,7 +131,7 @@ private final class Backend implements HttpRequestHandler, ExceptionListener {
@Override
public void handle(ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) throws IOException {

Request req = mapper.readValue(request.getEntity().getContent(), Request.class);
NetworkParams req = mapper.readValue(request.getEntity().getContent(), NetworkParams.class);

Result stats = applyAndEvaluateParams(req, null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
/**
* Enum of network parameters.
*/
public enum Parameter {
public enum NetworkAttribute {
freespeed,
capacity
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.matsim.application.prepare.network.opt;

import com.fasterxml.jackson.annotation.JsonAnyGetter;
import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.annotation.JsonIgnore;

import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Object containing parameters for a model. Can be used to serialize and deserialize parameters.
*/
final class NetworkParams {

double f;

@JsonIgnore
Map<String, double[]> params = new HashMap<>();

/**
* Used by jackson
*/
public NetworkParams() {
}

public NetworkParams(double f) {
this.f = f;
}

@JsonAnyGetter
public double[] getParams(String type) {
return params.get(type);
}

@JsonAnySetter
public void setParams(String type, double[] params) {
this.params.put(type, params);
}

public boolean hasParams() {
return !params.isEmpty();
}

@Override
public String toString() {
if (f == 0)
return "Request{" + params.entrySet().stream()
.map(e -> e.getKey() + "=" + e.getValue().length).collect(Collectors.joining(",")) + '}';

return "Request{f=" + f + "}";
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package org.matsim.application.prepare.network.opt;

import com.fasterxml.jackson.annotation.JsonAnyGetter;
import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.annotation.JsonIgnore;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.objects.Object2DoubleMap;
Expand All @@ -18,7 +15,7 @@
import org.matsim.api.core.v01.network.Network;
import org.matsim.api.core.v01.network.Node;
import org.matsim.application.analysis.traffic.traveltime.SampleValidationRoutes;
import org.matsim.core.router.FastDijkstraFactory;
import org.matsim.core.router.DijkstraFactory;
import org.matsim.core.router.costcalculators.OnlyTimeDependentTravelDisutility;
import org.matsim.core.router.util.LeastCostPathCalculator;
import org.matsim.core.trafficmonitoring.FreeSpeedTravelTime;
Expand All @@ -33,7 +30,6 @@
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Private helper class with utility functions.
Expand Down Expand Up @@ -120,7 +116,7 @@ static Object2DoubleMap<SampleValidationRoutes.FromToNodes> readValidation(List<
static Result evaluate(Network network, Object2DoubleMap<SampleValidationRoutes.FromToNodes> validationSet, Map<Id<Link>, Feature> features, Map<Id<Link>, double[]> attributes, String save) throws IOException {
FreeSpeedTravelTime tt = new FreeSpeedTravelTime();
OnlyTimeDependentTravelDisutility util = new OnlyTimeDependentTravelDisutility(tt);
LeastCostPathCalculator router = new FastDijkstraFactory(false).createPathCalculator(network, util, tt);
LeastCostPathCalculator router = new DijkstraFactory(false).createPathCalculator(network, util, tt);

SummaryStatistics rmse = new SummaryStatistics();
SummaryStatistics mse = new SummaryStatistics();
Expand Down Expand Up @@ -174,50 +170,6 @@ static Result evaluate(Network network, Object2DoubleMap<SampleValidationRoutes.
return new Result(rmse.getMean(), mse.getMean(), data);
}

/**
* JSON request containing desired parameters.
*/
static final class Request {

double f;

@JsonIgnore
Map<String, double[]> params = new HashMap<>();

/**
* Used by jackson
*/
public Request() {
}

public Request(double f) {
this.f = f;
}

@JsonAnyGetter
public double[] getParams(String type) {
return params.get(type);
}

@JsonAnySetter
public void setParams(String type, double[] params) {
this.params.put(type, params);
}

public boolean hasParams() {
return !params.isEmpty();
}

@Override
public String toString() {
if (f == 0)
return "Request{" + params.entrySet().stream()
.map(e -> e.getKey() + "=" + e.getValue().length).collect(Collectors.joining(",")) + '}';

return "Request{f=" + f + "}";
}
}

record Feature(String junctionType, Object2DoubleMap<String> features) {
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.matsim.application.prepare.network.opt;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.matsim.contrib.sumo.SumoNetworkConverter;
import org.matsim.testcases.MatsimTestUtils;

import java.nio.file.Path;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

class ApplyNetworkParamsTest {

@RegisterExtension
MatsimTestUtils utils = new MatsimTestUtils();

@Test
void apply() throws Exception {

Path networkPath = Path.of(utils.getPackageInputDirectory()).resolve("osm.net.xml");

Path output = Path.of(utils.getOutputDirectory());

SumoNetworkConverter converter = SumoNetworkConverter.newInstance(List.of(networkPath),
output.resolve("network.xml"),
"EPSG:4326", "EPSG:4326");

converter.call();

assertThat(output.resolve("network.xml")).exists();
assertThat(output.resolve("network-ft.csv")).exists();

new ApplyNetworkParams().execute(
"capacity", "freespeed",
"--network", output.resolve("network.xml").toString(),
"--input-features", output.resolve("network-ft.csv").toString(),
"--output", output.resolve("network-opt.xml").toString(),
"--model", "org.matsim.application.prepare.network.opt.ref.GermanyNetworkParams"
);

assertThat(output.resolve("network-opt.xml")).exists();

}
}
Loading

0 comments on commit 9b8f1a9

Please sign in to comment.