Skip to content

Commit

Permalink
get rid of command_line_params (#257)
Browse files Browse the repository at this point in the history
we don't need a separate RuntimeParameters method to parse a string,
instead we can just do this in pyro_sim.py and then use the existing
set_val method from a dict.  

This also allows us to simplify the confusing Pyro() interface.
  • Loading branch information
zingale authored Sep 6, 2024
1 parent 9c65825 commit 23cb4f4
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 120 deletions.
4 changes: 2 additions & 2 deletions presentations/pyro_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@
"from pyro import Pyro\n",
"pyro_sim = Pyro(\"advection\")\n",
"pyro_sim.initialize_problem(\"tophat\",\n",
" other_commands=[\"mesh.nx=8\", \"mesh.ny=8\",\n",
" \"vis.dovis=0\"])\n",
" inputs_dict={\"mesh.nx\": 8,\n",
" \"mesh.ny\": 8})\n",
"pyro_sim.run_sim()"
]
},
Expand Down
22 changes: 11 additions & 11 deletions pyro/pyro_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import pyro.util.io_pyro as io
import pyro.util.profile_pyro as profile
from pyro.util import compare, msg, runparams
from pyro.util import compare, msg
from pyro.util.runparams import RuntimeParameters, _get_val

valid_solvers = ["advection",
"advection_nonuniform",
Expand Down Expand Up @@ -71,16 +72,15 @@ def __init__(self, solver_name, *, from_commandline=False):
# -------------------------------------------------------------------------

# parameter defaults
self.rp = runparams.RuntimeParameters()
self.rp = RuntimeParameters()
self.rp.load_params(self.pyro_home + "_defaults")
self.rp.load_params(self.pyro_home + self.solver_name + "/_defaults")

self.tc = profile.TimerCollection()

self.is_initialized = False

def initialize_problem(self, problem_name, *, inputs_file=None, inputs_dict=None,
other_commands=None):
def initialize_problem(self, problem_name, *, inputs_file=None, inputs_dict=None):
"""
Initialize the specific problem
Expand All @@ -92,8 +92,6 @@ def initialize_problem(self, problem_name, *, inputs_file=None, inputs_dict=None
Filename containing problem's runtime parameters
inputs_dict : dict
Dictionary containing extra runtime parameters
other_commands : str
Other command line parameter options
"""
# pylint: disable=attribute-defined-outside-init

Expand Down Expand Up @@ -126,10 +124,6 @@ def initialize_problem(self, problem_name, *, inputs_file=None, inputs_dict=None
for k, v in inputs_dict.items():
self.rp.set_param(k, v)

# and any commandline overrides
if other_commands is not None:
self.rp.command_line_params(other_commands)

# write out the inputs.auto
self.rp.print_paramfile()

Expand Down Expand Up @@ -399,9 +393,15 @@ def main():
else:
pyro = Pyro(args.solver[0], from_commandline=True)

other = {}
for param_string in args.other:
k, v = param_string.split("=")
other[k] = _get_val(v)

print(other)
pyro.initialize_problem(problem_name=args.problem[0],
inputs_file=args.param[0],
other_commands=args.other)
inputs_dict=other)
pyro.run_sim()


Expand Down
62 changes: 31 additions & 31 deletions pyro/solver-test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
"solver = \"advection\"\n",
"problem_name = \"smooth\"\n",
"param_file = \"inputs.smooth\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=8\", \"mesh.ny=8\"]"
"params = {\"driver.max_steps\":1, \"mesh.nx\": 8, \"mesh.ny\": 8}"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"tags": [
"nbval-ignore-output"
Expand Down Expand Up @@ -68,14 +68,14 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {
"scrolled": true
},
Expand Down Expand Up @@ -121,7 +121,7 @@
"solver = \"advection_nonuniform\"\n",
"problem_name = \"slotted\"\n",
"param_file = \"inputs.slotted\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=8\", \"mesh.ny=8\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 8, \"mesh.ny\": 8}"
]
},
{
Expand Down Expand Up @@ -162,7 +162,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down Expand Up @@ -213,7 +213,7 @@
"solver = \"advection_fv4\"\n",
"problem_name = \"smooth\"\n",
"param_file = \"inputs.smooth\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=8\", \"mesh.ny=8\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 8, \"mesh.ny\": 8}"
]
},
{
Expand Down Expand Up @@ -247,7 +247,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down Expand Up @@ -298,7 +298,7 @@
"solver = \"advection_rk\"\n",
"problem_name = \"tophat\"\n",
"param_file = \"inputs.tophat\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=8\", \"mesh.ny=8\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 8, \"mesh.ny\": 8}"
]
},
{
Expand Down Expand Up @@ -332,7 +332,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down Expand Up @@ -376,19 +376,19 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"solver = \"compressible\"\n",
"problem_name = \"rt\"\n",
"param_file = \"inputs.rt\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=8\", \"mesh.ny=24\", \"driver.verbose=0\", \"compressible.riemann=CGF\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 8, \"mesh.ny\": 24, \"compressible.riemann\": \"CGF\"}"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"metadata": {
"tags": [
"nbval-ignore-output"
Expand Down Expand Up @@ -425,14 +425,14 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -485,19 +485,19 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"solver = \"compressible_fv4\"\n",
"problem_name = \"kh\"\n",
"param_file = \"inputs.kh\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=8\", \"mesh.ny=8\", \"driver.verbose=0\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 8, \"mesh.ny\": 8}"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 21,
"metadata": {
"tags": [
"nbval-ignore-output"
Expand Down Expand Up @@ -526,14 +526,14 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 22,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -570,14 +570,14 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"solver = \"compressible_rk\"\n",
"problem_name = \"quad\"\n",
"param_file = \"inputs.quad\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=16\", \"mesh.ny=16\", \"driver.verbose=0\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 16, \"mesh.ny\": 16}"
]
},
{
Expand Down Expand Up @@ -611,7 +611,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down Expand Up @@ -670,7 +670,7 @@
"solver = \"compressible_sdc\"\n",
"problem_name = \"sod\"\n",
"param_file = \"inputs.sod.y\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=4\", \"mesh.ny=16\", \"driver.verbose=0\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 4, \"mesh.ny\": 16}"
]
},
{
Expand Down Expand Up @@ -715,7 +715,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down Expand Up @@ -774,7 +774,7 @@
"solver = \"diffusion\"\n",
"problem_name = \"gaussian\"\n",
"param_file = \"inputs.gaussian\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=16\", \"mesh.ny=16\", \"driver.verbose=0\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 16, \"mesh.ny\": 16}"
]
},
{
Expand Down Expand Up @@ -818,7 +818,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down Expand Up @@ -877,7 +877,7 @@
"solver = \"incompressible\"\n",
"problem_name = \"shear\"\n",
"param_file = \"inputs.shear\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=8\", \"mesh.ny=8\", \"driver.verbose=0\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 8, \"mesh.ny\": 8}"
]
},
{
Expand Down Expand Up @@ -921,7 +921,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down Expand Up @@ -972,7 +972,7 @@
"solver = \"lm_atm\"\n",
"problem_name = \"bubble\"\n",
"param_file = \"inputs.bubble\"\n",
"other_commands = [\"driver.max_steps=1\", \"mesh.nx=16\", \"mesh.ny=16\", \"driver.verbose=0\"]"
"params = {\"driver.max_steps\": 1, \"mesh.nx\": 16, \"mesh.ny\": 16}"
]
},
{
Expand Down Expand Up @@ -1006,7 +1006,7 @@
],
"source": [
"pyro_sim = Pyro(solver)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, other_commands=other_commands)\n",
"pyro_sim.initialize_problem(problem_name, inputs_file=param_file, inputs_dict=params)\n",
"pyro_sim.run_sim()\n",
"pyro_sim.sim.dovis()"
]
Expand Down
179 changes: 136 additions & 43 deletions pyro/solver_test_swe.ipynb

