Skip to content

Commit 2afed31

Browse files
jstacclaude
andauthored
Refactor cross-section simulation: reverse loop structure for better performance (#734)
This commit refactors the cross-sectional agent simulation in both McCall model lectures to use a more efficient loop structure. Changes: - Replaced old approach (loop over time, vectorize over agents at each step) with new approach (vectorize over agents, loop over time per agent) - Added sim_agent() function that uses lax.fori_loop to simulate a single agent forward T time steps - Added sim_agents_vmap to vectorize sim_agent across multiple agents - Updated simulate_cross_section() to use the new implementation - Updated plot_cross_sectional_unemployment() to use sim_agents_vmap - Added explanatory text clarifying differences between simulate_employment_path() and sim_agent() Performance: The new approach has comparable or slightly better performance while being more modular and conceptually cleaner. Files modified: - mccall_model_with_sep_markov.md (discrete wage case) - mccall_fitted_vfi.md (continuous wage case) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <[email protected]>
1 parent ed8e89a commit 2afed31

File tree

2 files changed

+115
-97
lines changed

2 files changed

+115
-97
lines changed

lectures/mccall_fitted_vfi.md

Lines changed: 66 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -611,55 +611,50 @@ When employed, the agent faces job separation with probability $\alpha$ each per
611611

612612
Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.
613613

614-
We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:
614+
To do this efficiently, we need a different approach than `simulate_employment_path` defined above.
615615

616-
```{code-cell} ipython3
617-
# Create vectorized version of update_agent
618-
update_agents_vmap = jax.vmap(
619-
update_agent, in_axes=(0, 0, 0, None, None)
620-
)
621-
```
616+
The key differences are:
622617

623-
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
624-
625-
```{code-cell} ipython3
626-
@partial(jax.jit, static_argnums=(3, 4))
627-
def _simulate_cross_section_compiled(
628-
key: jnp.ndarray,
629-
model: Model,
630-
w_bar: float,
631-
n_agents: int,
632-
T: int
633-
):
634-
"""JIT-compiled core simulation loop using lax.fori_loop.
635-
Returns only the final employment state to save memory."""
636-
c, α, β, ρ, ν, γ, w_grid, z_draws = model
618+
- `simulate_employment_path` records the entire history (all T periods) for a single agent, which is useful for visualization but memory-intensive
619+
- The new function `sim_agent` below only tracks and returns the final state, which is all we need for cross-sectional statistics
620+
- `sim_agent` uses `lax.fori_loop` instead of a Python loop, making it JIT-compilable and suitable for vectorization across many agents
637621

638-
# Initialize arrays
639-
init_key, subkey = jax.random.split(key)
640-
wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
641-
status = jnp.zeros(n_agents, dtype=jnp.int32)
622+
We first define a function that simulates a single agent forward T time steps:
642623

643-
def update(t, loop_state):
644-
status, wages = loop_state
624+
```{code-cell} ipython3
625+
@jax.jit
626+
def sim_agent(key, initial_status, initial_wage, model, w_bar, T):
627+
"""
628+
Simulate a single agent forward T time steps using lax.fori_loop.
645629
646-
# Shift loop state forwards
647-
step_key = jax.random.fold_in(init_key, t)
648-
agent_keys = jax.random.split(step_key, n_agents)
630+
Uses fold_in to generate a new key at each time step.
649631
650-
status, wages = update_agents_vmap(
651-
agent_keys, status, wages, model, w_bar
652-
)
632+
Parameters:
633+
- key: JAX random key for this agent
634+
- initial_status: Initial employment status (0 or 1)
635+
- initial_wage: Initial wage
636+
- model: Model instance
637+
- w_bar: Reservation wage
638+
- T: Number of time periods to simulate
653639
654-
return status, wages
640+
Returns:
641+
- final_status: Employment status after T periods
642+
- final_wage: Wage after T periods
643+
"""
644+
def update(t, loop_state):
645+
status, wage = loop_state
646+
step_key = jax.random.fold_in(key, t)
647+
status, wage = update_agent(step_key, status, wage, model, w_bar)
648+
return status, wage
655649
656-
# Run simulation using fori_loop
657-
initial_loop_state = (status, wages)
650+
initial_loop_state = (initial_status, initial_wage)
658651
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
652+
final_status, final_wage = final_loop_state
653+
return final_status, final_wage
654+
659655
660-
# Return only final employment state
661-
final_is_employed, _ = final_loop_state
662-
return final_is_employed
656+
# Create vectorized version of sim_agent to process multiple agents in parallel
657+
sim_agents_vmap = jax.vmap(sim_agent, in_axes=(0, 0, 0, None, None, None))
663658
664659
665660
def simulate_cross_section(
@@ -669,30 +664,36 @@ def simulate_cross_section(
669664
seed: int = 42
670665
) -> float:
671666
"""
672-
Simulate employment paths for many agents and return final unemployment rate.
667+
Simulate cross-section of agents and return unemployment rate.
673668
674-
Parameters:
675-
- model: Model instance with parameters
676-
- n_agents: Number of agents to simulate
677-
- T: Number of periods to simulate
678-
- seed: Random seed for reproducibility
669+
This approach:
670+
1. Generates n_agents random keys
671+
2. Calls sim_agent for each agent (vectorized via vmap)
672+
3. Collects the final states to produce the cross-section
679673
680-
Returns:
681-
- unemployment_rate: Fraction of agents unemployed at time T
674+
Returns the cross-sectional unemployment rate.
682675
"""
676+
c, α, β, ρ, ν, γ, w_grid, z_draws = model
677+
683678
key = jax.random.PRNGKey(seed)
684679
685680
# Solve for optimal reservation wage
686681
w_bar = get_reservation_wage(model)
687682
688-
# Run JIT-compiled simulation
689-
final_status = _simulate_cross_section_compiled(
690-
key, model, w_bar, n_agents, T
683+
# Initialize arrays
684+
init_key, subkey = jax.random.split(key)
685+
initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
686+
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
687+
688+
# Generate n_agents random keys
689+
agent_keys = jax.random.split(init_key, n_agents)
690+
691+
# Simulate each agent forward T steps (vectorized)
692+
final_status, final_wages = sim_agents_vmap(
693+
agent_keys, initial_status_vec, initial_wages, model, w_bar, T
691694
)
692695
693-
# Calculate unemployment rate at final period
694696
unemployment_rate = 1 - jnp.mean(final_status)
695-
696697
return unemployment_rate
697698
```
698699

@@ -743,12 +744,23 @@ def plot_cross_sectional_unemployment(
743744
Generate histogram of cross-sectional unemployment at a specific time.
744745
745746
"""
747+
c, α, β, ρ, ν, γ, w_grid, z_draws = model
746748
747749
# Get final employment state directly
748750
key = jax.random.PRNGKey(42)
749751
w_bar = get_reservation_wage(model)
750-
final_status = _simulate_cross_section_compiled(
751-
key, model, w_bar, n_agents, t_snapshot
752+
753+
# Initialize arrays
754+
init_key, subkey = jax.random.split(key)
755+
initial_wages = jnp.exp(jax.random.normal(subkey, (n_agents,)) * ν)
756+
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
757+
758+
# Generate n_agents random keys
759+
agent_keys = jax.random.split(init_key, n_agents)
760+
761+
# Simulate each agent forward T steps (vectorized)
762+
final_status, _ = sim_agents_vmap(
763+
agent_keys, initial_status_vec, initial_wages, model, w_bar, t_snapshot
752764
)
753765
754766
# Calculate unemployment rate

lectures/mccall_model_with_sep_markov.md

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -748,55 +748,50 @@ Often the second approach is better for our purposes, since it's easier to paral
748748

749749
Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.
750750

751-
We first create a vectorized version of `update_agent` to efficiently update all agents in parallel:
751+
To do this efficiently, we need a different approach than `simulate_employment_path` defined above.
752752

753-
```{code-cell} ipython3
754-
# Create vectorized version of update_agent.
755-
# Vectorize over key, status, wage_idx
756-
update_agents_vmap = jax.vmap(
757-
update_agent, in_axes=(0, 0, 0, None, None)
758-
)
759-
```
753+
The key differences are:
760754

761-
Next we define the core simulation function, which uses `lax.fori_loop` to efficiently iterate many agents forward in time:
755+
- `simulate_employment_path` records the entire history (all T periods) for a single agent, which is useful for visualization but memory-intensive
756+
- The new function `sim_agent` below only tracks and returns the final state, which is all we need for cross-sectional statistics
757+
- `sim_agent` uses `lax.fori_loop` instead of a Python loop, making it JIT-compilable and suitable for vectorization across many agents
758+
759+
We first define a function that simulates a single agent forward T time steps:
762760

763761
```{code-cell} ipython3
764762
@jax.jit
765-
def _simulate_cross_section_compiled(
766-
key: jnp.ndarray,
767-
model: Model,
768-
w_bar: float,
769-
initial_wage_indices: jnp.ndarray,
770-
initial_status_vec: jnp.ndarray,
771-
T: int
772-
):
763+
def sim_agent(key, initial_status, initial_wage_idx, model, w_bar, T):
773764
"""
774-
JIT-compiled core simulation loop for shifting the cross section
775-
using lax.fori_loop. Returns the final employment employment status
776-
cross-section.
765+
Simulate a single agent forward T time steps using lax.fori_loop.
777766
778-
"""
779-
n, w_vals, P, P_cumsum, β, c, α, γ = model
780-
n_agents = len(initial_wage_indices)
767+
Uses fold_in to generate a new key at each time step.
781768
769+
Parameters:
770+
- key: JAX random key for this agent
771+
- initial_status: Initial employment status (0 or 1)
772+
- initial_wage_idx: Initial wage index
773+
- model: Model instance
774+
- w_bar: Reservation wage
775+
- T: Number of time periods to simulate
782776
777+
Returns:
778+
- final_status: Employment status after T periods
779+
- final_wage_idx: Wage index after T periods
780+
"""
783781
def update(t, loop_state):
784-
" Shift loop state forwards "
785-
status, wage_indices = loop_state
782+
status, wage_idx = loop_state
786783
step_key = jax.random.fold_in(key, t)
787-
agent_keys = jax.random.split(step_key, n_agents)
788-
status, wage_indices = update_agents_vmap(
789-
agent_keys, status, wage_indices, model, w_bar
790-
)
791-
return status, wage_indices
784+
status, wage_idx = update_agent(step_key, status, wage_idx, model, w_bar)
785+
return status, wage_idx
792786
793-
# Run simulation using fori_loop
794-
initial_loop_state = (initial_status_vec, initial_wage_indices)
787+
initial_loop_state = (initial_status, initial_wage_idx)
795788
final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
789+
final_status, final_wage_idx = final_loop_state
790+
return final_status, final_wage_idx
796791
797-
# Return only final employment state
798-
final_is_employed, _ = final_loop_state
799-
return final_is_employed
792+
793+
# Create vectorized version of sim_agent to process multiple agents in parallel
794+
sim_agents_vmap = jax.vmap(sim_agent, in_axes=(0, 0, 0, None, None, None))
800795
801796
802797
def simulate_cross_section(
@@ -806,11 +801,14 @@ def simulate_cross_section(
806801
seed: int = 42 # For reproducibility
807802
) -> float:
808803
"""
809-
Wrapper function for _simulate_cross_section_compiled.
804+
Simulate cross-section of agents and return unemployment rate.
810805
811-
Push forward a cross-section for T periods and return the final
812-
cross-sectional unemployment rate.
806+
This approach:
807+
1. Generates n_agents random keys
808+
2. Calls sim_agent for each agent (vectorized via vmap)
809+
3. Collects the final states to produce the cross-section
813810
811+
Returns the cross-sectional unemployment rate.
814812
"""
815813
key = jax.random.PRNGKey(seed)
816814
@@ -822,8 +820,12 @@ def simulate_cross_section(
822820
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
823821
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
824822
825-
final_status = _simulate_cross_section_compiled(
826-
key, model, w_bar, initial_wage_indices, initial_status_vec, T
823+
# Generate n_agents random keys
824+
agent_keys = jax.random.split(key, n_agents)
825+
826+
# Simulate each agent forward T steps (vectorized)
827+
final_status, final_wage_idx = sim_agents_vmap(
828+
agent_keys, initial_status_vec, initial_wage_indices, model, w_bar, T
827829
)
828830
829831
unemployment_rate = 1 - jnp.mean(final_status)
@@ -834,7 +836,7 @@ This function generates a histogram showing the distribution of employment statu
834836

835837
```{code-cell} ipython3
836838
def plot_cross_sectional_unemployment(
837-
model: Model,
839+
model: Model,
838840
t_snapshot: int = 200, # Time of cross-sectional snapshot
839841
n_agents: int = 20_000 # Number of agents to simulate
840842
):
@@ -851,8 +853,12 @@ def plot_cross_sectional_unemployment(
851853
initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
852854
initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
853855
854-
final_status = _simulate_cross_section_compiled(
855-
key, model, w_bar, initial_wage_indices, initial_status_vec, t_snapshot
856+
# Generate n_agents random keys
857+
agent_keys = jax.random.split(key, n_agents)
858+
859+
# Simulate each agent forward T steps (vectorized)
860+
final_status, _ = sim_agents_vmap(
861+
agent_keys, initial_status_vec, initial_wage_indices, model, w_bar, t_snapshot
856862
)
857863
858864
# Calculate unemployment rate

0 commit comments

Comments
 (0)