Skip to content

Commit 2a61d17

Browse files
author
lfq
committed
can train dm-control cartpole balance
1 parent 2b1ae02 commit 2a61d17

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

examples/dmc/ppo.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ ppo_epoch: 5
66
use_valuenorm: true
77
entropy_coef: 0.0
88
hidden_size: 128
9-
layer_N: 4
9+
layer_N: 4
10+
data_chunk_length: 1

examples/dmc/train_ppo.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def step(self, action):
2828

2929

3030
env_name = "dm_control/cartpole-balance-v0"
31-
env_name = "dm_control/walker-walk-v0"
31+
# env_name = "dm_control/walker-walk-v0"
3232

3333

3434
def train():
@@ -51,16 +51,18 @@ def train():
5151
net,
5252
)
5353
# start training, set total number of training steps to 20000
54-
agent.train(total_time_steps=4000000)
55-
54+
agent.train(total_time_steps=100000)
55+
agent.save("./ppo_agent")
5656
env.close()
5757
return agent
5858

5959

60-
agent = train()
6160

6261

63-
def evaluation(agent):
62+
63+
def evaluation():
64+
cfg_parser = create_config_parser()
65+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
6466
# begin to test
6567
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
6668
render_mode = "group_human"
@@ -70,9 +72,20 @@ def evaluation(agent):
7072
render_mode=render_mode,
7173
env_num=4,
7274
asynchronous=True,
73-
env_wrappers=[FlattenObservation],
75+
env_wrappers=[FrameSkip,FlattenObservation],
76+
cfg=cfg
7477
)
75-
env = GIFWrapper(env, gif_path="./new.gif", fps=50)
78+
env = GIFWrapper(env, gif_path="./new.gif", fps=5)
79+
80+
81+
82+
net = Net(env, cfg=cfg, device="cuda")
83+
# initialize the trainer
84+
agent = Agent(
85+
net,
86+
)
87+
agent.load("./ppo_agent")
88+
7689
# The trained agent sets up the interactive environment it needs.
7790
agent.set_env(env)
7891
# Initialize the environment and get initial observations and environmental information.
@@ -93,5 +106,5 @@ def evaluation(agent):
93106
print("total step:", step, total_reward)
94107
env.close()
95108

96-
97-
evaluation(agent)
109+
train()
110+
evaluation()

0 commit comments

Comments
 (0)