-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
calibration cli and script to run multiple simulations at once
- Loading branch information
Showing
5 changed files
with
174 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,29 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import optuna | ||
import traceback | ||
from . import run_create_csv | ||
from . import run_simulations | ||
|
||
from . import study_as_df | ||
def _add(subparsers, m): | ||
""" Adds module to as subcommand""" | ||
s1 = subparsers.add_parser(m.METADATA[0], help=m.METADATA[1]) | ||
m.setup(s1) | ||
s1.set_defaults(func=m.main) | ||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(prog="matsim-calibration", description="Calibration CLI") | ||
parser.add_argument('file', nargs=1, type=str, help="Path to input db") | ||
parser.add_argument("--name", type=str, default="calib", help="Calibration name") | ||
parser.add_argument("--output", default=None, help="Output path") | ||
args = parser.parse_args() | ||
|
||
study = optuna.load_study( | ||
study_name=args.name, | ||
storage="sqlite:///%s" % args.file[0], | ||
) | ||
parser = argparse.ArgumentParser(prog="matsim-calibration", description="MATSim calibration command line utility") | ||
|
||
if not args.output: | ||
args.output = args.file[0] + ".csv" | ||
subparsers = parser.add_subparsers(title="Subcommands") | ||
|
||
df = study_as_df(study) | ||
df.to_csv(args.output, index=False) | ||
_add(subparsers, run_create_csv) | ||
_add(subparsers, run_simulations) | ||
|
||
try: | ||
from .plot import plot_study | ||
plot_study(study) | ||
args = parser.parse_args() | ||
|
||
except ImportError: | ||
print("Could not plot study.") | ||
traceback.print_exc() | ||
if not hasattr(args, 'func'): | ||
parser.print_help() | ||
else: | ||
args.func(args) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import argparse | ||
|
||
METADATA = "create-csv", "Create plots and csv from calibration study." | ||
|
||
def setup(parser: argparse.ArgumentParser): | ||
parser.add_argument('file', nargs=1, type=str, help="Path to input db") | ||
parser.add_argument("--name", type=str, default="calib", help="Calibration name") | ||
parser.add_argument("--output", default=None, help="Output path") | ||
|
||
def main(args): | ||
|
||
import optuna | ||
from . import study_as_df | ||
|
||
study = optuna.load_study( | ||
study_name=args.name, | ||
storage="sqlite:///%s" % args.file[0], | ||
) | ||
|
||
if not args.output: | ||
args.output = args.file[0] + ".csv" | ||
|
||
df = study_as_df(study) | ||
df.to_csv(args.output, index=False) | ||
|
||
try: | ||
from .plot import plot_study | ||
plot_study(study) | ||
|
||
except ImportError: | ||
print("Could not plot study.") | ||
traceback.print_exc() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import argparse | ||
import os | ||
import subprocess | ||
import sys | ||
from os import makedirs | ||
from time import sleep | ||
from typing import Union, Callable | ||
|
||
METADATA = "run-simulations", "Utility to run multiple simulations at once." | ||
|
||
|
||
def process_results(directory): | ||
"""Process results of multiple simulations""" | ||
|
||
print("Processing results in %s" % directory) | ||
|
||
|
||
def run(jar: Union[str, os.PathLike], | ||
config: Union[str, os.PathLike], | ||
args: Union[str, Callable] = "", | ||
jvm_args="", | ||
runs: int = 10, | ||
worker_id: int = 0, | ||
workers: int = 1, | ||
seed: int = 4711, | ||
overwrite: bool = False, | ||
custom_cli: Callable = None, | ||
debug: bool = False): | ||
"""Run multiple simulations using different seeds at once. Simulations will be performed sequentially. | ||
For parallel execution, multiple workers must be started. | ||
:param jar: path to executable jar file of the scenario | ||
:param config: path to config file to run | ||
:param args: arguments to pass to the simulation | ||
:param jvm_args: arguments to pass to the JVM | ||
:param runs: number of simulations to run | ||
:param worker_id: id of this process | ||
:param workers: total number of processes | ||
:param seed: starting seed | ||
:param overwrite: overwrite existing output directory | ||
:param custom_cli: use custom command line interface | ||
:param debug: if true, output will be printed to console | ||
""" | ||
if not os.access(jar, os.R_OK): | ||
raise ValueError("Can not access JAR File: %s" % jar) | ||
|
||
if not os.access(config, os.R_OK): | ||
raise ValueError("Can not access config File: %s" % config) | ||
|
||
if not os.path.exists("eval-runs"): | ||
makedirs("eval-runs") | ||
|
||
for i in range(runs): | ||
if i % workers != worker_id: | ||
continue | ||
|
||
run_dir = "eval-runs/%03d" % i | ||
|
||
if os.path.exists(run_dir) and not overwrite: | ||
print("Run %s already exists, skipping." % run_dir) | ||
continue | ||
|
||
run_args = args(completed) if callable(args) else args | ||
|
||
# Same custom cli interface as calibration | ||
if custom_cli: | ||
cmd = custom_cli(jvm_args, jar, config, params_path, run_dir, trial.number, run_args) | ||
else: | ||
cmd = "java %s -jar %s run --config %s --output %s --runId %03d --config:global.randomSeed=%d %s" \ | ||
% (jvm_args, jar, config, run_dir, i, seed + i, run_args) | ||
|
||
# Extra whitespaces will break argument parsing | ||
cmd = cmd.strip() | ||
|
||
print("Running cmd %s" % cmd) | ||
|
||
if os.name != 'nt': | ||
cmd = cmd.split(" ") | ||
cmd = [c for c in cmd if c != ""] | ||
|
||
p = subprocess.Popen(cmd, | ||
stdout=sys.stdout if debug else subprocess.DEVNULL, | ||
stderr=sys.stderr if debug else subprocess.DEVNULL) | ||
|
||
try: | ||
while p.poll() is None: | ||
sleep(1) | ||
|
||
if p.returncode != 0: | ||
print("The scenario could not be run properly and returned with an error code.", file=sys.stderr) | ||
if not debug: | ||
print("Set debug=True and check the output for any errors.", file=sys.stderr) | ||
print("Alternatively run the cmd from the log above manually and check for errors.", | ||
file=sys.stderr) | ||
|
||
raise Exception("Process returned with error code: %s." % p.returncode) | ||
finally: | ||
p.terminate() | ||
|
||
process_results("eval_runs") | ||
|
||
|
||
def setup(parser: argparse.ArgumentParser): | ||
parser.add_argument("--jar", type=str, required=True, help="Path to executable JAR file") | ||
parser.add_argument("--config", type=str, required=True, help="Path to config file") | ||
parser.add_argument("--args", type=str, default="", help="Arguments to pass to the simulation") | ||
parser.add_argument("--jvm-args", type=str, default="", help="Arguments to pass to the JVM") | ||
parser.add_argument("--runs", type=int, default=10, help="Number of simulations to run") | ||
parser.add_argument("--worker-id", type=int, default=0, help="ID of this worker") | ||
parser.add_argument("--workers", type=int, default=1, help="Total number of workers") | ||
parser.add_argument("--seed", type=int, default=4711, help="Starting seed") | ||
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing output directories") | ||
parser.add_argument("--debug", action="store_true", help="Print output to console") | ||
|
||
|
||
def main(args): | ||
run(args.jar, args.config, args.args, args.jvm_args, args.runs, args.worker_id, args.workers, args.seed, | ||
args.overwrite, debug=args.debug) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters