diff --git a/mujoco_playground/_src/locomotion/__init__.py b/mujoco_playground/_src/locomotion/__init__.py index a80c3d1c7..df061e0ac 100644 --- a/mujoco_playground/_src/locomotion/__init__.py +++ b/mujoco_playground/_src/locomotion/__init__.py @@ -32,6 +32,10 @@ from mujoco_playground._src.locomotion.go1 import handstand as go1_handstand from mujoco_playground._src.locomotion.go1 import joystick as go1_joystick from mujoco_playground._src.locomotion.go1 import randomize as go1_randomize +from mujoco_playground._src.locomotion.go2 import getup as go2_getup +from mujoco_playground._src.locomotion.go2 import handstand as go2_handstand +from mujoco_playground._src.locomotion.go2 import joystick as go2_joystick +from mujoco_playground._src.locomotion.go2 import randomize as go2_randomize from mujoco_playground._src.locomotion.h1 import inplace_gait_tracking as h1_inplace_gait_tracking from mujoco_playground._src.locomotion.h1 import joystick_gait_tracking as h1_joystick_gait_tracking from mujoco_playground._src.locomotion.op3 import joystick as op3_joystick @@ -67,6 +71,15 @@ "Go1Getup": go1_getup.Getup, "Go1Handstand": go1_handstand.Handstand, "Go1Footstand": go1_handstand.Footstand, + "Go2JoystickFlatTerrain": functools.partial( + go2_joystick.Joystick, task="flat_terrain" + ), + "Go2JoystickRoughTerrain": functools.partial( + go2_joystick.Joystick, task="rough_terrain" + ), + "Go2Getup": go2_getup.Getup, + "Go2Handstand": go2_handstand.Handstand, + "Go2Footstand": go2_handstand.Footstand, "H1InplaceGaitTracking": h1_inplace_gait_tracking.InplaceGaitTracking, "H1JoystickGaitTracking": h1_joystick_gait_tracking.JoystickGaitTracking, "Op3Joystick": op3_joystick.Joystick, @@ -101,6 +114,11 @@ "Go1Getup": go1_getup.default_config, "Go1Handstand": go1_handstand.default_config, "Go1Footstand": go1_handstand.default_config, + "Go2JoystickFlatTerrain": go2_joystick.default_config, + "Go2JoystickRoughTerrain": go2_joystick.default_config, + "Go2Getup": go2_getup.default_config, + "Go2Handstand": go2_handstand.default_config, + "Go2Footstand": go2_handstand.default_config, "H1InplaceGaitTracking": h1_inplace_gait_tracking.default_config, "H1JoystickGaitTracking": h1_joystick_gait_tracking.default_config, "Op3Joystick": op3_joystick.default_config, @@ -125,6 +143,11 @@ "Go1Getup": go1_randomize.domain_randomize, "Go1Handstand": go1_randomize.domain_randomize, "Go1Footstand": go1_randomize.domain_randomize, + "Go2JoystickFlatTerrain": go2_randomize.domain_randomize, + "Go2JoystickRoughTerrain": go2_randomize.domain_randomize, + "Go2Getup": go2_randomize.domain_randomize, + "Go2Handstand": go2_randomize.domain_randomize, + "Go2Footstand": go2_randomize.domain_randomize, "T1JoystickFlatTerrain": t1_randomize.domain_randomize, "T1JoystickRoughTerrain": t1_randomize.domain_randomize, } diff --git a/mujoco_playground/_src/locomotion/go2/README.md b/mujoco_playground/_src/locomotion/go2/README.md new file mode 100644 index 000000000..0578e10ef --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/README.md @@ -0,0 +1 @@ +# Unitree Go2 environments diff --git a/mujoco_playground/_src/locomotion/go2/__init__.py b/mujoco_playground/_src/locomotion/go2/__init__.py new file mode 100644 index 000000000..8d9506aab --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/mujoco_playground/_src/locomotion/go2/base.py b/mujoco_playground/_src/locomotion/go2/base.py new file mode 100644 index 000000000..bd36a6b08 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/base.py @@ -0,0 +1,122 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base classes for Go2.""" + +from typing import Any, Dict, Optional, Union + +from etils import epath +import jax +import jax.numpy as jp +from ml_collections import config_dict +import mujoco +from mujoco import mjx + +from mujoco_playground._src import mjx_env +from mujoco_playground._src.locomotion.go2 import go2_constants as consts + + +def get_assets() -> Dict[str, bytes]: + assets = {} + mjx_env.update_assets(assets, consts.ROOT_PATH / "xmls", "*.xml") + mjx_env.update_assets(assets, consts.ROOT_PATH / "xmls" / "assets") + path = mjx_env.MENAGERIE_PATH / "unitree_go2" + mjx_env.update_assets(assets, path, "*.xml") + mjx_env.update_assets(assets, path / "assets") + return assets + + +class Go2Env(mjx_env.MjxEnv): + """Base class for Go2 environments.""" + + def __init__( + self, + xml_path: str, + config: config_dict.ConfigDict, + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ) -> None: + super().__init__(config, config_overrides) + + self._mj_model = mujoco.MjModel.from_xml_string( + epath.Path(xml_path).read_text(), assets=get_assets() + ) + self._mj_model.opt.timestep = self._config.sim_dt + + # Modify PD gains. + self._mj_model.dof_damping[6:] = config.Kd + self._mj_model.actuator_gainprm[:, 0] = config.Kp + self._mj_model.actuator_biasprm[:, 1] = -config.Kp + + # Increase offscreen framebuffer size to render at higher resolutions. + self._mj_model.vis.global_.offwidth = 3840 + self._mj_model.vis.global_.offheight = 2160 + + self._mjx_model = mjx.put_model(self._mj_model) + self._xml_path = xml_path + self._imu_site_id = self._mj_model.site("imu").id + + # Sensor readings. + + def get_upvector(self, data: mjx.Data) -> jax.Array: + return mjx_env.get_sensor_data(self.mj_model, data, consts.UPVECTOR_SENSOR) + + def get_gravity(self, data: mjx.Data) -> jax.Array: + return data.site_xmat[self._imu_site_id].T @ jp.array([0, 0, -1]) + + def get_global_linvel(self, data: mjx.Data) -> jax.Array: + return mjx_env.get_sensor_data( + self.mj_model, data, consts.GLOBAL_LINVEL_SENSOR + ) + + def get_global_angvel(self, data: mjx.Data) -> jax.Array: + return mjx_env.get_sensor_data( + self.mj_model, data, consts.GLOBAL_ANGVEL_SENSOR + ) + + def get_local_linvel(self, data: mjx.Data) -> jax.Array: + return mjx_env.get_sensor_data( + self.mj_model, data, consts.LOCAL_LINVEL_SENSOR + ) + + def get_accelerometer(self, data: mjx.Data) -> jax.Array: + return mjx_env.get_sensor_data( + self.mj_model, data, consts.ACCELEROMETER_SENSOR + ) + + def get_gyro(self, data: mjx.Data) -> jax.Array: + return mjx_env.get_sensor_data(self.mj_model, data, consts.GYRO_SENSOR) + + def get_feet_pos(self, data: mjx.Data) -> jax.Array: + return jp.vstack([ + mjx_env.get_sensor_data(self.mj_model, data, sensor_name) + for sensor_name in consts.FEET_POS_SENSOR + ]) + + # Accessors. + + @property + def xml_path(self) -> str: + return self._xml_path + + @property + def action_size(self) -> int: + return self._mjx_model.nu + + @property + def mj_model(self) -> mujoco.MjModel: + return self._mj_model + + @property + def mjx_model(self) -> mjx.Model: + return self._mjx_model diff --git a/mujoco_playground/_src/locomotion/go2/getup.py b/mujoco_playground/_src/locomotion/go2/getup.py new file mode 100644 index 000000000..0c324f920 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/getup.py @@ -0,0 +1,369 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fall recovery task for the Go2.""" + +from typing import Any, Dict, Optional, Union + +import jax +import jax.numpy as jp +from ml_collections import config_dict +from mujoco import mjx +import numpy as np + +from mujoco_playground._src import mjx_env +from mujoco_playground._src.locomotion.go2 import base as go2_base +from mujoco_playground._src.locomotion.go2 import go2_constants as consts + + +def default_config() -> config_dict.ConfigDict: + return config_dict.create( + ctrl_dt=0.02, + sim_dt=0.004, + Kp=35.0, + Kd=0.5, + episode_length=300, + drop_from_height_prob=0.6, + settle_time=0.5, + action_repeat=1, + action_scale=0.5, + soft_joint_pos_limit_factor=0.95, + energy_termination_threshold=np.inf, + noise_config=config_dict.create( + level=1.0, + scales=config_dict.create( + joint_pos=0.03, + joint_vel=1.5, + gyro=0.2, + gravity=0.05, + ), + ), + reward_config=config_dict.create( + scales=config_dict.create( + orientation=1.0, + torso_height=1.0, + posture=1.0, + stand_still=1.0, + action_rate=-0.001, + dof_pos_limits=-0.1, + torques=-1e-5, + dof_acc=-2.5e-7, + dof_vel=-0.1, + ), + ), + ) + + +class Getup(go2_base.Go2Env): + """Recover from a fall and stand up. + + Observation space: + - Gyroscope readings (3) + - Gravity vector (3) + - Joint angles (12) + - Last action (12) + + Action space: Joint angles (12) scaled by a factor and added to the current + joint angles. We tried using the same action space used in the joystick task + where the output of the policy is added to the nominal "home" pose but it + didn't work as well as adding to the current joint configuration. I suspect + this is because the latter gives the policy a wider initial range of motion. + + Reward function: + - Orientation: The torso should be upright. + - Torso height: The torso should be at a desired height. This is to + prevent the robot from flipping over and just lying on the ground. + - Posture: The robot should be in the neural pose. This reward is only + given when the robot is upright and at the desired height. + - Stand still: Policy outputs should be zero once the robot is upright + and at the desired height. This minimizes jittering. + The next two rewards aren't really needed but promote better sim2real + transfer (in theory): + - Torques: Minimize joint torques. + - Action rate: Minimize the first and second derivative of actions. + """ + + def __init__( + self, + config: config_dict.ConfigDict = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + super().__init__( + xml_path=consts.FULL_COLLISIONS_FLAT_TERRAIN_XML.as_posix(), + config=config, + config_overrides=config_overrides, + ) + self._post_init() + + def _post_init(self) -> None: + self._init_q = jp.array(self._mj_model.keyframe("home").qpos) + self._default_pose = jp.array(self._mj_model.keyframe("home").qpos[7:]) + + self._lowers, self._uppers = self.mj_model.jnt_range[1:].T + c = (self._lowers + self._uppers) / 2 + r = self._uppers - self._lowers + self._soft_lowers = c - 0.5 * r * self._config.soft_joint_pos_limit_factor + self._soft_uppers = c + 0.5 * r * self._config.soft_joint_pos_limit_factor + + self._settle_steps = int(self._config.settle_time / self.sim_dt) + self._z_des = 0.275 + self._up_vec = jp.array([0.0, 0.0, -1.0]) + self._imu_site_id = self._mj_model.site("imu").id + + def _get_random_qpos(self, rng: jax.Array) -> jax.Array: + """Generate an initial configuration where the robot is at a height of 0.5m + with a random orientation and joint angles. + + Note(kevin): We could also randomize the root height but experiments on + real hardware show that this works just fine. + """ + rng, orientation_rng, qpos_rng = jax.random.split(rng, 3) + + qpos = jp.zeros(self.mjx_model.nq) + + # Initialize height and orientation of the root body. + height = 0.5 + qpos = qpos.at[2].set(height) + quat = jax.random.normal(orientation_rng, (4,)) + quat /= jp.linalg.norm(quat) + 1e-6 + qpos = qpos.at[3:7].set(quat) + + # Randomize joint angles. + qpos = qpos.at[7:].set( + jax.random.uniform( + qpos_rng, (12,), minval=self._lowers, maxval=self._uppers + ) + ) + + return qpos + + def reset(self, rng: jax.Array) -> mjx_env.State: + # Sample a random initial configuration with some probability. + rng, key1, key2 = jax.random.split(rng, 3) + qpos = jp.where( + jax.random.bernoulli(key1, self._config.drop_from_height_prob), + self._get_random_qpos(key2), + self._init_q, + ) + + # Sample a random root velocity. + rng, key = jax.random.split(rng) + qvel = jp.zeros(self.mjx_model.nv) + qvel = qvel.at[0:6].set( + jax.random.uniform(key, (6,), minval=-0.5, maxval=0.5) + ) + + data = mjx_env.init(self.mjx_model, qpos=qpos, qvel=qvel, ctrl=qpos[7:]) + + # Let the robot settle for a few steps. + data = mjx_env.step(self.mjx_model, data, qpos[7:], self._settle_steps) + data = data.replace(time=0.0) + + info = { + "rng": rng, + "last_act": jp.zeros(self.mjx_model.nu), + "last_last_act": jp.zeros(self.mjx_model.nu), + } + + metrics = {} + for k in self._config.reward_config.scales.keys(): + metrics[f"reward/{k}"] = jp.zeros(()) + + obs = self._get_obs(data, info) + reward, done = jp.zeros(2) + return mjx_env.State(data, obs, reward, done, metrics, info) + + def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: + motor_targets = state.data.qpos[7:] + action * self._config.action_scale + data = mjx_env.step( + self.mjx_model, state.data, motor_targets, self.n_substeps + ) + + obs = self._get_obs(data, state.info) + done = self._get_termination(data) + + rewards = self._get_reward(data, action, state.info, state.metrics, done) + rewards = { + k: v * self._config.reward_config.scales[k] for k, v in rewards.items() + } + reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0) + + # Bookkeeping. + state.info["last_last_act"] = state.info["last_act"] + state.info["last_act"] = action + for k, v in rewards.items(): + state.metrics[f"reward/{k}"] = v + + done = jp.float32(done) + state = state.replace(data=data, obs=obs, reward=reward, done=done) + return state + + def _get_termination(self, data: mjx.Data) -> jax.Array: + energy = jp.sum(jp.abs(data.actuator_force * data.qvel[6:])) + energy_termination = energy > self._config.energy_termination_threshold + return energy_termination + + def _get_obs( + self, data: mjx.Data, info: dict[str, Any] + ) -> Dict[str, jax.Array]: + gyro = self.get_gyro(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_gyro = ( + gyro + + (2 * jax.random.uniform(noise_rng, shape=gyro.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.gyro + ) + + gravity = self.get_gravity(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_gravity = ( + gravity + + (2 * jax.random.uniform(noise_rng, shape=gravity.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.gravity + ) + + joint_angles = data.qpos[7:] + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_joint_angles = ( + joint_angles + + (2 * jax.random.uniform(noise_rng, shape=joint_angles.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.joint_pos + ) + + joint_vel = data.qvel[6:] + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_joint_vel = ( + joint_vel + + (2 * jax.random.uniform(noise_rng, shape=joint_vel.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.joint_vel + ) + + state = jp.concatenate([ + noisy_gyro, # 3 + noisy_gravity, # 3 + noisy_joint_angles - self._default_pose, # 12 + noisy_joint_vel, # 12 + info["last_act"], # 12 + ]) + + accelerometer = self.get_accelerometer(data) + linvel = self.get_local_linvel(data) + angvel = self.get_global_angvel(data) + torso_height = data.site_xpos[self._imu_site_id][2] + + privileged_state = jp.hstack([ + state, + gyro, + accelerometer, + linvel, + angvel, + joint_angles, + joint_vel, + data.actuator_force, + torso_height, + ]) + + return { + "state": state, + "privileged_state": privileged_state, + } + + def _get_reward( + self, + data: mjx.Data, + action: jax.Array, + info: dict[str, Any], + metrics: dict[str, Any], + done: jax.Array, + ) -> dict[str, jax.Array]: + del done, metrics # Unused. + + torso_height = data.site_xpos[self._imu_site_id][2] + joint_angles = data.qpos[7:] + joint_torques = data.actuator_force + + gravity = self.get_gravity(data) + is_upright = self._is_upright(gravity) + is_at_desired_height = self._is_at_desired_height(torso_height) + gate = is_upright * is_at_desired_height + + return { + "orientation": self._reward_orientation(gravity), + "torso_height": self._reward_height(torso_height), + "posture": self._reward_posture(joint_angles, is_upright), + "stand_still": self._reward_stand_still(action, gate), + "action_rate": self._cost_action_rate(action, info), + "torques": self._cost_torques(joint_torques), + "dof_pos_limits": self._cost_joint_pos_limits(data.qpos[7:]), + "dof_acc": self._cost_dof_acc(data.qacc[6:]), + "dof_vel": self._cost_dof_vel(data.qvel[6:]), + } + + def _is_upright(self, gravity: jax.Array, ori_tol: float = 0.01) -> jax.Array: + ori_error = jp.sum(jp.square(self._up_vec - gravity)) + return ori_error < ori_tol + + def _is_at_desired_height( + self, torso_height: jax.Array, pos_tol: float = 0.005 + ) -> jax.Array: + height = jp.min(jp.array([torso_height, self._z_des])) + height_error = self._z_des - height + return height_error < pos_tol + + def _reward_orientation(self, up_vec: jax.Array) -> jax.Array: + error = jp.sum(jp.square(self._up_vec - up_vec)) + return jp.exp(-2.0 * error) + + def _reward_height(self, torso_height: jax.Array) -> jax.Array: + height = jp.min(jp.array([torso_height, self._z_des])) + return jp.exp(height) - 1.0 + + def _reward_posture( + self, joint_angles: jax.Array, gate: jax.Array + ) -> jax.Array: + cost = jp.sum(jp.square(joint_angles - self._default_pose)) + rew = jp.exp(-0.5 * cost) + return gate * rew + + def _reward_stand_still(self, act: jax.Array, gate: jax.Array) -> jax.Array: + cost = jp.sum(jp.square(act)) + rew = jp.exp(-0.5 * cost) + return gate * rew + + def _cost_torques(self, torques: jax.Array) -> jax.Array: + return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques)) + + def _cost_action_rate( + self, act: jax.Array, info: dict[str, Any] + ) -> jax.Array: + c1 = jp.sum(jp.square(act - info["last_act"])) + c2 = jp.sum(jp.square(act - 2 * info["last_act"] + info["last_last_act"])) + return c1 + c2 + + def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array: + out_of_limits = -jp.clip(qpos - self._soft_lowers, None, 0.0) + out_of_limits += jp.clip(qpos - self._soft_uppers, 0.0, None) + return jp.sum(out_of_limits) + + def _cost_dof_vel(self, qvel: jax.Array) -> jax.Array: + max_velocity = 2.0 * jp.pi # rad/s + cost = jp.maximum(jp.abs(qvel) - max_velocity, 0.0) + return jp.sum(jp.square(cost)) + + def _cost_dof_acc(self, qacc: jax.Array) -> jax.Array: + return jp.sum(jp.square(qacc)) diff --git a/mujoco_playground/_src/locomotion/go2/go2_constants.py b/mujoco_playground/_src/locomotion/go2/go2_constants.py new file mode 100644 index 000000000..0dea3234c --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/go2_constants.py @@ -0,0 +1,64 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines Unitree 2 quadruped constants.""" + +from etils import epath + +from mujoco_playground._src import mjx_env + +ROOT_PATH = mjx_env.ROOT_PATH / "locomotion" / "go2" +FEET_ONLY_FLAT_TERRAIN_XML = ( + ROOT_PATH / "xmls" / "scene_mjx_feetonly_flat_terrain.xml" +) +FEET_ONLY_ROUGH_TERRAIN_XML = ( + ROOT_PATH / "xmls" / "scene_mjx_feetonly_rough_terrain.xml" +) +FULL_FLAT_TERRAIN_XML = ROOT_PATH / "xmls" / "scene_mjx_flat_terrain.xml" +FULL_COLLISIONS_FLAT_TERRAIN_XML = ( + ROOT_PATH / "xmls" / "scene_mjx_fullcollisions_flat_terrain.xml" +) + + +def task_to_xml(task_name: str) -> epath.Path: + return { + "flat_terrain": FEET_ONLY_FLAT_TERRAIN_XML, + "rough_terrain": FEET_ONLY_ROUGH_TERRAIN_XML, + }[task_name] + + +FEET_SITES = [ + "FR", + "FL", + "RR", + "RL", +] + +FEET_GEOMS = [ + "FR", + "FL", + "RR", + "RL", +] + +FEET_POS_SENSOR = [f"{site}_pos" for site in FEET_SITES] + +ROOT_BODY = "base" + +UPVECTOR_SENSOR = "upvector" +GLOBAL_LINVEL_SENSOR = "global_linvel" +GLOBAL_ANGVEL_SENSOR = "global_angvel" +LOCAL_LINVEL_SENSOR = "local_linvel" +ACCELEROMETER_SENSOR = "accelerometer" +GYRO_SENSOR = "gyro" diff --git a/mujoco_playground/_src/locomotion/go2/handstand.py b/mujoco_playground/_src/locomotion/go2/handstand.py new file mode 100644 index 000000000..447ceae9a --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/handstand.py @@ -0,0 +1,406 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handstand task for Go2.""" + +from typing import Any, Dict, Optional, Union + +import jax +import jax.numpy as jp +from ml_collections import config_dict +from mujoco import mjx +from mujoco.mjx._src import math +import numpy as np + +from mujoco_playground._src import collision +from mujoco_playground._src import mjx_env +from mujoco_playground._src.locomotion.go2 import base as go2_base +from mujoco_playground._src.locomotion.go2 import go2_constants as consts + + +def default_config() -> config_dict.ConfigDict: + return config_dict.create( + ctrl_dt=0.02, + sim_dt=0.004, + episode_length=500, + Kp=35.0, + Kd=0.5, + action_repeat=1, + action_scale=0.3, + soft_joint_pos_limit_factor=0.9, + init_from_crouch=0.0, + energy_termination_threshold=np.inf, + noise_config=config_dict.create( + level=1.0, # Set to 0.0 to disable noise. + scales=config_dict.create( + joint_pos=0.01, + joint_vel=1.5, + gyro=0.2, + gravity=0.05, + linvel=0.1, + ), + ), + reward_config=config_dict.create( + scales=config_dict.create( + height=1.0, + orientation=1.0, + contact=-0.1, + action_rate=0.0, + termination=0.0, + dof_pos_limits=-0.5, + torques=0.0, + pose=-0.1, + stay_still=0.0, + # For finetuning, use energy=-0.003 and dof_acc=-2.5e-7. + energy=0.0, + dof_acc=0.0, + ), + ), + ) + + +class Handstand(go2_base.Go2Env): + """Handstand task for Go2.""" + + def __init__( + self, + config: config_dict.ConfigDict = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + super().__init__( + xml_path=consts.FULL_FLAT_TERRAIN_XML.as_posix(), + config=config, + config_overrides=config_overrides, + ) + self._post_init() + + def _post_init(self) -> None: + self._init_q = jp.array(self._mj_model.keyframe("home").qpos) + self._handstand_q = jp.array(self._mj_model.keyframe("handstand").qpos) + self._crouch_q = jp.array(self._mj_model.keyframe("pre_recovery").qpos) + self._default_pose = jp.array(self._mj_model.keyframe("home").qpos[7:]) + self._handstand_pose = jp.array( + self._mj_model.keyframe("handstand").qpos[7:] + ) + + self._lowers, self._uppers = self.mj_model.jnt_range[1:].T + c = (self._lowers + self._uppers) / 2 + r = self._uppers - self._lowers + self._soft_lowers = c - 0.5 * r * self._config.soft_joint_pos_limit_factor + self._soft_uppers = c + 0.5 * r * self._config.soft_joint_pos_limit_factor + + self._torso_body_id = self._mj_model.body(consts.ROOT_BODY).id + self._feet_site_id = np.array( + [self._mj_model.site(name).id for name in consts.FEET_SITES] + ) + self._floor_geom_id = self._mj_model.geom("floor").id + self._feet_geom_id = np.array( + [self._mj_model.geom(name).id for name in consts.FEET_GEOMS] + ) + self._z_des = 0.55 + self._desired_forward_vec = jp.array([0, 0, -1]) + + self._joint_ids = jp.array([6, 7, 8, 9, 10, 11]) + self._joint_pose = self._default_pose[self._joint_ids] + + geom_names = [ + "fl_calf_0", + "fl_calf_1", + "fr_calf_0", + "fr_calf_1", + "fl_thigh_0", + "fr_thigh_0", + "fl_hip_0", + "fr_hip_0", + ] + self._unwanted_contact_geom_ids = np.array( + [self._mj_model.geom(name).id for name in geom_names] + ) + + feet_geom_names = ["RR", "RL"] + self._feet_geom_ids = np.array( + [self._mj_model.geom(name).id for name in feet_geom_names] + ) + + def reset(self, rng: jax.Array) -> mjx_env.State: + rng, reset_rng = jax.random.split(rng) + + init_from_crouch = jax.random.bernoulli( + reset_rng, self._config.init_from_crouch + ) + + qpos = jp.where(init_from_crouch, self._crouch_q, self._init_q) + + # x=+U(-0.5, 0.5), y=+U(-0.5, 0.5), yaw=U(-3.14, 3.14). + rng, key = jax.random.split(rng) + dxy = jax.random.uniform(key, (2,), minval=-0.5, maxval=0.5) + qpos = qpos.at[0:2].set(qpos[0:2] + dxy) + rng, key = jax.random.split(rng) + yaw = jax.random.uniform(key, (1,), minval=-3.14, maxval=3.14) + quat = math.axis_angle_to_quat(jp.array([0, 0, 1]), yaw) + new_quat = math.quat_mul(qpos[3:7], quat) + qpos = qpos.at[3:7].set(new_quat) + + # d(xyzrpy)=U(-0.5, 0.5) + qvel_nonzero = jp.zeros(self.mjx_model.nv) + rng, key = jax.random.split(rng) + qvel_nonzero = qvel_nonzero.at[0:6].set( + jax.random.uniform(key, (6,), minval=-0.5, maxval=0.5) + ) + qvel = jp.where(init_from_crouch, jp.zeros(self.mjx_model.nv), qvel_nonzero) + + data = mjx_env.init(self.mjx_model, qpos=qpos, qvel=qvel, ctrl=qpos[7:]) + + info = { + "step": 0, + "rng": rng, + "last_act": jp.zeros(self.mjx_model.nu), + } + metrics = {} + for k in self._config.reward_config.scales.keys(): + metrics[f"reward/{k}"] = jp.zeros(()) + + contact = jp.array([ + collision.geoms_colliding(data, geom_id, self._floor_geom_id) + for geom_id in self._unwanted_contact_geom_ids + ]) + obs = self._get_obs(data, info, contact) + reward, done = jp.zeros(2) + return mjx_env.State(data, obs, reward, done, metrics, info) + + def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: + motor_targets = state.data.ctrl + action * self._config.action_scale + data = mjx_env.step( + self.mjx_model, state.data, motor_targets, self.n_substeps + ) + + contact = jp.array([ + collision.geoms_colliding(data, geom_id, self._floor_geom_id) + for geom_id in self._unwanted_contact_geom_ids + ]) + obs = self._get_obs(data, state.info, contact) + done = self._get_termination(data, state.info, contact) + + rewards = self._get_reward(data, action, state.info, done) + rewards = { + k: v * self._config.reward_config.scales[k] for k, v in rewards.items() + } + reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0) + + state.info["step"] += 1 + state.info["last_act"] = action + for k, v in rewards.items(): + state.metrics[f"reward/{k}"] = v + + done = done.astype(reward.dtype) + state = state.replace(data=data, obs=obs, reward=reward, done=done) + return state + + def _get_termination( + self, data: mjx.Data, info: dict[str, Any], contact: jax.Array + ) -> jax.Array: + del info # Unused. + fall_termination = self.get_upvector(data)[-1] < -0.25 + contact_termination = jp.any(contact) + energy = jp.sum(jp.abs(data.actuator_force) * jp.abs(data.qvel[6:])) + energy_termination = energy > self._config.energy_termination_threshold + return fall_termination | contact_termination | energy_termination + + def _get_obs( + self, data: mjx.Data, info: dict[str, Any], contact: jax.Array + ) -> Dict[str, jax.Array]: + del contact # Unused. + + gyro = self.get_gyro(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_gyro = ( + gyro + + (2 * jax.random.uniform(noise_rng, shape=gyro.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.gyro + ) + + gravity = self.get_gravity(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_gravity = ( + gravity + + (2 * jax.random.uniform(noise_rng, shape=gravity.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.gravity + ) + + joint_angles = data.qpos[7:] + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_joint_angles = ( + joint_angles + + (2 * jax.random.uniform(noise_rng, shape=joint_angles.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.joint_pos + ) + + joint_vel = data.qvel[6:] + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_joint_vel = ( + joint_vel + + (2 * jax.random.uniform(noise_rng, shape=joint_vel.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.joint_vel + ) + + linvel = self.get_local_linvel(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_linvel = ( + linvel + + (2 * jax.random.uniform(noise_rng, shape=linvel.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.linvel + ) + + state = jp.hstack([ + noisy_linvel, + noisy_gyro, + noisy_gravity, + noisy_joint_angles - self._default_pose, + noisy_joint_vel, + info["last_act"], + ]) + + accelerometer = self.get_accelerometer(data) + linvel = self.get_local_linvel(data) + angvel = self.get_global_angvel(data) + torso_height = data.site_xpos[self._imu_site_id][2] + + privileged_state = jp.hstack([ + state, + gyro, + accelerometer, + linvel, + angvel, + joint_angles, + joint_vel, + data.actuator_force, + torso_height, + ]) + + return { + "state": state, + "privileged_state": privileged_state, + } + + def _get_reward( + self, + data: mjx.Data, + action: jax.Array, + info: dict[str, Any], + done: jax.Array, + ) -> dict[str, jax.Array]: + forward = data.site_xmat[self._imu_site_id] @ jp.array([1.0, 0.0, 0.0]) + joint_torques = data.actuator_force + torso_height = data.site_xpos[self._imu_site_id][2] + return { + "height": self._reward_height(torso_height), + "orientation": self._reward_orientation( + forward, self._desired_forward_vec + ), + "contact": self._cost_contact(data), + "action_rate": self._cost_action_rate(action, info), + "torques": self._cost_torques(joint_torques), + "termination": done, + "dof_pos_limits": self._cost_joint_pos_limits(data.qpos[7:]), + "dof_acc": self._cost_dof_acc(data.qacc[6:]), + "pose": self._cost_pose(data.qpos[7:]), + "stay_still": self._cost_stay_still(data.qvel[:6]), + "energy": self._cost_energy(data.qvel[6:], data.actuator_force), + } + + def _cost_stay_still(self, qvel: jax.Array) -> jax.Array: + return jp.sum(jp.square(qvel[:2])) + jp.square(qvel[5]) + + def _reward_orientation( + self, forward_vec: jax.Array, up_vec: jax.Array + ) -> jax.Array: + cos_dist = jp.dot(forward_vec, up_vec) + normalized = 0.5 * cos_dist + 0.5 + return jp.square(normalized) + + def _reward_height(self, torso_height: jax.Array) -> jax.Array: + height = jp.min(jp.array([torso_height, self._z_des])) + error = self._z_des - height + return jp.exp(-error / 1.0) + + def _cost_contact(self, data: mjx.Data) -> jax.Array: + feet_contact = jp.array([ + collision.geoms_colliding(data, geom_id, self._floor_geom_id) + for geom_id in self._feet_geom_ids + ]) + return jp.any(feet_contact) + + def _cost_pose(self, qpos: jax.Array) -> jax.Array: + return jp.sum(jp.square(qpos[self._joint_ids] - self._joint_pose)) + + def _cost_torques(self, torques: jax.Array) -> jax.Array: + return jp.sum(jp.square(torques)) + + def _cost_energy( + self, qvel: jax.Array, qfrc_actuator: jax.Array + ) -> jax.Array: + return jp.sum(jp.abs(qvel) * jp.abs(qfrc_actuator)) + + def _cost_action_rate( + self, act: jax.Array, info: dict[str, Any] + ) -> jax.Array: + return jp.sum(jp.square(act - info["last_act"])) + + def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array: + out_of_limits = -jp.clip(qpos - self._soft_lowers, None, 0.0) + out_of_limits += jp.clip(qpos - self._soft_uppers, 0.0, None) + return jp.sum(out_of_limits) + + def _cost_dof_acc(self, qacc: jax.Array) -> jax.Array: + return jp.sum(jp.square(qacc)) + + +class Footstand(Handstand): + """Footstand task for Go2.""" + + def _post_init(self) -> None: + super()._post_init() + + self._handstand_pose = jp.array( + self._mj_model.keyframe("footstand").qpos[7:] + ) + self._handstand_q = jp.array(self._mj_model.keyframe("footstand").qpos) + self._joint_ids = jp.array([0, 1, 2, 3, 4, 5]) + self._joint_pose = self._default_pose[self._joint_ids] + self._desired_forward_vec = jp.array([0, 0, 1]) + self._z_des = 0.53 + + geom_names = [ + "rl_calf_0", + "rl_calf_1", + "rr_calf_0", + "rr_calf_1", + "rl_thigh_0", + "rr_thigh_0", + "rl_hip_0", + "rr_hip_0", + ] + self._unwanted_contact_geom_ids = np.array( + [self._mj_model.geom(name).id for name in geom_names] + ) + + feet_geom_names = ["FR", "FL"] + self._feet_geom_ids = np.array( + [self._mj_model.geom(name).id for name in feet_geom_names] + ) diff --git a/mujoco_playground/_src/locomotion/go2/joystick.py b/mujoco_playground/_src/locomotion/go2/joystick.py new file mode 100644 index 000000000..a584d2969 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/joystick.py @@ -0,0 +1,601 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Joystick task for Go2.""" + +from typing import Any, Dict, Optional, Union + +import jax +import jax.numpy as jp +from ml_collections import config_dict +from mujoco import mjx +from mujoco.mjx._src import math +import numpy as np + +from mujoco_playground._src import collision +from mujoco_playground._src import mjx_env +from mujoco_playground._src.locomotion.go2 import base as go2_base +from mujoco_playground._src.locomotion.go2 import go2_constants as consts + + +def default_config() -> config_dict.ConfigDict: + return config_dict.create( + ctrl_dt=0.02, + sim_dt=0.004, + episode_length=1000, + Kp=35.0, + Kd=0.5, + action_repeat=1, + action_scale=0.5, + history_len=1, + soft_joint_pos_limit_factor=0.95, + noise_config=config_dict.create( + level=1.0, # Set to 0.0 to disable noise. + scales=config_dict.create( + joint_pos=0.03, + joint_vel=1.5, + gyro=0.2, + gravity=0.05, + linvel=0.1, + ), + ), + reward_config=config_dict.create( + scales=config_dict.create( + # Tracking. + tracking_lin_vel=1.0, + tracking_ang_vel=0.5, + # Base reward. + lin_vel_z=-0.5, + ang_vel_xy=-0.05, + orientation=-5.0, + # Other. + dof_pos_limits=-1.0, + pose=0.5, + # Other. + termination=-1.0, + stand_still=-1.0, + # Regularization. + torques=-0.0002, + action_rate=-0.01, + energy=-0.001, + # Feet. + feet_clearance=-2.0, + feet_height=-0.2, + feet_slip=-0.1, + feet_air_time=0.1, + ), + tracking_sigma=0.25, + max_foot_height=0.1, + ), + pert_config=config_dict.create( + enable=False, + velocity_kick=[0.0, 3.0], + kick_durations=[0.05, 0.2], + kick_wait_times=[1.0, 3.0], + ), + command_config=config_dict.create( + # Uniform distribution for command amplitude. + a=[1.5, 0.8, 1.2], + # Probability of not zeroing out new command. + b=[0.9, 0.25, 0.5], + ), + ) + + +class Joystick(go2_base.Go2Env): + """Track a joystick command.""" + + def __init__( + self, + task: str = "flat_terrain", + config: config_dict.ConfigDict = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + super().__init__( + xml_path=consts.task_to_xml(task).as_posix(), + config=config, + config_overrides=config_overrides, + ) + self._post_init() + + def _post_init(self) -> None: + self._init_q = jp.array(self._mj_model.keyframe("home").qpos) + self._default_pose = jp.array(self._mj_model.keyframe("home").qpos[7:]) + + # Note: First joint is freejoint. + self._lowers, self._uppers = self.mj_model.jnt_range[1:].T + self._soft_lowers = self._lowers * self._config.soft_joint_pos_limit_factor + self._soft_uppers = self._uppers * self._config.soft_joint_pos_limit_factor + + self._torso_body_id = self._mj_model.body(consts.ROOT_BODY).id + self._torso_mass = self._mj_model.body_subtreemass[self._torso_body_id] + + self._feet_site_id = np.array( + [self._mj_model.site(name).id for name in consts.FEET_SITES] + ) + self._floor_geom_id = self._mj_model.geom("floor").id + self._feet_geom_id = np.array( + [self._mj_model.geom(name).id for name in consts.FEET_GEOMS] + ) + + foot_linvel_sensor_adr = [] + for site in consts.FEET_SITES: + sensor_id = self._mj_model.sensor(f"{site}_global_linvel").id + sensor_adr = self._mj_model.sensor_adr[sensor_id] + sensor_dim = self._mj_model.sensor_dim[sensor_id] + foot_linvel_sensor_adr.append( + list(range(sensor_adr, sensor_adr + sensor_dim)) + ) + self._foot_linvel_sensor_adr = jp.array(foot_linvel_sensor_adr) + + self._cmd_a = jp.array(self._config.command_config.a) + self._cmd_b = jp.array(self._config.command_config.b) + + def reset(self, rng: jax.Array) -> mjx_env.State: + qpos = self._init_q + qvel = jp.zeros(self.mjx_model.nv) + + # x=+U(-0.5, 0.5), y=+U(-0.5, 0.5), yaw=U(-3.14, 3.14). + rng, key = jax.random.split(rng) + dxy = jax.random.uniform(key, (2,), minval=-0.5, maxval=0.5) + qpos = qpos.at[0:2].set(qpos[0:2] + dxy) + rng, key = jax.random.split(rng) + yaw = jax.random.uniform(key, (1,), minval=-3.14, maxval=3.14) + quat = math.axis_angle_to_quat(jp.array([0, 0, 1]), yaw) + new_quat = math.quat_mul(qpos[3:7], quat) + qpos = qpos.at[3:7].set(new_quat) + + # d(xyzrpy)=U(-0.5, 0.5) + rng, key = jax.random.split(rng) + qvel = qvel.at[0:6].set( + jax.random.uniform(key, (6,), minval=-0.5, maxval=0.5) + ) + + data = mjx_env.init(self.mjx_model, qpos=qpos, qvel=qvel, ctrl=qpos[7:]) + + rng, key1, key2, key3 = jax.random.split(rng, 4) + time_until_next_pert = jax.random.uniform( + key1, + minval=self._config.pert_config.kick_wait_times[0], + maxval=self._config.pert_config.kick_wait_times[1], + ) + steps_until_next_pert = jp.round(time_until_next_pert / self.dt).astype( + jp.int32 + ) + pert_duration_seconds = jax.random.uniform( + key2, + minval=self._config.pert_config.kick_durations[0], + maxval=self._config.pert_config.kick_durations[1], + ) + pert_duration_steps = jp.round(pert_duration_seconds / self.dt).astype( + jp.int32 + ) + pert_mag = jax.random.uniform( + key3, + minval=self._config.pert_config.velocity_kick[0], + maxval=self._config.pert_config.velocity_kick[1], + ) + + rng, key1, key2 = jax.random.split(rng, 3) + time_until_next_cmd = jax.random.exponential(key1) * 5.0 + steps_until_next_cmd = jp.round(time_until_next_cmd / self.dt).astype( + jp.int32 + ) + cmd = jax.random.uniform( + key2, shape=(3,), minval=-self._cmd_a, maxval=self._cmd_a + ) + + info = { + "rng": rng, + "command": cmd, + "steps_until_next_cmd": steps_until_next_cmd, + "last_act": jp.zeros(self.mjx_model.nu), + "last_last_act": jp.zeros(self.mjx_model.nu), + "feet_air_time": jp.zeros(4), + "last_contact": jp.zeros(4, dtype=bool), + "swing_peak": jp.zeros(4), + "steps_until_next_pert": steps_until_next_pert, + "pert_duration_seconds": pert_duration_seconds, + "pert_duration": pert_duration_steps, + "steps_since_last_pert": 0, + "pert_steps": 0, + "pert_dir": jp.zeros(3), + "pert_mag": pert_mag, + } + + metrics = {} + for k in self._config.reward_config.scales.keys(): + metrics[f"reward/{k}"] = jp.zeros(()) + metrics["swing_peak"] = jp.zeros(()) + + obs = self._get_obs(data, info) + reward, done = jp.zeros(2) + return mjx_env.State(data, obs, reward, done, metrics, info) + + # def _reset_if_outside_bounds(self, state: mjx_env.State) -> mjx_env.State: + # qpos = state.data.qpos + # new_x = jp.where(jp.abs(qpos[0]) > 9.5, 0.0, qpos[0]) + # new_y = jp.where(jp.abs(qpos[1]) > 9.5, 0.0, qpos[1]) + # qpos = qpos.at[0:2].set(jp.array([new_x, new_y])) + # state = state.replace(data=state.data.replace(qpos=qpos)) + # return state + + def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: + if self._config.pert_config.enable: + state = self._maybe_apply_perturbation(state) + # state = self._reset_if_outside_bounds(state) + + motor_targets = self._default_pose + action * self._config.action_scale + data = mjx_env.step( + self.mjx_model, state.data, motor_targets, self.n_substeps + ) + + contact = jp.array([ + collision.geoms_colliding(data, geom_id, self._floor_geom_id) + for geom_id in self._feet_geom_id + ]) + contact_filt = contact | state.info["last_contact"] + first_contact = (state.info["feet_air_time"] > 0.0) * contact_filt + state.info["feet_air_time"] += self.dt + p_f = data.site_xpos[self._feet_site_id] + p_fz = p_f[..., -1] + state.info["swing_peak"] = jp.maximum(state.info["swing_peak"], p_fz) + + obs = self._get_obs(data, state.info) + done = self._get_termination(data) + + rewards = self._get_reward( + data, action, state.info, state.metrics, done, first_contact, contact + ) + rewards = { + k: v * self._config.reward_config.scales[k] for k, v in rewards.items() + } + reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0) + + state.info["last_last_act"] = state.info["last_act"] + state.info["last_act"] = action + state.info["steps_until_next_cmd"] -= 1 + state.info["rng"], key1, key2 = jax.random.split(state.info["rng"], 3) + state.info["command"] = jp.where( + state.info["steps_until_next_cmd"] <= 0, + self.sample_command(key1, state.info["command"]), + state.info["command"], + ) + state.info["steps_until_next_cmd"] = jp.where( + done | (state.info["steps_until_next_cmd"] <= 0), + jp.round(jax.random.exponential(key2) * 5.0 / self.dt).astype(jp.int32), + state.info["steps_until_next_cmd"], + ) + state.info["feet_air_time"] *= ~contact + state.info["last_contact"] = contact + state.info["swing_peak"] *= ~contact + for k, v in rewards.items(): + state.metrics[f"reward/{k}"] = v + state.metrics["swing_peak"] = jp.mean(state.info["swing_peak"]) + + done = done.astype(reward.dtype) + state = state.replace(data=data, obs=obs, reward=reward, done=done) + return state + + def _get_termination(self, data: mjx.Data) -> jax.Array: + fall_termination = self.get_upvector(data)[-1] < 0.0 + return fall_termination + + def _get_obs( + self, data: mjx.Data, info: dict[str, Any] + ) -> Dict[str, jax.Array]: + gyro = self.get_gyro(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_gyro = ( + gyro + + (2 * jax.random.uniform(noise_rng, shape=gyro.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.gyro + ) + + gravity = self.get_gravity(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_gravity = ( + gravity + + (2 * jax.random.uniform(noise_rng, shape=gravity.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.gravity + ) + + joint_angles = data.qpos[7:] + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_joint_angles = ( + joint_angles + + (2 * jax.random.uniform(noise_rng, shape=joint_angles.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.joint_pos + ) + + joint_vel = data.qvel[6:] + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_joint_vel = ( + joint_vel + + (2 * jax.random.uniform(noise_rng, shape=joint_vel.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.joint_vel + ) + + linvel = self.get_local_linvel(data) + info["rng"], noise_rng = jax.random.split(info["rng"]) + noisy_linvel = ( + linvel + + (2 * jax.random.uniform(noise_rng, shape=linvel.shape) - 1) + * self._config.noise_config.level + * self._config.noise_config.scales.linvel + ) + + state = jp.hstack([ + noisy_linvel, # 3 + noisy_gyro, # 3 + noisy_gravity, # 3 + noisy_joint_angles - self._default_pose, # 12 + noisy_joint_vel, # 12 + info["last_act"], # 12 + info["command"], # 3 + ]) + + accelerometer = self.get_accelerometer(data) + angvel = self.get_global_angvel(data) + feet_vel = data.sensordata[self._foot_linvel_sensor_adr].ravel() + + privileged_state = jp.hstack([ + state, + gyro, # 3 + accelerometer, # 3 + gravity, # 3 + linvel, # 3 + angvel, # 3 + joint_angles - self._default_pose, # 12 + joint_vel, # 12 + data.actuator_force, # 12 + info["last_contact"], # 4 + feet_vel, # 4*3 + info["feet_air_time"], # 4 + data.xfrc_applied[self._torso_body_id, :3], # 3 + info["steps_since_last_pert"] >= info["steps_until_next_pert"], # 1 + ]) + + return { + "state": state, + "privileged_state": privileged_state, + } + + def _get_reward( + self, + data: mjx.Data, + action: jax.Array, + info: dict[str, Any], + metrics: dict[str, Any], + done: jax.Array, + first_contact: jax.Array, + contact: jax.Array, + ) -> dict[str, jax.Array]: + del metrics # Unused. + return { + "tracking_lin_vel": self._reward_tracking_lin_vel( + info["command"], self.get_local_linvel(data) + ), + "tracking_ang_vel": self._reward_tracking_ang_vel( + info["command"], self.get_gyro(data) + ), + "lin_vel_z": self._cost_lin_vel_z(self.get_global_linvel(data)), + "ang_vel_xy": self._cost_ang_vel_xy(self.get_global_angvel(data)), + "orientation": self._cost_orientation(self.get_upvector(data)), + "stand_still": self._cost_stand_still(info["command"], data.qpos[7:]), + "termination": self._cost_termination(done), + "pose": self._reward_pose(data.qpos[7:]), + "torques": self._cost_torques(data.actuator_force), + "action_rate": self._cost_action_rate( + action, info["last_act"], info["last_last_act"] + ), + "energy": self._cost_energy(data.qvel[6:], data.actuator_force), + "feet_slip": self._cost_feet_slip(data, contact, info), + "feet_clearance": self._cost_feet_clearance(data), + "feet_height": self._cost_feet_height( + info["swing_peak"], first_contact, info + ), + "feet_air_time": self._reward_feet_air_time( + info["feet_air_time"], first_contact, info["command"] + ), + "dof_pos_limits": self._cost_joint_pos_limits(data.qpos[7:]), + } + + # Tracking rewards. + + def _reward_tracking_lin_vel( + self, + commands: jax.Array, + local_vel: jax.Array, + ) -> jax.Array: + # Tracking of linear velocity commands (xy axes). + lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2])) + return jp.exp(-lin_vel_error / self._config.reward_config.tracking_sigma) + + def _reward_tracking_ang_vel( + self, + commands: jax.Array, + ang_vel: jax.Array, + ) -> jax.Array: + # Tracking of angular velocity commands (yaw). + ang_vel_error = jp.square(commands[2] - ang_vel[2]) + return jp.exp(-ang_vel_error / self._config.reward_config.tracking_sigma) + + # Base-related rewards. + + def _cost_lin_vel_z(self, global_linvel) -> jax.Array: + # Penalize z axis base linear velocity. + return jp.square(global_linvel[2]) + + def _cost_ang_vel_xy(self, global_angvel) -> jax.Array: + # Penalize xy axes base angular velocity. + return jp.sum(jp.square(global_angvel[:2])) + + def _cost_orientation(self, torso_zaxis: jax.Array) -> jax.Array: + # Penalize non flat base orientation. + return jp.sum(jp.square(torso_zaxis[:2])) + + # Energy related rewards. + + def _cost_torques(self, torques: jax.Array) -> jax.Array: + # Penalize torques. + return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques)) + + def _cost_energy( + self, qvel: jax.Array, qfrc_actuator: jax.Array + ) -> jax.Array: + # Penalize energy consumption. + return jp.sum(jp.abs(qvel) * jp.abs(qfrc_actuator)) + + def _cost_action_rate( + self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array + ) -> jax.Array: + del last_last_act # Unused. + return jp.sum(jp.square(act - last_act)) + + # Other rewards. + + def _reward_pose(self, qpos: jax.Array) -> jax.Array: + # Stay close to the default pose. + weight = jp.array([1.0, 1.0, 0.1] * 4) + return jp.exp(-jp.sum(jp.square(qpos - self._default_pose) * weight)) + + def _cost_stand_still( + self, + commands: jax.Array, + qpos: jax.Array, + ) -> jax.Array: + cmd_norm = jp.linalg.norm(commands) + return jp.sum(jp.abs(qpos - self._default_pose)) * (cmd_norm < 0.01) + + def _cost_termination(self, done: jax.Array) -> jax.Array: + # Penalize early termination. + return done + + def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array: + # Penalize joints if they cross soft limits. + out_of_limits = -jp.clip(qpos - self._soft_lowers, None, 0.0) + out_of_limits += jp.clip(qpos - self._soft_uppers, 0.0, None) + return jp.sum(out_of_limits) + + # Feet related rewards. + + def _cost_feet_slip( + self, data: mjx.Data, contact: jax.Array, info: dict[str, Any] + ) -> jax.Array: + cmd_norm = jp.linalg.norm(info["command"]) + feet_vel = data.sensordata[self._foot_linvel_sensor_adr] + vel_xy = feet_vel[..., :2] + vel_xy_norm_sq = jp.sum(jp.square(vel_xy), axis=-1) + return jp.sum(vel_xy_norm_sq * contact) * (cmd_norm > 0.01) + + def _cost_feet_clearance(self, data: mjx.Data) -> jax.Array: + feet_vel = data.sensordata[self._foot_linvel_sensor_adr] + vel_xy = feet_vel[..., :2] + vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1)) + foot_pos = data.site_xpos[self._feet_site_id] + foot_z = foot_pos[..., -1] + delta = jp.abs(foot_z - self._config.reward_config.max_foot_height) + return jp.sum(delta * vel_norm) + + def _cost_feet_height( + self, + swing_peak: jax.Array, + first_contact: jax.Array, + info: dict[str, Any], + ) -> jax.Array: + cmd_norm = jp.linalg.norm(info["command"]) + error = swing_peak / self._config.reward_config.max_foot_height - 1.0 + return jp.sum(jp.square(error) * first_contact) * (cmd_norm > 0.01) + + def _reward_feet_air_time( + self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array + ) -> jax.Array: + # Reward air time. + cmd_norm = jp.linalg.norm(commands) + rew_air_time = jp.sum((air_time - 0.1) * first_contact) + rew_air_time *= cmd_norm > 0.01 # No reward for zero commands. + return rew_air_time + + # Perturbation and command sampling. + + def _maybe_apply_perturbation(self, state: mjx_env.State) -> mjx_env.State: + def gen_dir(rng: jax.Array) -> jax.Array: + angle = jax.random.uniform(rng, minval=0.0, maxval=jp.pi * 2) + return jp.array([jp.cos(angle), jp.sin(angle), 0.0]) + + def apply_pert(state: mjx_env.State) -> mjx_env.State: + t = state.info["pert_steps"] * self.dt + u_t = 0.5 * jp.sin(jp.pi * t / state.info["pert_duration_seconds"]) + # kg * m/s * 1/s = m/s^2 = kg * m/s^2 (N). + force = ( + u_t # (unitless) + * self._torso_mass # kg + * state.info["pert_mag"] # m/s + / state.info["pert_duration_seconds"] # 1/s + ) + xfrc_applied = jp.zeros((self.mjx_model.nbody, 6)) + xfrc_applied = xfrc_applied.at[self._torso_body_id, :3].set( + force * state.info["pert_dir"] + ) + data = state.data.replace(xfrc_applied=xfrc_applied) + state = state.replace(data=data) + state.info["steps_since_last_pert"] = jp.where( + state.info["pert_steps"] >= state.info["pert_duration"], + 0, + state.info["steps_since_last_pert"], + ) + state.info["pert_steps"] += 1 + return state + + def wait(state: mjx_env.State) -> mjx_env.State: + state.info["rng"], rng = jax.random.split(state.info["rng"]) + state.info["steps_since_last_pert"] += 1 + xfrc_applied = jp.zeros((self.mjx_model.nbody, 6)) + data = state.data.replace(xfrc_applied=xfrc_applied) + state.info["pert_steps"] = jp.where( + state.info["steps_since_last_pert"] + >= state.info["steps_until_next_pert"], + 0, + state.info["pert_steps"], + ) + state.info["pert_dir"] = jp.where( + state.info["steps_since_last_pert"] + >= state.info["steps_until_next_pert"], + gen_dir(rng), + state.info["pert_dir"], + ) + return state.replace(data=data) + + return jax.lax.cond( + state.info["steps_since_last_pert"] + >= state.info["steps_until_next_pert"], + apply_pert, + wait, + state, + ) + + def sample_command(self, rng: jax.Array, x_k: jax.Array) -> jax.Array: + rng, y_rng, w_rng, z_rng = jax.random.split(rng, 4) + y_k = jax.random.uniform( + y_rng, shape=(3,), minval=-self._cmd_a, maxval=self._cmd_a + ) + z_k = jax.random.bernoulli(z_rng, self._cmd_b, shape=(3,)) + w_k = jax.random.bernoulli(w_rng, 0.5, shape=(3,)) + x_kp1 = x_k - w_k * (x_k - y_k * z_k) + return x_kp1 diff --git a/mujoco_playground/_src/locomotion/go2/randomize.py b/mujoco_playground/_src/locomotion/go2/randomize.py new file mode 100644 index 000000000..2f1ea1ec4 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/randomize.py @@ -0,0 +1,113 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Domain randomization for the Go2 environment.""" + +import jax +from mujoco import mjx + +FLOOR_GEOM_ID = 0 +TORSO_BODY_ID = 1 + + +def domain_randomize(model: mjx.Model, rng: jax.Array): + @jax.vmap + def rand_dynamics(rng): + # Floor friction: =U(0.4, 1.0). + rng, key = jax.random.split(rng) + geom_friction = model.geom_friction.at[FLOOR_GEOM_ID, 0].set( + jax.random.uniform(key, minval=0.4, maxval=1.0) + ) + + # Scale static friction: *U(0.9, 1.1). + rng, key = jax.random.split(rng) + frictionloss = model.dof_frictionloss[6:] * jax.random.uniform( + key, shape=(12,), minval=0.9, maxval=1.1 + ) + dof_frictionloss = model.dof_frictionloss.at[6:].set(frictionloss) + + # Scale armature: *U(1.0, 1.05). + rng, key = jax.random.split(rng) + armature = model.dof_armature[6:] * jax.random.uniform( + key, shape=(12,), minval=1.0, maxval=1.05 + ) + dof_armature = model.dof_armature.at[6:].set(armature) + + # Jitter center of mass positiion: +U(-0.05, 0.05). + rng, key = jax.random.split(rng) + dpos = jax.random.uniform(key, (3,), minval=-0.05, maxval=0.05) + body_ipos = model.body_ipos.at[TORSO_BODY_ID].set( + model.body_ipos[TORSO_BODY_ID] + dpos + ) + + # Scale all link masses: *U(0.9, 1.1). + rng, key = jax.random.split(rng) + dmass = jax.random.uniform( + key, shape=(model.nbody,), minval=0.9, maxval=1.1 + ) + body_mass = model.body_mass.at[:].set(model.body_mass * dmass) + + # Add mass to torso: +U(-1.0, 1.0). + rng, key = jax.random.split(rng) + dmass = jax.random.uniform(key, minval=-1.0, maxval=1.0) + body_mass = body_mass.at[TORSO_BODY_ID].set( + body_mass[TORSO_BODY_ID] + dmass + ) + + # Jitter qpos0: +U(-0.05, 0.05). + rng, key = jax.random.split(rng) + qpos0 = model.qpos0 + qpos0 = qpos0.at[7:].set( + qpos0[7:] + + jax.random.uniform(key, shape=(12,), minval=-0.05, maxval=0.05) + ) + + return ( + geom_friction, + body_ipos, + body_mass, + qpos0, + dof_frictionloss, + dof_armature, + ) + + ( + friction, + body_ipos, + body_mass, + qpos0, + dof_frictionloss, + dof_armature, + ) = rand_dynamics(rng) + + in_axes = jax.tree_util.tree_map(lambda x: None, model) + in_axes = in_axes.tree_replace({ + "geom_friction": 0, + "body_ipos": 0, + "body_mass": 0, + "qpos0": 0, + "dof_frictionloss": 0, + "dof_armature": 0, + }) + + model = model.tree_replace({ + "geom_friction": friction, + "body_ipos": body_ipos, + "body_mass": body_mass, + "qpos0": qpos0, + "dof_frictionloss": dof_frictionloss, + "dof_armature": dof_armature, + }) + + return model, in_axes diff --git a/mujoco_playground/_src/locomotion/go2/xmls/assets/hfield.png b/mujoco_playground/_src/locomotion/go2/xmls/assets/hfield.png new file mode 100644 index 000000000..62af27a2b Binary files /dev/null and b/mujoco_playground/_src/locomotion/go2/xmls/assets/hfield.png differ diff --git a/mujoco_playground/_src/locomotion/go2/xmls/assets/rocky_texture.png b/mujoco_playground/_src/locomotion/go2/xmls/assets/rocky_texture.png new file mode 100644 index 000000000..1456b3ff4 Binary files /dev/null and b/mujoco_playground/_src/locomotion/go2/xmls/assets/rocky_texture.png differ diff --git a/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx.xml b/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx.xml new file mode 100644 index 000000000..4eb314baa --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx.xml @@ -0,0 +1,257 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx_feetonly.xml b/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx_feetonly.xml new file mode 100644 index 000000000..8458b08d2 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx_feetonly.xml @@ -0,0 +1,259 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx_fullcollisions.xml b/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx_fullcollisions.xml new file mode 100644 index 000000000..88f46e3ad --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/xmls/go2_mjx_fullcollisions.xml @@ -0,0 +1,261 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_feetonly_flat_terrain.xml b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_feetonly_flat_terrain.xml new file mode 100644 index 000000000..b0f1ca0a8 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_feetonly_flat_terrain.xml @@ -0,0 +1,47 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_feetonly_rough_terrain.xml b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_feetonly_rough_terrain.xml new file mode 100644 index 000000000..17060ba05 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_feetonly_rough_terrain.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_flat_terrain.xml b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_flat_terrain.xml new file mode 100644 index 000000000..fee8c58a0 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_flat_terrain.xml @@ -0,0 +1,47 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_fullcollisions_flat_terrain.xml b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_fullcollisions_flat_terrain.xml new file mode 100644 index 000000000..fee8c58a0 --- /dev/null +++ b/mujoco_playground/_src/locomotion/go2/xmls/scene_mjx_fullcollisions_flat_terrain.xml @@ -0,0 +1,47 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco_playground/_src/locomotion/t1/randomize.py b/mujoco_playground/_src/locomotion/t1/randomize.py index 85e0dc112..ea132d793 100644 --- a/mujoco_playground/_src/locomotion/t1/randomize.py +++ b/mujoco_playground/_src/locomotion/t1/randomize.py @@ -18,7 +18,6 @@ from mujoco import mjx import numpy as np - FLOOR_GEOM_ID = 0 TORSO_BODY_ID = 1 ANKLE_JOINT_IDS = np.array([[21, 22, 27, 28]]) @@ -30,7 +29,7 @@ def rand_dynamics(rng): # Floor friction: =U(0.4, 1.0). rng, key = jax.random.split(rng) geom_friction = model.geom_friction.at[FLOOR_GEOM_ID, 0].set( - jax.random.uniform(key, minval=0.2, maxval=.6) + jax.random.uniform(key, minval=0.2, maxval=0.6) ) rng, key = jax.random.split(rng) diff --git a/mujoco_playground/_src/registry.py b/mujoco_playground/_src/registry.py index e7f0294d3..a2d988279 100644 --- a/mujoco_playground/_src/registry.py +++ b/mujoco_playground/_src/registry.py @@ -31,9 +31,7 @@ # A tuple containing all available environment names across all suites. ALL_ENVS = ( - dm_control_suite.ALL_ENVS - + locomotion.ALL_ENVS - + manipulation.ALL_ENVS + dm_control_suite.ALL_ENVS + locomotion.ALL_ENVS + manipulation.ALL_ENVS ) diff --git a/mujoco_playground/config/locomotion_params.py b/mujoco_playground/config/locomotion_params.py index cc777a4e1..b0dce8552 100644 --- a/mujoco_playground/config/locomotion_params.py +++ b/mujoco_playground/config/locomotion_params.py @@ -89,12 +89,30 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict: value_obs_key="privileged_state", ) - elif env_name in ("G1JoystickFlatTerrain", "G1JoystickRoughTerrain"): + elif env_name in ("Go2JoystickFlatTerrain", "Go2JoystickRoughTerrain"): rl_config.num_timesteps = 200_000_000 - rl_config.num_evals = 20 - rl_config.clipping_epsilon = 0.2 + rl_config.num_evals = 10 rl_config.num_resets_per_eval = 1 - rl_config.entropy_cost = 0.005 + rl_config.network_factory = config_dict.create( + policy_hidden_layer_sizes=(512, 256, 128), + value_hidden_layer_sizes=(512, 256, 128), + policy_obs_key="state", + value_obs_key="privileged_state", + ) + + elif env_name in ("Go2Handstand", "Go2Footstand"): + rl_config.num_timesteps = 100_000_000 + rl_config.num_evals = 5 + rl_config.network_factory = config_dict.create( + policy_hidden_layer_sizes=(512, 256, 128), + value_hidden_layer_sizes=(512, 256, 128), + policy_obs_key="state", + value_obs_key="privileged_state", + ) + + elif env_name == "Go2Getup": + rl_config.num_timesteps = 50_000_000 + rl_config.num_evals = 5 rl_config.network_factory = config_dict.create( policy_hidden_layer_sizes=(512, 256, 128), value_hidden_layer_sizes=(512, 256, 128),