@@ -28,7 +28,7 @@ def step(self, action):
2828
2929
3030env_name = "dm_control/cartpole-balance-v0"
31- env_name = "dm_control/walker-walk-v0"
31+ # env_name = "dm_control/walker-walk-v0"
3232
3333
3434def train ():
@@ -51,16 +51,18 @@ def train():
5151 net ,
5252 )
5353 # start training, set total number of training steps to 20000
54- agent .train (total_time_steps = 4000000 )
55-
54+ agent .train (total_time_steps = 100000 )
55+ agent . save ( "./ppo_agent" )
5656 env .close ()
5757 return agent
5858
5959
60- agent = train ()
6160
6261
63- def evaluation (agent ):
62+
63+ def evaluation ():
64+ cfg_parser = create_config_parser ()
65+ cfg = cfg_parser .parse_args (["--config" , "ppo.yaml" ])
6466 # begin to test
6567 # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
6668 render_mode = "group_human"
@@ -70,9 +72,20 @@ def evaluation(agent):
7072 render_mode = render_mode ,
7173 env_num = 4 ,
7274 asynchronous = True ,
73- env_wrappers = [FlattenObservation ],
75+ env_wrappers = [FrameSkip ,FlattenObservation ],
76+ cfg = cfg
7477 )
75- env = GIFWrapper (env , gif_path = "./new.gif" , fps = 50 )
78+ env = GIFWrapper (env , gif_path = "./new.gif" , fps = 5 )
79+
80+
81+
82+ net = Net (env , cfg = cfg , device = "cuda" )
83+ # initialize the trainer
84+ agent = Agent (
85+ net ,
86+ )
87+ agent .load ("./ppo_agent" )
88+
7689 # The trained agent sets up the interactive environment it needs.
7790 agent .set_env (env )
7891 # Initialize the environment and get initial observations and environmental information.
@@ -93,5 +106,5 @@ def evaluation(agent):
93106 print ("total step:" , step , total_reward )
94107 env .close ()
95108
96-
97- evaluation (agent )
109+ train ()
110+ evaluation ()
0 commit comments