Large diffs are not rendered by default.

30 changes: 0 additions & 30 deletions pyro/util/runparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,36 +163,6 @@ def load_params(self, pfile, *, no_new=False):

self.param_comments[key] = comment.strip()

def command_line_params(self, cmd_strings):
"""
finds dictionary pairs from a string that came from the
commandline. Stores the parameters in only if they
already exist.
we expect things in the string in the form:
["sec.opt=value", "sec.opt=value"]
with each opt an element in the list
Parameters
----------
cmd_strings : list
The list of strings containing runtime parameter pairs
"""

for item in cmd_strings:

# break it apart
key, value = item.split("=")

# we only want to override existing keys/values
if key not in self.params:
msg.warning("warning, key: %s not defined" % (key))
continue

# check in turn whether this is an integer, float, or string
self.params[key] = _get_val(value)

def get_param(self, key):
"""
returns the value of the runtime parameter corresponding to the
Expand Down
7 changes: 4 additions & 3 deletions pyro/util/tests/test_runparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ def test_get_param(self):
assert self.rp.get_param("test3.param") == "this is a test"
assert self.rp.get_param("test3.i1") == 1

def test_command_line_params(self):
def test_dict(self):

param_string = "test.p1=q test3.i1=2"
params = {"test.p1": "q", "test3.i1": 2}

self.rp.command_line_params(param_string.split())
for k, v in params.items():
self.rp.set_param(k, v, no_new=False)

assert self.rp.get_param("test.p1") == "q"
assert self.rp.get_param("test3.i1") == 2
Expand Down

0 comments on commit 23cb4f4

Please sign in to comment.