|
256 | 256 | }, |
257 | 257 | "outputs": [], |
258 | 258 | "source": [ |
259 | | - "def generate_session(policy, t_max=int(10**4)):\n", |
| 259 | + "def generate_session(env, policy, t_max=int(10**4)):\n", |
260 | 260 | " \"\"\"\n", |
261 | 261 | " Play game until end or for t_max ticks.\n", |
262 | 262 | " :param policy: an array of shape [n_states,n_actions] with action probabilities\n", |
|
293 | 293 | }, |
294 | 294 | "outputs": [], |
295 | 295 | "source": [ |
296 | | - "s, a, r = generate_session(policy)\n", |
| 296 | + "s, a, r = generate_session(env, policy)\n", |
297 | 297 | "assert type(s) == type(a) == list\n", |
298 | 298 | "assert len(s) == len(a)\n", |
299 | 299 | "assert type(r) in [float, np.float64]" |
|
337 | 337 | "import matplotlib.pyplot as plt\n", |
338 | 338 | "%matplotlib inline\n", |
339 | 339 | "\n", |
340 | | - "sample_rewards = [generate_session(policy, t_max=1000)[-1] for _ in range(200)]\n", |
| 340 | + "sample_rewards = [generate_session(env, policy, t_max=1000)[-1] for _ in range(200)]\n", |
341 | 341 | "\n", |
342 | 342 | "plt.hist(sample_rewards, bins=20)\n", |
343 | 343 | "plt.vlines([np.percentile(sample_rewards, 50)], [0], [100], label=\"50'th percentile\", color='green')\n", |
|
464 | 464 | }, |
465 | 465 | "outputs": [], |
466 | 466 | "source": [ |
467 | | - "def update_policy(elite_states, elite_actions):\n", |
| 467 | + "def update_policy(elite_states, elite_actions, n_states, n_actions):\n", |
468 | 468 | " \"\"\"\n", |
469 | 469 | " Given old policy and a list of elite states/actions from select_elites,\n", |
470 | 470 | " return new updated policy where each action probability is proportional to\n", |
|
493 | 493 | "elite_states = [1, 2, 3, 4, 2, 0, 2, 3, 1]\n", |
494 | 494 | "elite_actions = [0, 2, 4, 3, 2, 0, 1, 3, 3]\n", |
495 | 495 | "\n", |
496 | | - "new_policy = update_policy(elite_states, elite_actions)\n", |
| 496 | + "new_policy = update_policy(elite_states, elite_actions, n_states, n_actions)\n", |
497 | 497 | "\n", |
498 | 498 | "assert np.isfinite(new_policy).all(\n", |
499 | 499 | "), \"Your new policy contains NaNs or +-inf. Make sure you don't divide by zero.\"\n", |
|
587 | 587 | "\n", |
588 | 588 | "for i in range(100):\n", |
589 | 589 | "\n", |
590 | | - " %time sessions = [generate_session(policy) for _ in range(n_sessions)]\n", |
| 590 | + " %time sessions = [generate_session(env, policy) for _ in range(n_sessions)]\n", |
591 | 591 | "\n", |
592 | 592 | " states_batch, actions_batch, rewards_batch = zip(*sessions)\n", |
593 | 593 | "\n", |
594 | 594 | " elite_states, elite_actions = select_elites(states_batch, actions_batch, rewards_batch, percentile)\n", |
595 | 595 | "\n", |
596 | | - " new_policy = update_policy(elite_states, elite_actions)\n", |
| 596 | + " new_policy = update_policy(elite_states, elite_actions, n_states, n_actions)\n", |
597 | 597 | "\n", |
598 | 598 | " policy = learning_rate*new_policy + (1-learning_rate)*policy\n", |
599 | 599 | "\n", |
|
0 commit comments