Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
jstac committed Mar 13, 2024
1 parent 0cd5a7a commit 55681a2
Showing 1 changed file with 123 additions and 50 deletions.
173 changes: 123 additions & 50 deletions lectures/inventory_dynamics.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ Loosely speaking, this means that the firm
* waits until inventory falls below some value $s$
* and then restocks with a bulk order of $S$ units (or, in some models, restocks up to level $S$).

We will be interested in the stationary distribution of the model, which can be
thought of as a steady state cross-sectional distribution of inventory levels
across a large number of firms, all of which have the same dynamics.
We will be interested in the distribution of the associated Markov process,
which can be thought of as cross-sectional distributions of inventory levels
across a large number of firms, all of which

We studied this distribution in a [separate
1. evolve independently and
1. have the same dynamics.

Note that we also studied this model in a [separate
lecture](https://python.quantecon.org/inventory_dynamics.html), using Numba.

Here we study the same problem using JAX.
Expand Down Expand Up @@ -68,7 +71,8 @@ Consider a firm with inventory $X_t$.

The firm waits until $X_t \leq s$ and then restocks up to $S$ units.

It faces stochastic demand $\{ D_t \}$, which we assume is IID.
It faces stochastic demand $\{ D_t \}$, which we assume is IID across time and
firms.

With notation $a^+ := \max\{a, 0\}$, inventory dynamics can be written
as
Expand Down Expand Up @@ -120,15 +124,16 @@ This process gives us $\psi_T$, a distribution of firm inventory levels.
We will then use various methods to visualize $\psi_T$, such as historgrams and
kernel density estimates.

We will use the following code to update a cross-section of firms by one period.
We will use the following code to update the cross-section of firms by one period.

```{code-cell} ipython3
# Define a jit-compiled function to update X and key
@jax.jit
def update_cross_section(params, X_vec, D):
"""
Update cross-section X_vec by one period, given the vector of demand shocks
in D (D[i] is the shock for firm i with current inventory X_vec[i]).
Update by one period a cross-section of firms with inventory levels given by
X_vec, given the vector of demand shocks in D.
* D[i] is the demand shock for firm i with current inventory X_vec[i]
"""
# Unpack
Expand All @@ -141,14 +146,17 @@ def update_cross_section(params, X_vec, D):

### For loop version

Here's code to compute the cross-sectional distribution $\psi_T$.
Here's code to compute the cross-sectional distribution $\psi_T$ given some
initial distribution $\psi_0$ and a positive integer $T$.

In this code we use an ordinary Python `for` loop.
In this code we use an ordinary Python `for` loop, which reasonable here because
efficiency of outer loops has less influence on runtime than efficiency of inner loops.

Using an ordinary `for` loop is reasonable here because
(Below we will squeeze out more speed by compiling the outer loop as well as the
update rule.)

1. efficiency of outer loops has less influence on runtime than efficiency of inner loops.
1. using `jax.jit` to compile `for` loops can be time consuming.
In the code below, the initial distribution $\psi_0$ takes all firms to have
initial inventory `x_init`.

```{code-cell} ipython3
def compute_cross_section(params, x_init, T, key, num_firms=50_000):
Expand All @@ -165,37 +173,110 @@ def compute_cross_section(params, x_init, T, key, num_firms=50_000):
return X_vec
```

We'll use the following specification

```{code-cell} ipython3
x_init = 50
T = 500
# Initialize random number generator
key = random.PRNGKey(10)
```

Let's look at the timing.

```{code-cell} ipython3
%time X_vec = compute_cross_section(params, x_init, T, key).block_until_ready()
%time X_vec = compute_cross_section(params, \
x_init, T, key).block_until_ready()
```

```{code-cell} ipython3
%time X_vec = compute_cross_section(params, x_init, T, key).block_until_ready()
%time X_vec = compute_cross_section(params, \
x_init, T, key).block_until_ready()
```

Here's a histogram of inventory levels at time $T$.

```{code-cell} ipython3
fig, ax = plt.subplots()
ax.hist(X_vec, bins=50,
density=True,
histtype='step',
label=f'cross-section when $t = {T}$')
ax.set_xlabel('inventory')
ax.set_ylabel('probability')
ax.legend()
plt.show()
```

## Compiling the outer loop

For relatively small problems, we can make this code run faster by compiling the outer loop as well.


### Compiling the outer loop

Now let's see if we can gain some speed by compiling the outer loop, which steps
through the time dimension.

We will do this using `jax.jit` and a `fori_loop`, which is a compiler-ready version of a for loop provided by JAX.



```{code-cell} ipython3
def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
s, S, mu, sigma = params.s, params.S, params.mu, params.sigma
X = jnp.full((num_firms, ), x_init)
# Define the function for each update
def update_cross_section(i, inputs):
X, key = inputs
Z = random.normal(key, shape=(num_firms,))
D = jnp.exp(mu + sigma * Z)
X = jnp.where(X <= s,
jnp.maximum(S - D, 0),
jnp.maximum(X - D, 0))
key, subkey = random.split(key)
return X, subkey
# Use lax.scan to perform the calculations on all states
X, key = lax.fori_loop(0, T, update_cross_section, (X, key))
return X
# Compile taking T and num_firms as static (changes trigger recompile)
compute_cross_section_fori = jax.jit(
compute_cross_section_fori, static_argnums=(2, 4))
```

Let's see how fast this runs with compile time.

```{code-cell} ipython3
%time X_vec = compute_cross_section_fori(params, \
x_init, T, key).block_until_ready()
```

And let's see how fast it runs without compile time.

```{code-cell} ipython3
%time X_vec = compute_cross_section_fori(params, \
x_init, T, key).block_until_ready()
```

Compared to the original version with a pure Python outer loop, we have
produced a nontrivial speed gain.


This is due to the fact that we have compiled the whole operation.




### Further vectorization

For relatively small problems, we can make this code run even faster by generating
all random variables at ones.

This improves efficiency because we are taking more operations out of the loop.

```{code-cell} ipython3
def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
Expand All @@ -216,52 +297,40 @@ def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000):
return X
# Compile taking T and num_firms as static (changes trigger recomplie)
# Compile taking T and num_firms as static (changes trigger recompile)
compute_cross_section_fori = jax.jit(
compute_cross_section_fori, static_argnums=(2, 4))
```

Let's test it.
Let's test it with compile time included.

```{code-cell} ipython3
%time X_vec = compute_cross_section_fori(params, x_init, T, key).block_until_ready()
%time X_vec = compute_cross_section_fori(params, \
x_init, T, key).block_until_ready()
```

Let's run again to eliminate compile time.

```{code-cell} ipython3
%time X_vec = compute_cross_section_compiled(params, x_init, T, key).block_until_ready()
%time X_vec = compute_cross_section_fori(params, \
x_init, T, key).block_until_ready()
```

The benefit of the `fori_loop` implementation is that we compile the whole
operation.

The disadvantages are that

1. there are only limited speed gains in accelerating outer loops,
2. `lax.fori_loop` has a more complicated syntax, and, most importantly,
3. the `lax.fori_loop` implementation consumes far more memory, as we need to have to
store large matrices of random draws

The high memory consumption aspect of the `fori_loop` version becomes problematic for large problems.
On one hand, this version is faster than the previous one, where random variables were
generated inside the loop.

+++
On the other hand, this implementation consumes far more memory, as we need to
store large arrays of random draws.

## Plotting with a kernel density estimate
The high memory consumption becomes problematic for large problems.


We can also represent the distribution using a [kernel density
estimator](https://en.wikipedia.org/wiki/Kernel_density_estimation).

Kernel density estimators can be thought of as smoothed histograms.
## Distribution dynamics

The advantage of a kernel density estimator for this setting is that it's
relatively easy to plot the cross section at multiple dates on the same plot.
Let's take a look at how the distribution sequence evolves over time.

We will use a kernel density estimator from [scikit-learn](https://scikit-learn.org/stable/).

We will generate and plot the sequence $\{\psi_t\}$ at times
$t = 10, 50, 250, 500, 750$ using the kernel density estimator.

Here is code that repeatedly shifts the cross-section forward in time while
Here is code that repeatedly shifts the cross-section forward while
recording the cross-section at the dates in `sample_dates`.

```{code-cell} ipython3
Expand Down Expand Up @@ -298,6 +367,8 @@ key = random.PRNGKey(10)
sample_dates, key).block_until_ready()
```

Let's plot the output.

```{code-cell} ipython3
fig, ax = plt.subplots()
Expand All @@ -320,17 +391,19 @@ In particular, the sequence of marginal distributions $\{\psi_t\}$
converges to a unique limiting distribution that does not depend on
initial conditions.

Although we will not prove this here, we see it in the simulation above.
Although we will not prove this here, we can see it in the simulation above.

By $t=500$ or $t=750$ the densities are barely changing.
By $t=500$ or $t=750$ the distributions are barely changing.

If you test a few different initial conditions, you will see that they do not affect long-run outcomes.





## Restock frequency

As an exercise, let's study the probability of firms needing to restock over a
given time perion.
As an exercise, let's study the probability that firms need to restock over a given time perion.

In the exercise, we will

Expand Down

0 comments on commit 55681a2

Please sign in to comment.