Skip to content

Commit

Permalink
refactor the problem setup to allow init via function (#262)
Browse files Browse the repository at this point in the history
this will allow us to define a problem setup completely in Jupyter
an example was added to the docs demonstrating this
  • Loading branch information
zingale authored Sep 8, 2024
1 parent e4dffa7 commit 8e6acd6
Show file tree
Hide file tree
Showing 24 changed files with 553 additions and 121 deletions.
444 changes: 444 additions & 0 deletions docs/source/adding_a_problem_jupyter.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ new ideas.
:hidden:

compressible-rt-compare.ipynb
adding_a_problem_jupyter.ipynb
advection-error.ipynb
compressible-convergence.ipynb

Expand Down
5 changes: 1 addition & 4 deletions pyro/advection/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -35,8 +33,7 @@ def initialize(self):
self.particles = particles.Particles(self.cc_data, bc, n_particles, particle_generator)

# now set the initial conditions for the problem
problem = importlib.import_module(f"pyro.advection.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.rp)
self.problem_func(self.cc_data, self.rp)

def method_compute_timestep(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions pyro/advection/tests/test_advection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pyro.advection.simulation as sn
import pyro.advection.simulation as sim
from pyro.advection.problems import test
from pyro.util import runparams


Expand All @@ -19,7 +20,7 @@ def setup_method(self):
self.rp.params["mesh.ny"] = 8
self.rp.params["particles.do_particles"] = 0

self.sim = sn.Simulation("advection", "test", self.rp)
self.sim = sim.Simulation("advection", "test", test.init_data, self.rp)
self.sim.initialize()

def teardown_method(self):
Expand Down
5 changes: 1 addition & 4 deletions pyro/advection_fv4/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

import pyro.advection_fv4.fluxes as flx
import pyro.mesh.array_indexer as ai
from pyro import advection_rk
Expand Down Expand Up @@ -32,8 +30,7 @@ def initialize(self):
self.particles = particles.Particles(self.cc_data, bc, n_particles, particle_generator)

# now set the initial conditions for the problem
problem = importlib.import_module(f"pyro.advection_fv4.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.rp)
self.problem_func(self.cc_data, self.rp)

def substep(self, myd):
"""
Expand Down
5 changes: 1 addition & 4 deletions pyro/advection_nonuniform/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -55,8 +53,7 @@ def shift(velocity):
self.particles = particles.Particles(self.cc_data, bc, n_particles, particle_generator)

# now set the initial conditions for the problem
problem = importlib.import_module(f"pyro.advection_nonuniform.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.rp)
self.problem_func(self.cc_data, self.rp)

# compute the required shift for each node using corresponding velocity at the node
shx = self.cc_data.get_var("x-shift")
Expand Down
5 changes: 3 additions & 2 deletions pyro/advection_nonuniform/tests/test_advection_nonuniform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pyro.advection_nonuniform.simulation as sn
import pyro.advection_nonuniform.simulation as sim
from pyro.advection_nonuniform.problems import test
from pyro.util import runparams


Expand All @@ -19,7 +20,7 @@ def setup_method(self):
self.rp.params["mesh.ny"] = 8
self.rp.params["particles.do_particles"] = 0

self.sim = sn.Simulation("advection_nonuniform", "test", self.rp)
self.sim = sim.Simulation("advection_nonuniform", "test", test.init_data, self.rp)
self.sim.initialize()

def teardown_method(self):
Expand Down
5 changes: 1 addition & 4 deletions pyro/burgers/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -43,8 +41,7 @@ def initialize(self):
self.particles = particles.Particles(self.cc_data, bc, n_particles, particle_generator)

# now set the initial conditions for the problem
problem = importlib.import_module(f"pyro.burgers.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.rp)
self.problem_func(self.cc_data, self.rp)

def method_compute_timestep(self):
"""
Expand Down
5 changes: 1 addition & 4 deletions pyro/compressible/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -167,8 +165,7 @@ def initialize(self, *, extra_vars=None, ng=4):
self.cc_data.add_derived(derives.derive_primitives)

# initial conditions for the problem
problem = importlib.import_module(f"pyro.{self.solver_name}.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.rp)
self.problem_func(self.cc_data, self.rp)

if self.verbose > 0:
print(my_data)
Expand Down
9 changes: 5 additions & 4 deletions pyro/compressible/tests/test_compressible.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pytest
from numpy.testing import assert_array_equal

import pyro.compressible.simulation as sn
import pyro.compressible.simulation as sim
from pyro.compressible.problems import test
from pyro.util import runparams


Expand All @@ -26,7 +27,7 @@ def setup_method(self):
self.rp.params["eos.gamma"] = 1.4
self.rp.params["compressible.grav"] = 1.0

self.sim = sn.Simulation("compressible", "test", self.rp)
self.sim = sim.Simulation("compressible", "test", test.init_data, self.rp)
self.sim.initialize()

def teardown_method(self):
Expand All @@ -45,13 +46,13 @@ def test_prim(self):

# U -> q
gamma = self.sim.cc_data.get_aux("gamma")
q = sn.cons_to_prim(self.sim.cc_data.data, gamma, self.sim.ivars, self.sim.cc_data.grid)
q = sim.cons_to_prim(self.sim.cc_data.data, gamma, self.sim.ivars, self.sim.cc_data.grid)

assert q[:, :, self.sim.ivars.ip].min() == pytest.approx(1.0) and \
q[:, :, self.sim.ivars.ip].max() == pytest.approx(1.0)

# q -> U
U = sn.prim_to_cons(q, gamma, self.sim.ivars, self.sim.cc_data.grid)
U = sim.prim_to_cons(q, gamma, self.sim.ivars, self.sim.cc_data.grid)
assert_array_equal(U, self.sim.cc_data.data)

def test_derives(self):
Expand Down
8 changes: 5 additions & 3 deletions pyro/compressible_fv4/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

class Simulation(compressible_rk.Simulation):

def __init__(self, solver_name, problem_name, rp, *,
timers=None, data_class=fv.FV2d):
super().__init__(solver_name, problem_name, rp, timers=timers, data_class=data_class)
def __init__(self, solver_name, problem_name, problem_func, rp, *,
problem_finalize_func=None, timers=None, data_class=fv.FV2d):
super().__init__(solver_name, problem_name, problem_func, rp,
problem_finalize_func=problem_finalize_func,
timers=timers, data_class=data_class)

def substep(self, myd):
"""
Expand Down
5 changes: 3 additions & 2 deletions pyro/compressible_rk/tests/test_compressible_rk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pyro.compressible_rk.simulation as sn
import pyro.compressible_rk.simulation as sim
from pyro.compressible_rk.problems import test
from pyro.util import runparams


Expand All @@ -22,7 +23,7 @@ def setup_method(self):
self.rp.params["eos.gamma"] = 1.4
self.rp.params["compressible.grav"] = 1.0

self.sim = sn.Simulation("compressible", "test", self.rp)
self.sim = sim.Simulation("compressible", "test", test.init_data, self.rp)
self.sim.initialize()

def teardown_method(self):
Expand Down
8 changes: 2 additions & 6 deletions pyro/diffusion/simulation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
""" A simulation of diffusion """

import importlib
import math

import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -28,7 +25,7 @@ def initialize(self):
if my_grid.nx != my_grid.ny:
msg.fail("need nx = ny for diffusion problems")

n = int(math.log(my_grid.nx)/math.log(2.0))
n = int(np.log(my_grid.nx)/np.log(2.0))
if 2**n != my_grid.nx:
msg.fail("grid needs to be a power of 2")

Expand All @@ -50,8 +47,7 @@ def initialize(self):
self.cc_data = my_data

# now set the initial conditions for the problem
problem = importlib.import_module(f"pyro.diffusion.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.rp)
self.problem_func(self.cc_data, self.rp)

def method_compute_timestep(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions pyro/diffusion/tests/test_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pyro.diffusion.simulation as sn
import pyro.diffusion.simulation as sim
from pyro.diffusion.problems import test
from pyro.util import runparams


Expand All @@ -19,7 +20,7 @@ def setup_method(self):
self.rp.params["mesh.nx"] = 8
self.rp.params["mesh.ny"] = 8

self.sim = sn.Simulation("diffusion", "test", self.rp)
self.sim = sim.Simulation("diffusion", "test", test.init_data, self.rp)
self.sim.initialize()

def teardown_method(self):
Expand Down
5 changes: 1 addition & 4 deletions pyro/incompressible/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -64,8 +62,7 @@ def initialize(self, *, other_bc=False, aux_vars=()):
self.in_preevolve = False

# now set the initial conditions for the problem
problem = importlib.import_module(f"pyro.{self.solver_name}.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.rp)
self.problem_func(self.cc_data, self.rp)

def preevolve(self):
"""
Expand Down
11 changes: 5 additions & 6 deletions pyro/lm_atm/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -37,9 +35,11 @@ def jp(self, shift, buf=0):

class Simulation(NullSimulation):

def __init__(self, solver_name, problem_name, rp, *, timers=None):
def __init__(self, solver_name, problem_name, problem_func, rp, *,
problem_finalize_func=None, timers=None):

NullSimulation.__init__(self, solver_name, problem_name, rp, timers=timers)
super().__init__(solver_name, problem_name, problem_func, rp,
problem_finalize_func=problem_finalize_func, timers=timers)

self.base = {}
self.aux_data = None
Expand Down Expand Up @@ -114,8 +114,7 @@ def initialize(self):
self.base["p0"] = Basestate(myg.ny, ng=myg.ng)

# now set the initial conditions for the problem
problem = importlib.import_module(f"pyro.lm_atm.problems.{self.problem_name}")
problem.init_data(self.cc_data, self.base, self.rp)
self.problem_func(self.cc_data, self.base, self.rp)

# Construct beta_0
gamma = self.rp.get_param("eos.gamma")
Expand Down
2 changes: 1 addition & 1 deletion pyro/particles/tests/test_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def setup_test(n_particles=50, extra_rp_params=None):
rp.params[param] = value

# set up sim
sim = NullSimulation("", "", rp)
sim = NullSimulation("", "", rp, None)

# set up grid
my_grid = grid_setup(rp)
Expand Down
52 changes: 44 additions & 8 deletions pyro/pyro_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,16 @@ def __init__(self, solver_name, *, from_commandline=False):
self.solver = importlib.import_module(solver_import)
self.solver_name = solver_name

# -------------------------------------------------------------------------
self.problem_name = None
self.problem_func = None
self.problem_params = None
self.problem_finalize = None

# custom problems

self.custom_problems = {}

# runtime parameters
# -------------------------------------------------------------------------

# parameter defaults
self.rp = RuntimeParameters()
Expand All @@ -80,6 +87,23 @@ def __init__(self, solver_name, *, from_commandline=False):

self.is_initialized = False

def add_problem(self, name, problem_func, *, problem_params=None):
"""Add a problem setup for this solver.
Parameters
----------
name : str
The descriptive name of the problem
problem_func : function
The function to initialize the state data
problem_params : dict
A dictionary of runtime parameters needed for the problem setup
"""

if problem_params is None:
problem_params = {}
self.custom_problems[name] = (problem_func, problem_params)

def initialize_problem(self, problem_name, *, inputs_file=None, inputs_dict=None):
"""
Initialize the specific problem
Expand All @@ -95,16 +119,27 @@ def initialize_problem(self, problem_name, *, inputs_file=None, inputs_dict=None
"""
# pylint: disable=attribute-defined-outside-init

problem = importlib.import_module("pyro.{}.problems.{}".format(self.solver_name, problem_name))
if problem_name in self.custom_problems:
# this is a problem we added via self.add_problem
self.problem_name = problem_name
self.problem_func, self.problem_params = self.custom_problems[problem_name]
self.problem_finalize = None

else:
problem = importlib.import_module("pyro.{}.problems.{}".format(self.solver_name, problem_name))
self.problem_name = problem_name
self.problem_func = problem.init_data
self.problem_params = problem.PROBLEM_PARAMS
self.problem_finalize = problem.finalize

if inputs_file is None:
inputs_file = problem.DEFAULT_INPUTS

# problem-specific runtime parameters
for k, v in problem.PROBLEM_PARAMS.items():
for k, v in self.problem_params.items():
self.rp.set_param(k, v, no_new=False)

# now read in the inputs file
if inputs_file is None:
inputs_file = problem.DEFAULT_INPUTS

if inputs_file is not None:
if not os.path.isfile(inputs_file):
# check if the param file lives in the solver's problems directory
Expand Down Expand Up @@ -138,7 +173,8 @@ def initialize_problem(self, problem_name, *, inputs_file=None, inputs_dict=None
# data and know about the runtime parameters and which problem we
# are running
self.sim = self.solver.Simulation(
self.solver_name, problem_name, self.rp, timers=self.tc)
self.solver_name, self.problem_name, self.problem_func, self.rp,
problem_finalize_func=self.problem_finalize, timers=self.tc)

self.sim.initialize()
self.sim.preevolve()
Expand Down
Loading

0 comments on commit 8e6acd6

Please sign in to comment.