diff --git a/pyro/compressible/interface.py b/pyro/compressible/interface.py index 877c88790..dbff86b61 100644 --- a/pyro/compressible/interface.py +++ b/pyro/compressible/interface.py @@ -3,7 +3,7 @@ @njit(cache=True) -def states(idir, grid, dt, +def states(idir, ng, dx, dt, irho, iu, iv, ip, ix, nspec, gamma, qv, dqv): r""" @@ -65,8 +65,10 @@ def states(idir, grid, dt, ---------- idir : int Are we predicting to the edges in the x-direction (1) or y-direction (2)? - grid : Grid2d, Cartesian2d, or SphericalPolar - The grid object. + ng : int + The number of ghost cells + dx : ndarray + The cell spacing dt : float The timestep irho, iu, iv, ip, ix : int @@ -92,13 +94,16 @@ def states(idir, grid, dt, q_l = np.zeros_like(qv) q_r = np.zeros_like(qv) - ns = nvar - nspec + nx = qx - 2 * ng + ny = qy - 2 * ng + ilo = ng + ihi = ng + nx + jlo = ng + jhi = ng + ny - if idir == 1: - dtdx = dt / grid.Lx.v() - else: - dtdx = dt / grid.Ly.v() + ns = nvar - nspec + dtdx = dt / dx dtdx4 = 0.25 * dtdx lvec = np.zeros((nvar, nvar)) @@ -108,8 +113,8 @@ def states(idir, grid, dt, betar = np.zeros(nvar) # this is the loop over zones. For zone i, we see q_l[i+1] and q_r[i] - for i in range(grid.ilo - 2, grid.ihi + 2): - for j in range(grid.jlo - 2, grid.jhi + 2): + for i in range(ilo - 2, ihi + 2): + for j in range(jlo - 2, jhi + 2): dq = dqv[i, j, :] q = qv[i, j, :] diff --git a/pyro/compressible/unsplit_fluxes.py b/pyro/compressible/unsplit_fluxes.py index 9c5a04493..7e93e6ea9 100644 --- a/pyro/compressible/unsplit_fluxes.py +++ b/pyro/compressible/unsplit_fluxes.py @@ -217,7 +217,7 @@ def unsplit_fluxes(my_data, my_aux, rp, ivars, solid, tc, dt): tm_states = tc.timer("interfaceStates") tm_states.begin() - V_l, V_r = ifc.states(1, myg, dt, + V_l, V_r = ifc.states(1, myg.ng, myg.Lx, dt, ivars.irho, ivars.iu, ivars.iv, ivars.ip, ivars.ix, ivars.naux, gamma, @@ -236,7 +236,7 @@ def unsplit_fluxes(my_data, my_aux, rp, ivars, solid, tc, dt): # left and right primitive variable states tm_states.begin() - _V_l, _V_r = ifc.states(2, myg, dt, + _V_l, _V_r = ifc.states(2, myg.ng, myg.Ly, dt, ivars.irho, ivars.iu, ivars.iv, ivars.ip, ivars.ix, ivars.naux, gamma,