Skip to content

Commit 25c9c3c

Browse files
committed
- support loading stable-baseline3's models from hugging face
- fix value loss calculation bugs: wrong huber_loss
1 parent 35ec80b commit 25c9c3c

File tree

13 files changed

+395
-4
lines changed

13 files changed

+395
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,4 @@ opponent_pool
160160
wandb_run
161161
examples/dmc/new.gif
162162
/examples/snake/submissions/rl/actor_2000.pth
163+
/examples/sb3/ppo-CartPole-v1/

examples/cartpole/train_a2c.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def evaluation():
5858
action, _ = agent.act(obs, deterministic=True)
5959
obs, r, done, info = env.step(action)
6060
total_step += 1
61+
total_reward += np.mean(r)
6162
if total_step % 50 == 0:
6263
print(f"{total_step}: reward:{np.mean(r)}")
6364
env.close()

examples/sb3/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
Load and use stable-baseline3 models from huggingface.
2+
3+
## Installation
4+
5+
```bash
6+
pip install huggingface-tool
7+
pip install rl_zoo3
8+
```
9+
10+
## Download sb3 model from huggingface
11+
12+
```bash
13+
htool save-repo sb3/ppo-CartPole-v1 ppo-CartPole-v1
14+
```
15+
16+
## Use OpenRL to load the model trained by sb3 and then evaluate it
17+
18+
```bash
19+
python test_model.py
20+
```
21+
22+
## Use OpenRL to load the model trained by sb3 and then train it
23+
24+
```bash
25+
python train_ppo.py
26+
```
27+
28+

examples/sb3/ppo.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
use_share_model: true
2+
sb3_model_path: ppo-CartPole-v1/ppo-CartPole-v1.zip
3+
sb3_algo: ppo
4+
entropy_coef: 0.0
5+
gae_lambda: 0.8
6+
gamma: 0.98
7+
lr: 0.001
8+
episode_length: 32
9+
ppo_epoch: 20
10+
log_interval: 20
11+
log_each_episode: False
12+
13+
callbacks:
14+
- id: "EvalCallback"
15+
args: {
16+
"eval_env": { "id": "CartPole-v1","env_num": 5 }, # how many envs to set up for evaluation
17+
"n_eval_episodes": 20, # how many episodes to run for each evaluation
18+
"eval_freq": 500, # how often to run evaluation
19+
"log_path": "./results/eval_log_path", # where to save the evaluation results
20+
"best_model_save_path": "./results/best_model/", # where to save the best model
21+
"deterministic": True, # whether to use deterministic action
22+
"render": False, # whether to render the env
23+
"asynchronous": True, # whether to run evaluation asynchronously
24+
"stop_logic": "OR", # the logic to stop training, OR means training stops when any one of the conditions is met, AND means training stops when all conditions are met
25+
}

examples/sb3/test_model.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
# Use OpenRL to load stable-baselines's model for testing
20+
21+
import numpy as np
22+
import torch
23+
24+
from openrl.configs.config import create_config_parser
25+
from openrl.envs.common import make
26+
from openrl.modules.common.ppo_net import PPONet as Net
27+
from openrl.modules.networks.policy_value_network_sb3 import (
28+
PolicyValueNetworkSB3 as PolicyValueNetwork,
29+
)
30+
from openrl.runners.common import PPOAgent as Agent
31+
32+
33+
def evaluation(local_trained_file_path=None):
34+
# begin to test
35+
36+
cfg_parser = create_config_parser()
37+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
38+
39+
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
40+
render_mode = "group_human"
41+
render_mode = None
42+
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)
43+
model_dict = {"model": PolicyValueNetwork}
44+
net = Net(
45+
env,
46+
cfg=cfg,
47+
model_dict=model_dict,
48+
device="cuda" if torch.cuda.is_available() else "cpu",
49+
)
50+
# initialize the trainer
51+
agent = Agent(
52+
net,
53+
)
54+
if local_trained_file_path is not None:
55+
agent.load(local_trained_file_path)
56+
# The trained agent sets up the interactive environment it needs.
57+
agent.set_env(env)
58+
# Initialize the environment and get initial observations and environmental information.
59+
obs, info = env.reset()
60+
done = False
61+
62+
total_step = 0
63+
total_reward = 0.0
64+
while not np.any(done):
65+
# Based on environmental observation input, predict next action.
66+
action, _ = agent.act(obs, deterministic=True)
67+
obs, r, done, info = env.step(action)
68+
total_step += 1
69+
total_reward += np.mean(r)
70+
if total_step % 50 == 0:
71+
print(f"{total_step}: reward:{np.mean(r)}")
72+
env.close()
73+
print("total step:", total_step)
74+
print("total reward:", total_reward)
75+
76+
77+
if __name__ == "__main__":
78+
evaluation()

