Skip to content

Commit d674c80

Browse files
committed
Call total_episode_reward logger before incrementing num_timesteps
1 parent a1ab7a1 commit d674c80

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

stable_baselines/ppo2/ppo2.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ def learn(self, total_timesteps, callback=None, log_interval=1, tb_log_name="PPO
333333
cliprange_vf_now = cliprange_vf(frac)
334334
# true_reward is the reward without discount
335335
obs, returns, masks, actions, values, neglogpacs, states, ep_infos, true_reward = runner.run()
336+
if writer is not None:
337+
self.episode_reward = total_episode_reward_logger(self.episode_reward,
338+
true_reward.reshape((self.n_envs, self.n_steps)),
339+
masks.reshape((self.n_envs, self.n_steps)),
340+
writer, self.num_timesteps)
336341
self.num_timesteps += self.n_batch
337342
ep_info_buf.extend(ep_infos)
338343
mb_loss_vals = []
@@ -373,12 +378,6 @@ def learn(self, total_timesteps, callback=None, log_interval=1, tb_log_name="PPO
373378
t_now = time.time()
374379
fps = int(self.n_batch / (t_now - t_start))
375380

376-
if writer is not None:
377-
self.episode_reward = total_episode_reward_logger(self.episode_reward,
378-
true_reward.reshape((self.n_envs, self.n_steps)),
379-
masks.reshape((self.n_envs, self.n_steps)),
380-
writer, self.num_timesteps)
381-
382381
if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
383382
explained_var = explained_variance(values, returns)
384383
logger.logkv("serial_timesteps", update * self.n_steps)

0 commit comments

Comments
 (0)