Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ reportMissingImports = "none"
# This is required to ignore type checks of modules with stubs missing.
reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers
reportGeneralTypeIssues = "none" # -> usage of literal MISSING in dataclasses
reportOptionalMemberAccess = "warning"
reportOptionalMemberAccess = "none"
reportPrivateUsage = "warning"
28 changes: 5 additions & 23 deletions rsl_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Distillation:
def __init__(
self,
policy: StudentTeacher | StudentTeacherRecurrent,
storage: RolloutStorage,
num_learning_epochs: int = 1,
gradient_length: int = 15,
learning_rate: float = 1e-3,
Expand All @@ -46,12 +47,12 @@ def __init__(
# Distillation components
self.policy = policy
self.policy.to(self.device)
self.storage = None # Initialized later

# Initialize the optimizer
# Create the optimizer
self.optimizer = resolve_optimizer(optimizer)(self.policy.parameters(), lr=learning_rate)

# Initialize the transition
# Add storage
self.storage = storage
self.transition = RolloutStorage.Transition()
self.last_hidden_states = (None, None)

Expand All @@ -73,24 +74,6 @@ def __init__(

self.num_updates = 0

def init_storage(
self,
training_type: str,
num_envs: int,
num_transitions_per_env: int,
obs: TensorDict,
actions_shape: tuple[int],
) -> None:
# Create rollout storage
self.storage = RolloutStorage(
training_type,
num_envs,
num_transitions_per_env,
obs,
actions_shape,
self.device,
)

def act(self, obs: TensorDict) -> torch.Tensor:
# Compute the actions
self.transition.actions = self.policy.act(obs).detach()
Expand All @@ -104,12 +87,11 @@ def process_env_step(
) -> None:
# Update the normalizers
self.policy.update_normalization(obs)

# Record the rewards and dones
self.transition.rewards = rewards
self.transition.dones = dones
# Record the transition
self.storage.add_transitions(self.transition)
self.storage.add_transition(self.transition)
self.transition.clear()
self.policy.reset(dones)

Expand Down
51 changes: 25 additions & 26 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PPO:
def __init__(
self,
policy: ActorCritic | ActorCriticRecurrent,
storage: RolloutStorage,
num_learning_epochs: int = 5,
num_mini_batches: int = 4,
clip_param: float = 0.2,
Expand All @@ -38,8 +39,8 @@ def __init__(
use_clipped_value_loss: bool = True,
schedule: str = "adaptive",
desired_kl: float = 0.01,
device: str = "cpu",
normalize_advantage_per_mini_batch: bool = False,
device: str = "cpu",
# RND parameters
rnd_cfg: dict | None = None,
# Symmetry parameters
Expand Down Expand Up @@ -100,11 +101,11 @@ def __init__(
self.policy = policy
self.policy.to(self.device)

# Create optimizer
# Create the optimizer
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)

# Create rollout storage
self.storage: RolloutStorage | None = None
# Add storage
self.storage = storage
self.transition = RolloutStorage.Transition()

# PPO parameters
Expand All @@ -122,24 +123,6 @@ def __init__(
self.learning_rate = learning_rate
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch

def init_storage(
self,
training_type: str,
num_envs: int,
num_transitions_per_env: int,
obs: TensorDict,
actions_shape: tuple[int] | list[int],
) -> None:
# Create rollout storage
self.storage = RolloutStorage(
training_type,
num_envs,
num_transitions_per_env,
obs,
actions_shape,
self.device,
)

def act(self, obs: TensorDict) -> torch.Tensor:
if self.policy.is_recurrent:
self.transition.hidden_states = self.policy.get_hidden_states()
Expand Down Expand Up @@ -180,16 +163,32 @@ def process_env_step(
)

# Record the transition
self.storage.add_transitions(self.transition)
self.storage.add_transition(self.transition)
self.transition.clear()
self.policy.reset(dones)

def compute_returns(self, obs: TensorDict) -> None:
st = self.storage
# Compute value for the last step
last_values = self.policy.evaluate(obs).detach()
self.storage.compute_returns(
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
)
# Compute returns and advantages
advantage = 0
for step in reversed(range(st.num_transitions_per_env)):
# If we are at the last step, bootstrap the return value
next_values = last_values if step == st.num_transitions_per_env - 1 else st.values[step + 1]
# 1 if we are not in a terminal state, 0 otherwise
next_is_not_terminal = 1.0 - st.dones[step].float()
# TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
delta = st.rewards[step] + next_is_not_terminal * self.gamma * next_values - st.values[step]
# Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1})
advantage = delta + next_is_not_terminal * self.gamma * self.lam * advantage
# Return: R_t = A(s_t, a_t) + V(s_t)
st.returns[step] = advantage + st.values[step]
# Compute the advantages
st.advantages = st.returns - st.values
# Normalize the advantages if per minibatch normalization is not used
if not self.normalize_advantage_per_mini_batch:
st.advantages = (st.advantages - st.advantages.mean()) / (st.advantages.std() + 1e-8)

def update(self) -> dict[str, float]:
mean_value_loss = 0
Expand Down
17 changes: 7 additions & 10 deletions rsl_rl/runners/distillation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rsl_rl.env import VecEnv
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
from rsl_rl.runners import OnPolicyRunner
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import resolve_obs_groups, store_code_state


Expand Down Expand Up @@ -158,19 +159,15 @@ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)

# Initialize the storage
storage = RolloutStorage(
"distillation", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device
)

# Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg: Distillation = alg_class(
student_teacher, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)

# Initialize the storage
alg.init_storage(
"distillation",
self.env.num_envs,
self.num_steps_per_env,
obs,
[self.env.num_actions],
student_teacher, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)

return alg
17 changes: 8 additions & 9 deletions rsl_rl/runners/on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from rsl_rl.algorithms import PPO
from rsl_rl.env import VecEnv
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import resolve_obs_groups, store_code_state


Expand Down Expand Up @@ -418,17 +419,15 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO:
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)

# Initialize the storage
storage = RolloutStorage(
"rl", self.env.num_envs, self.num_steps_per_env, obs, [self.env.num_actions], self.device
)

# Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg)

# Initialize the storage
alg.init_storage(
"rl",
self.env.num_envs,
self.num_steps_per_env,
obs,
[self.env.num_actions],
alg: PPO = alg_class(
actor_critic, storage, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)

return alg
Expand Down
75 changes: 30 additions & 45 deletions rsl_rl/storage/rollout_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@


class RolloutStorage:
"""Storage for the data collected during a rollout.

The rollout storage is populated by adding transitions during the rollout phase. It then returns a generator for
learning, depending on the algorithm and the policy architecture.
"""

class Transition:
"""Storage for a single state transition."""

def __init__(self) -> None:
self.observations: TensorDict | None = None
self.actions: torch.Tensor | None = None
Expand Down Expand Up @@ -75,7 +83,7 @@ def __init__(
# Counter for the number of transitions stored
self.step = 0

def add_transitions(self, transition: Transition) -> None:
def add_transition(self, transition: Transition) -> None:
# Check if the transition is valid
if self.step >= self.num_transitions_per_env:
raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.")
Expand Down Expand Up @@ -103,53 +111,9 @@ def add_transitions(self, transition: Transition) -> None:
# Increment the counter
self.step += 1

def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None:
if hidden_states == (None, None):
return
# Make a tuple out of GRU hidden states to match the LSTM format
hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# Initialize hidden states if needed
if self.saved_hidden_state_a is None:
self.saved_hidden_state_a = [
torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device)
for i in range(len(hidden_state_a))
]
self.saved_hidden_state_c = [
torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device)
for i in range(len(hidden_state_c))
]
# Copy the states
for i in range(len(hidden_state_a)):
self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i])
self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i])

