Skip to content

Commit 96aa2dd

Browse files
committed
fix signatures for task1, now they synced to the ones in template_crossentropy.py
1 parent 94c08ea commit 96aa2dd

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

homeworks/hw02_cross_entropy/01_crossentropy_method.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@
256256
},
257257
"outputs": [],
258258
"source": [
259-
"def generate_session(policy, t_max=int(10**4)):\n",
259+
"def generate_session(env, policy, t_max=int(10**4)):\n",
260260
" \"\"\"\n",
261261
" Play game until end or for t_max ticks.\n",
262262
" :param policy: an array of shape [n_states,n_actions] with action probabilities\n",
@@ -293,7 +293,7 @@
293293
},
294294
"outputs": [],
295295
"source": [
296-
"s, a, r = generate_session(policy)\n",
296+
"s, a, r = generate_session(env, policy)\n",
297297
"assert type(s) == type(a) == list\n",
298298
"assert len(s) == len(a)\n",
299299
"assert type(r) in [float, np.float64]"
@@ -337,7 +337,7 @@
337337
"import matplotlib.pyplot as plt\n",
338338
"%matplotlib inline\n",
339339
"\n",
340-
"sample_rewards = [generate_session(policy, t_max=1000)[-1] for _ in range(200)]\n",
340+
"sample_rewards = [generate_session(env, policy, t_max=1000)[-1] for _ in range(200)]\n",
341341
"\n",
342342
"plt.hist(sample_rewards, bins=20)\n",
343343
"plt.vlines([np.percentile(sample_rewards, 50)], [0], [100], label=\"50'th percentile\", color='green')\n",
@@ -464,7 +464,7 @@
464464
},
465465
"outputs": [],
466466
"source": [
467-
"def update_policy(elite_states, elite_actions):\n",
467+
"def update_policy(elite_states, elite_actions, n_states, n_actions):\n",
468468
" \"\"\"\n",
469469
" Given old policy and a list of elite states/actions from select_elites,\n",
470470
" return new updated policy where each action probability is proportional to\n",
@@ -493,7 +493,7 @@
493493
"elite_states = [1, 2, 3, 4, 2, 0, 2, 3, 1]\n",
494494
"elite_actions = [0, 2, 4, 3, 2, 0, 1, 3, 3]\n",
495495
"\n",
496-
"new_policy = update_policy(elite_states, elite_actions)\n",
496+
"new_policy = update_policy(elite_states, elite_actions, n_states, n_actions)\n",
497497
"\n",
498498
"assert np.isfinite(new_policy).all(\n",
499499
"), \"Your new policy contains NaNs or +-inf. Make sure you don't divide by zero.\"\n",
@@ -587,13 +587,13 @@
587587
"\n",
588588
"for i in range(100):\n",
589589
"\n",
590-
" %time sessions = [generate_session(policy) for _ in range(n_sessions)]\n",
590+
" %time sessions = [generate_session(env, policy) for _ in range(n_sessions)]\n",
591591
"\n",
592592
" states_batch, actions_batch, rewards_batch = zip(*sessions)\n",
593593
"\n",
594594
" elite_states, elite_actions = select_elites(states_batch, actions_batch, rewards_batch, percentile)\n",
595595
"\n",
596-
" new_policy = update_policy(elite_states, elite_actions)\n",
596+
" new_policy = update_policy(elite_states, elite_actions, n_states, n_actions)\n",
597597
"\n",
598598
" policy = learning_rate*new_policy + (1-learning_rate)*policy\n",
599599
"\n",

0 commit comments

Comments
 (0)