Skip to content

Commit

Permalink
minor speedup steady state
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Dec 27, 2022
1 parent 3d4599c commit be0959b
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 109 deletions.
109 changes: 32 additions & 77 deletions econpizza/parser/build_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
from grgrlib.jaxed import *
from .het_agent_funcs import _backwards_sweep, _forwards_sweep, _final_step, _second_sweep, _stacked_func_dist


def get_func_stst_raw(func_pre_stst, func_backw, func_stst_dist, func_eqns, shocks, init_vf, decisions_output_init, exog_grid_vars_init, tol_backw, maxit_backw, tol_forw, maxit_forw):
Expand All @@ -13,44 +14,44 @@ def get_func_stst_raw(func_pre_stst, func_backw, func_stst_dist, func_eqns, shoc
zshock = jnp.zeros(len(shocks))

def cond_func(cont):
(vf, _, _), vf_old, cnt = cont
(vf, _, _), (vf_old, cnt), _ = cont
cond0 = jnp.abs(vf - vf_old).max() > tol_backw
cond1 = cnt < maxit_backw
return cond0 & cond1
return jnp.logical_and(cond0, cond1)

def body_func_raw(cont, x, par):
(vf, _, _), _, cnt = cont
return func_backw(x, x, x, x, vf, zshock, par), vf, cnt + 1
def body_func(cont):
(vf, _, _), (_, cnt), (x, par) = cont
return func_backw(x, x, x, x, vf, zshock, par), (vf, cnt + 1), (x, par)

def func_backw_ext(x, par):
def find_stat_vf(x, par):

def body_func(cont): return body_func_raw(cont, x, par)

(vf, decisions_output, exog_grid_vars), _, cnt = jax.lax.while_loop(
cond_func, body_func, ((init_vf, decisions_output_init, exog_grid_vars_init), init_vf+1, 0))
(vf, decisions_output, exog_grid_vars), (_, cnt), _ = jax.lax.while_loop(
cond_func, body_func, ((init_vf, decisions_output_init, exog_grid_vars_init), (init_vf+1, 0), (x, par)))

return vf, decisions_output, exog_grid_vars, cnt

def func_stst_raw(y, full_output=False):
def func_stst_raw(y):

x, par = func_pre_stst(y)
x = x[..., None]

if not func_stst_dist:
return func_eqns(x, x, x, x, zshock, par)
return func_eqns(x, x, x, x, zshock, par), None

vf, decisions_output, exog_grid_vars, cnt_backw = func_backw_ext(
vf, decisions_output, exog_grid_vars, cnt_backw = find_stat_vf(
x, par)
dist, cnt_forw = func_stst_dist(decisions_output, tol_forw, maxit_forw)

# TODO: for more than one dist this should be a loop...
decisions_output_array = decisions_output[..., None]
dist_array = dist[..., None]

if full_output:
return (vf, decisions_output, exog_grid_vars, cnt_backw), (dist, cnt_forw)
aux = (vf, decisions_output, exog_grid_vars,
cnt_backw), (dist, cnt_forw)
out = func_eqns(x, x, x, x, zshock, par,
dist_array, decisions_output_array)

return func_eqns(x, x, x, x, zshock, par, dist_array, decisions_output_array)
return out, aux

return func_stst_raw

Expand All @@ -60,67 +61,21 @@ def get_stacked_func_dist(pars, func_backw, func_dist, func_eqns, stst, vfSS, di
"""

nshpe = (nvars, horizon-1)

def backwards_step(carry, i):

vf, X, shocks = carry
vf, decisions_output, exog_grid_vars = func_backw(
X[:, i], X[:, i+1], X[:, i+2], stst, vf, shocks[:, i], pars)

return (vf, X, shocks), decisions_output

def backwards_sweep(x, x0, shocks):

X = jnp.hstack((x0, x, stst)).reshape(horizon+1, -1).T

_, decisions_output_storage = jax.lax.scan(
backwards_step, (vfSS, X, shocks), jnp.arange(horizon-1), reverse=True)
decisions_output_storage = jnp.moveaxis(
decisions_output_storage, 0, -1)

return decisions_output_storage

def forwards_step(carry, i):

dist_old, decisions_output_storage = carry
dist = func_dist(dist_old, decisions_output_storage[..., i])

return (dist, decisions_output_storage), dist_old

def forwards_sweep(decisions_output_storage):

_, dists_storage = jax.lax.scan(
forwards_step, (distSS, decisions_output_storage), jnp.arange(horizon-1))
dists_storage = jnp.moveaxis(dists_storage, 0, -1)

return dists_storage

def final_step(x, dists_storage, decisions_output_storage, x0, shocks):

X = jnp.hstack((x0, x, stst)).reshape(horizon+1, -1).T
out = func_eqns(X[:, :-2].reshape(nshpe), X[:, 1:-1].reshape(nshpe), X[:, 2:].reshape(
nshpe), stst, shocks, pars, dists_storage, decisions_output_storage)

return out

def second_sweep(x, decisions_output_storage, x0, shocks):

# forwards step
dists_storage = forwards_sweep(decisions_output_storage)
# final step
out = final_step(x, dists_storage,
decisions_output_storage, x0, shocks)

return out

def stacked_func_dist(x, x0, shocks):

# backwards step
decisions_output_storage = backwards_sweep(x, x0, shocks)
# combined step
out = second_sweep(x, decisions_output_storage, x0, shocks)

return out
# build partials of input functions
func_backw = jax.tree_util.Partial(func_backw, XSS=stst, pars=pars)
func_dist = jax.tree_util.Partial(func_dist)

# build partials of output functions
backwards_sweep = jax.tree_util.Partial(
_backwards_sweep, stst=stst, vfSS=vfSS, horizon=horizon, func_backw=func_backw)
forwards_sweep = jax.tree_util.Partial(
_forwards_sweep, distSS=distSS, horizon=horizon, func_dist=func_dist)
final_step = jax.tree_util.Partial(
_final_step, stst=stst, horizon=horizon, nshpe=nshpe, pars=pars, func_eqns=func_eqns)
second_sweep = jax.tree_util.Partial(
_second_sweep, forwards_sweep=forwards_sweep, final_step=final_step)
stacked_func_dist = jax.tree_util.Partial(
_stacked_func_dist, backwards_sweep=backwards_sweep, second_sweep=second_sweep)

return stacked_func_dist, backwards_sweep, forwards_sweep, second_sweep

Expand Down
73 changes: 73 additions & 0 deletions econpizza/parser/het_agent_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Subfunctions of stacked_func_dist
"""

import jax
import jax.numpy as jnp


def backwards_step(carry, i):

vf, X, shocks, func_backw, stst = carry
vf, decisions_output, exog_grid_vars = func_backw(
X[:, i], X[:, i+1], X[:, i+2], VFPrime=vf, shocks=shocks[:, i])

return (vf, X, shocks, func_backw, stst), decisions_output


def _backwards_sweep(x, x0, shocks, stst, vfSS, horizon, func_backw):

X = jnp.hstack((x0, x, stst)).reshape(horizon+1, -1).T

_, decisions_output_storage = jax.lax.scan(
backwards_step, (vfSS, X, shocks, func_backw, stst), jnp.arange(horizon-1), reverse=True)
decisions_output_storage = jnp.moveaxis(
decisions_output_storage, 0, -1)

return decisions_output_storage


def forwards_step(carry, i):

dist_old, decisions_output_storage, func_dist = carry
dist = func_dist(dist_old, decisions_output_storage[..., i])

return (dist, decisions_output_storage, func_dist), dist_old


def _forwards_sweep(decisions_output_storage, distSS, horizon, func_dist):

_, dists_storage = jax.lax.scan(
forwards_step, (distSS, decisions_output_storage, func_dist), jnp.arange(horizon-1))
dists_storage = jnp.moveaxis(dists_storage, 0, -1)

return dists_storage


def _final_step(x, dists_storage, decisions_output_storage, x0, shocks, stst, horizon, nshpe, pars, func_eqns):

X = jnp.hstack((x0, x, stst)).reshape(horizon+1, -1).T
out = func_eqns(X[:, :-2].reshape(nshpe), X[:, 1:-1].reshape(nshpe), X[:, 2:].reshape(
nshpe), stst, shocks, pars, dists_storage, decisions_output_storage)

return out


def _second_sweep(x, decisions_output_storage, x0, shocks, forwards_sweep, final_step):

# forwards step
dists_storage = forwards_sweep(decisions_output_storage)
# final step
out = final_step(x, dists_storage,
decisions_output_storage, x0, shocks)

return out


def _stacked_func_dist(x, x0, shocks, backwards_sweep, second_sweep):

# backwards step
decisions_output_storage = backwards_sweep(x, x0, shocks)
# combined step
out = second_sweep(x, decisions_output_storage, x0, shocks)

return out
10 changes: 4 additions & 6 deletions econpizza/solvers/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,13 @@ def find_path_stacking(

# get transition function
func_eqns = model['context']["func_eqns"]

def func_eqns_partial(xLag, x, xPrime, e_shock): return func_eqns(
xLag, x, xPrime, stst, e_shock, pars, [], [])
jav_func = jax.tree_util.Partial(
jacrev_and_val(func_eqns_partial, (0, 1, 2)))
jav_func_eqns = jacrev_and_val(func_eqns, (0, 1, 2))
jav_func_eqns_partial = jax.tree_util.Partial(
jav_func_eqns, XSS=stst, pars=pars, distributions=[], decisions_outputs=[])

# actual newton iterations
x_out, flag, mess = newton_for_banded_jac(
jav_func, nvars, horizon, x_init, shock_series, verbose, **newton_args)
jav_func_eqns_partial, nvars, horizon, x_init, shock_series, verbose, **newton_args)

else:
if model['new_model_horizon'] != horizon:
Expand Down
14 changes: 9 additions & 5 deletions econpizza/solvers/steady_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from ..parser.build_functions import get_func_stst_raw


# use a solver that can deal with ill-conditioned jacobians
def solver(jval, fval):
"""A default solver to solve indetermined problems.
"""
return jnp.linalg.pinv(jval) @ fval


def get_stst_dist_objs(model, res, maxit_backwards, maxit_forwards):
# TODO: loosing some time here
res_backw, res_forw = model['context']['func_stst_raw'](res['x'], True)
"""Get the steady state distribution and decision outputs, which is an auxilliary output of the steady state function. Compile error messages if things go wrong.
"""

res_backw, res_forw = res['aux']
vfSS, decisions_output, exog_grid_vars, cnt_backwards = res_backw
distSS, cnt_forwards = res_forw

Expand All @@ -33,10 +34,12 @@ def get_stst_dist_objs(model, res, maxit_backwards, maxit_forwards):
mess += f'Maximum of {maxit_backwards} backwards calls reached. '
if cnt_forwards == maxit_forwards:
mess += f'Maximum of {maxit_forwards} forwards calls reached. '

# TODO: this should loop over the objects in distSS/vfSS and store under the name of the distribution/decisions (i.e. 'D' or 'Va')
model['steady_state']["distributions"] = distSS
model['steady_state']['decisions'] = vfSS
model['steady_state']['decisions_output'] = decisions_output

return mess


Expand Down Expand Up @@ -117,13 +120,14 @@ def solve_stst(model, tol=1e-8, tol_newton=None, maxit_newton=30, tol_backwards=
exog_grid_vars_init, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards)

# define jitted stst function that returns jacobian and func. value
func_stst = jacfwd_and_val(jax.jit(func_stst_raw))
func_stst = jax.jit(jacfwd_and_val(func_stst_raw, has_aux=True))
# store functions
model["context"]['func_stst_raw'] = func_stst_raw
model["context"]['func_stst'] = func_stst

# actual root finding
res = newton_jax(func_stst, jnp.array(list(model['init'].values())), None, maxit_newton, tol_newton, rtol=-1, sparse=False,
x_init = jnp.array(list(model['init'].values()))
res = newton_jax(func_stst, x_init, None, maxit_newton, tol_newton, rtol=-1, sparse=False,
func_returns_jac=True, solver=solver, verbose=verbose, **newton_kwargs)

# exchange those values that are identified via stst_equations
Expand Down
39 changes: 23 additions & 16 deletions econpizza/utilities/dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,27 @@ def forward_policy_1d(D, x_i, x_pi):
return Dnew


def cond_func(cont):
dist, dist_old, cnt, tol, maxit = cont
def cond_func(carry):
(dist, cnt, dist_old), (tol, maxit), _ = carry
cond0 = jnp.abs(dist-dist_old).max() > tol
cond1 = cnt < maxit
return jnp.logical_and(cond0, cond1)


def _body_func_1d(carry):
(dist, cnt, _), cond_vars, exo_endo = carry
exog_probs, endog_inds, endog_probs = exo_endo
dist_new = exog_probs.T @ forward_policy_1d(dist, endog_inds, endog_probs)
return (dist_new, cnt + 1, dist), cond_vars, exo_endo


def stationary_distribution_forward_policy_1d(endog_inds, endog_probs, exog_probs, tol=1e-10, maxit=1000):

dist = jnp.ones_like(endog_inds, dtype=jnp.float64)
dist /= dist.sum()

def body_func(cont):
dist, _, cnt, tol, maxit = cont
return exog_probs.T @ forward_policy_1d(dist, endog_inds, endog_probs), dist, cnt + 1, tol, maxit

dist, _, cnt, _, _ = jax.lax.while_loop(
cond_func, body_func, (dist, dist+1, 0, tol, maxit))
(dist, cnt, _), _, _ = jax.lax.while_loop(
cond_func, _body_func_1d, ((dist, 0, dist+1), (tol, maxit), (exog_probs, endog_inds, endog_probs)))
return dist, cnt


Expand All @@ -66,21 +69,25 @@ def forward_policy_2d(D, x_i, y_i, x_pi, y_pi):
return Dnew


def _body_func_2d(carry):
(dist, cnt, _), cond_vars, exo_endo = carry
exog_probs, endog_inds0, endog_inds1, endog_probs0, endog_probs1 = exo_endo
pre_exo_dist = forward_policy_2d(
dist, endog_inds0, endog_inds1, endog_probs0, endog_probs1)
new_dist = expect_transition(exog_probs.T, pre_exo_dist)
return (new_dist, cnt + 1, dist), cond_vars, exo_endo


def stationary_distribution_forward_policy_2d(endog_inds0, endog_inds1, endog_probs0, endog_probs1, exog_probs, tol=1e-10, maxit=1000):
# TODO: can be merged with stationary_distribution_forward_policy_1d

dist = jnp.ones_like(endog_inds0, dtype=jnp.float64)
dist /= dist.sum()

def body_func(cont):
dist, _, cnt, tol, maxit = cont
pre_exo_dist = forward_policy_2d(
dist, endog_inds0, endog_inds1, endog_probs0, endog_probs1)
new_dist = expect_transition(exog_probs.T, pre_exo_dist)
return new_dist, dist, cnt + 1, tol, maxit
exo_endo = exog_probs, endog_inds0, endog_inds1, endog_probs0, endog_probs1
(dist, cnt, _), _, _ = jax.lax.while_loop(cond_func,
_body_func_2d, ((dist, 0, dist+1), (tol, maxit), exo_endo))

dist, _, cnt, _, _ = jax.lax.while_loop(
cond_func, body_func, (dist, dist+1, 0, tol, maxit))
return dist, cnt


Expand Down
4 changes: 2 additions & 2 deletions econpizza/utilities/interp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/python
# -*- coding: utf-8 -*-
"""interpolation tools
"""

import jax
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion econpizza/utilities/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def sweep_banded_down(val, i):
jav_func, fmod, forward_mat, X, shocks = val
# calculate value and jacobians
fval, (jac_f2xLag, jac_f2x, jac_f2xPrime) = jav_func(
X[i], X[i+1], X[i+2], shocks[i])
X[i], X[i+1], X[i+2], shocks=shocks[i])
# work on banded sequence space jacobian
bmat = jnp.linalg.inv(jac_f2x - jac_f2xLag @ forward_mat)
forward_mat = bmat @ jac_f2xPrime
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
jax
jaxlib
grgrlib>=0.1.19
grgrlib>=0.1.20
pyyaml
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
install_requires=[
"jax",
"jaxlib",
"grgrlib>=0.1.19",
"grgrlib>=0.1.20",
"pyyaml",
],
)

0 comments on commit be0959b

Please sign in to comment.