def clear(self) -> None:
self.step = 0

def compute_returns(
self, last_values: torch.Tensor, gamma: float, lam: float, normalize_advantage: bool = True
) -> None:
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
# If we are at the last step, bootstrap the return value
next_values = last_values if step == self.num_transitions_per_env - 1 else self.values[step + 1]
# 1 if we are not in a terminal state, 0 otherwise
next_is_not_terminal = 1.0 - self.dones[step].float()
# TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
# Advantage: A(s_t, a_t) = delta_t + gamma * lambda * A(s_{t+1}, a_{t+1})
advantage = delta + next_is_not_terminal * gamma * lam * advantage
# Return: R_t = A(s_t, a_t) + V(s_t)
self.returns[step] = advantage + self.values[step]

# Compute the advantages
self.advantages = self.returns - self.values
# Normalize the advantages if flag is set
# Note: This is to prevent double normalization (i.e. if per minibatch normalization is used)
if normalize_advantage:
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

# For distillation
def generator(self) -> Generator:
if self.training_type != "distillation":
Expand Down Expand Up @@ -289,3 +253,24 @@ def recurrent_mini_batch_generator(self, num_mini_batches: int, num_epochs: int
)

first_traj = last_traj

def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None:
if hidden_states == (None, None):
return
# Make a tuple out of GRU hidden states to match the LSTM format
hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# Initialize hidden states if needed
if self.saved_hidden_state_a is None:
self.saved_hidden_state_a = [
torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device)
for i in range(len(hidden_state_a))
]
self.saved_hidden_state_c = [
torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device)
for i in range(len(hidden_state_c))
]
# Copy the states
for i in range(len(hidden_state_a)):
self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i])
self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i])