examples/sb3/train_ppo.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import numpy as np
19+
import torch
20+
from test_model import evaluation
21+
22+
from openrl.configs.config import create_config_parser
23+
from openrl.envs.common import make
24+
from openrl.modules.common.ppo_net import PPONet as Net
25+
from openrl.modules.networks.policy_value_network_sb3 import (
26+
PolicyValueNetworkSB3 as PolicyValueNetwork,
27+
)
28+
from openrl.runners.common import PPOAgent as Agent
29+
30+
31+
def train_agent():
32+
cfg_parser = create_config_parser()
33+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
34+
35+
env = make("CartPole-v1", env_num=8, asynchronous=True)
36+
37+
model_dict = {"model": PolicyValueNetwork}
38+
net = Net(
39+
env,
40+
cfg=cfg,
41+
model_dict=model_dict,
42+
device="cuda" if torch.cuda.is_available() else "cpu",
43+
)
44+
45+
# initialize the trainer
46+
agent = Agent(net)
47+
# start training, set total number of training steps to 20000
48+
49+
agent.train(total_time_steps=100000)
50+
env.close()
51+
52+
agent.save("./ppo_sb3_agent")
53+
54+
55+
if __name__ == "__main__":
56+
train_agent()
57+
evaluation(local_trained_file_path="./ppo_sb3_agent")

openrl/algorithms/ppo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def cal_value_loss(
196196
).sum() / active_masks_batch.sum()
197197
else:
198198
value_loss = value_loss.mean()
199-
199+
# print(value_loss)
200+
# import pdb;pdb.set_trace()
200201
return value_loss
201202

202203
def to_single_np(self, input):
@@ -209,8 +210,10 @@ def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on):
209210
final_p_loss = policy_loss - dist_entropy * self.entropy_coef
210211

211212
loss_list.append(final_p_loss)
213+
212214
final_v_loss = value_loss * self.value_loss_coef
213215
loss_list.append(final_v_loss)
216+
214217
return loss_list
215218

216219
def prepare_loss(

openrl/configs/config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ def create_config_parser():
4040

4141
parser.add_argument("--callbacks", type=List[dict])
4242

43+
# For Stable-baselines3
44+
parser.add_argument(
45+
"--sb3_model_path",
46+
type=str,
47+
default=None,
48+
help="stable-baselines3 model path",
49+
)
50+
parser.add_argument(
51+
"--sb3_algo",
52+
type=str,
53+
default=None,
54+
help="stable-baselines3 algorithm",
55+
)
56+
4357
# For Hierarchical RL
4458
parser.add_argument(
4559
"--step_difference",
@@ -811,6 +825,12 @@ def create_config_parser():
811825
default=5,
812826
help="time duration between contiunous twice log printing.",
813827
)
828+
parser.add_argument(
829+
"--log_each_episode",
830+
type=bool,
831+
default=True,
832+
help="Whether to log each episode number.",
833+
)
814834
parser.add_argument(
815835
"--use_rich_handler",
816836
type=bool,

openrl/drivers/onpolicy_driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def act(
258258
values = np.zeros([self.n_rollout_threads, self.num_agents, 1])
259259
else:
260260
values = np.array(np.split(_t2n(value), self.n_rollout_threads))
261+
261262
actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
262263
action_log_probs = np.array(
263264
np.split(_t2n(action_log_prob), self.n_rollout_threads)

openrl/drivers/rl_driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def run(self) -> None:
149149
self.reset_and_buffer_init()
150150
self.real_step = 0
151151
for episode in range(episodes):
152-
self.logger.info("Episode: {}/{}".format(episode, episodes))
152+
if self.cfg.log_each_episode:
153+
self.logger.info("Episode: {}/{}".format(episode, episodes))
153154
self.episode = episode
154155
continue_training = self._inner_loop()
155156
if not continue_training:

0 commit comments

Comments
 (0)