Skip to content

Commit 71e8c70

Browse files
authored
Merge pull request #54 from bit-bots/feature/destillation
Add distillation script for faster inference
2 parents e6f6c89 + be4869e commit 71e8c70

File tree

20 files changed

+1681
-625
lines changed

20 files changed

+1681
-625
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,5 +209,8 @@ ENV/
209209
# Torch models
210210
*.pth
211211

212+
# Wandb Logs
213+
ddlitlab2024/ml/training/wandb/
214+
212215
# Input data
213216
input/

ddlitlab2024/dataset/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
class RobotState(str, Enum):
14-
POSITIONING = "POSITIONING"
1514
PLAYING = "PLAYING"
15+
POSITIONING = "POSITIONING"
1616
STOPPED = "STOPPED"
1717
UNKNOWN = "UNKNOWN"
1818

@@ -219,7 +219,8 @@ class JointStates(Base):
219219
Index(None, "recording_id", asc("stamp")),
220220
)
221221

222-
def get_ordered_joint_names(self) -> list[str]:
222+
@staticmethod
223+
def get_ordered_joint_names() -> list[str]:
223224
return [
224225
JointStates.head_pan.name,
225226
JointStates.head_tilt.name,

ddlitlab2024/dataset/pytorch.py

Lines changed: 98 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from collections.abc import Iterable
55
from dataclasses import asdict, dataclass
66
from pathlib import Path
7-
from typing import Literal
7+
from typing import Literal, Optional
88

9+
import cv2
910
import numpy as np
1011
import pandas as pd
1112
import torch
12-
from profilehooks import profile
1313
from tabulate import tabulate
1414
from torch.utils.data import DataLoader, Dataset
15+
from torchvision.transforms import v2
1516

1617
from ddlitlab2024 import DB_PATH
1718
from ddlitlab2024.dataset import logger
@@ -40,15 +41,15 @@ class DDLITLab2024Dataset(Dataset):
4041
@dataclass
4142
class Result:
4243
joint_command: torch.Tensor
43-
joint_command_history: torch.Tensor
44-
joint_state: torch.Tensor
45-
rotation: torch.Tensor
46-
game_state: torch.Tensor
47-
image_data: torch.Tensor
48-
image_stamps: torch.Tensor
44+
joint_command_history: Optional[torch.Tensor]
45+
joint_state: Optional[torch.Tensor]
46+
rotation: Optional[torch.Tensor]
47+
game_state: Optional[torch.Tensor]
48+
image_data: Optional[torch.Tensor]
49+
image_stamps: Optional[torch.Tensor]
4950

5051
def shapes(self) -> dict[str, tuple[int, ...]]:
51-
return {k: v.shape for k, v in asdict(self).items()}
52+
return {k: v.shape for k, v in asdict(self).items() if v is not None}
5253

5354
def __init__(
5455
self,
@@ -61,8 +62,14 @@ def __init__(
6162
sampling_rate: int = 100,
6263
max_fps_video: int = 10,
6364
num_frames_video: int = 50,
64-
trajectory_stride: int = 10,
65+
image_resolution: int = 480,
66+
trajectory_stride: int = 1,
6567
num_joints: int = 20,
68+
use_images: bool = True,
69+
use_imu: bool = True,
70+
use_joint_states: bool = True,
71+
use_action_history: bool = True,
72+
use_game_state: bool = True,
6673
):
6774
# Initialize the database connection
6875
self.db_connection: sqlite3.Connection = db_connection if db_connection else connect_to_db()
@@ -76,8 +83,15 @@ def __init__(
7683
self.sampling_rate = sampling_rate
7784
self.max_fps_video = max_fps_video
7885
self.num_frames_video = num_frames_video
86+
self.image_resolution = image_resolution
7987
self.trajectory_stride = trajectory_stride
8088
self.num_joints = num_joints
89+
self.joint_names = JointStates.get_ordered_joint_names()
90+
self.use_images = use_images
91+
self.use_imu = use_imu
92+
self.use_joint_states = use_joint_states
93+
self.use_action_history = use_action_history
94+
self.use_game_state = use_game_state
8195

8296
# Print out metadata
8397
cursor = self.db_connection.cursor()
@@ -100,7 +114,9 @@ def __init__(
100114
assert num_data_points > 0, "Recording length is negative or zero"
101115
total_samples_before = self.num_samples
102116
# Calculate the number of batches that can be build from the recording including the stride
103-
self.num_samples += int(num_data_points / self.trajectory_stride)
117+
self.num_samples += int(
118+
(num_data_points - self.num_samples_joint_trajectory_future) / self.trajectory_stride
119+
)
104120
# Store the boundaries of the samples for later retrieval
105121
self.sample_boundaries.append((total_samples_before, self.num_samples, recording_id))
106122

@@ -119,7 +135,7 @@ def query_joint_data(
119135
)
120136

121137
# Convert to numpy array, keep only the joint angle columns in alphabetical order
122-
raw_joint_data = raw_joint_data[JointStates.get_ordered_joint_names()].to_numpy(dtype=np.float32)
138+
raw_joint_data = raw_joint_data[self.joint_names].to_numpy(dtype=np.float32)
123139

124140
assert raw_joint_data.shape[1] == self.num_joints, "The number of joints is not correct"
125141

@@ -155,7 +171,7 @@ def query_joint_data_history(
155171
return raw_joint_data
156172

157173
def query_image_data(
158-
self, recording_id: int, end_time_stamp: float, context_len: float, num_frames: int
174+
self, recording_id: int, end_time_stamp: float, context_len: float, num_frames: int, resolution: int
159175
) -> tuple[torch.Tensor, torch.Tensor]:
160176
# Get the image data
161177
cursor = self.db_connection.cursor()
@@ -178,25 +194,36 @@ def query_image_data(
178194
stamps = []
179195
image_data = []
180196

197+
# Define the preprocessing pipeline
198+
preprocessing = v2.Compose(
199+
[
200+
v2.ToImage(),
201+
v2.ToDtype(torch.float32, scale=True),
202+
v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
203+
]
204+
)
205+
181206
# Get the raw image data
182207
for stamp, data in response:
183208
# Deserialize the image data
184209
image = np.frombuffer(data, dtype=np.uint8).reshape(480, 480, 3)
185-
# Make chw from hwc
186-
image = np.moveaxis(image, -1, 0)
210+
# Resize the image
211+
image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_AREA)
212+
# Apply the preprocessing pipeline
213+
image = preprocessing(image)
187214
# Append to the list
188215
image_data.append(image)
189216
stamps.append(stamp)
190217

191218
# Apply zero padding if necessary
192219
if len(image_data) < num_frames:
193220
image_data = [
194-
np.zeros((3, 480, 480), dtype=np.uint8) for _ in range(num_frames - len(image_data))
221+
torch.zeros((3, resolution, resolution), dtype=torch.float32)
222+
for _ in range(num_frames - len(image_data))
195223
] + image_data
196224
stamps = [end_time_stamp - context_len for _ in range(num_frames - len(stamps))] + stamps
197225

198-
# Convert to tensor
199-
image_data = torch.from_numpy(np.stack(image_data, axis=0)).float()
226+
image_data = torch.stack(image_data, axis=0)
200227
stamps = torch.tensor(stamps)
201228

202229
return stamps, image_data
@@ -244,7 +271,7 @@ def query_imu_data(self, recording_id: int, end_sample: int, num_samples: int) -
244271
case rep:
245272
raise NotImplementedError(f"Unknown IMU representation {rep}")
246273

247-
return torch.from_numpy(imu_data)
274+
return torch.from_numpy(imu_data).float()
248275

249276
def query_current_game_state(self, recording_id: int, stamp: float) -> torch.Tensor:
250277
cursor = self.db_connection.cursor()
@@ -265,7 +292,6 @@ def query_current_game_state(self, recording_id: int, stamp: float) -> torch.Ten
265292

266293
return torch.tensor(int(game_state))
267294

268-
@profile
269295
def __getitem__(self, idx: int) -> Result:
270296
# Find the recording that contains the sample
271297
for start_sample, end_sample, recording_id in self.sample_boundaries:
@@ -288,20 +314,30 @@ def __getitem__(self, idx: int) -> Result:
288314
stamp = sample_joint_command_index / self.sampling_rate
289315

290316
# Get the image data
291-
image_stamps, image_data = self.query_image_data(
292-
recording_id,
293-
stamp,
294-
# The duration is used to narrow down the query for a faster retrieval, so we consider it as an upper bound
295-
(self.num_frames_video + 1) / self.max_fps_video,
296-
self.num_frames_video,
297-
)
298-
# Some sanity checks
299-
assert all([stamp >= image_stamp for image_stamp in image_stamps]), "The image data is not synchronized"
300-
assert len(image_stamps) == self.num_frames_video, "The image data is not the correct length"
301-
assert image_data.shape == (self.num_frames_video, 3, 480, 480), "The image data has the wrong shape"
302-
assert (
303-
image_stamps[0] >= stamp - (self.num_frames_video + 1) / self.max_fps_video
304-
), "The image data is not synchronized"
317+
if self.use_images:
318+
image_stamps, image_data = self.query_image_data(
319+
recording_id,
320+
stamp,
321+
# The duration is used to narrow down the query for a faster retrieval,
322+
# so we consider it as an upper bound
323+
(self.num_frames_video + 1) / self.max_fps_video,
324+
self.num_frames_video,
325+
self.image_resolution,
326+
)
327+
# Some sanity checks
328+
assert all([stamp >= image_stamp for image_stamp in image_stamps]), "The image data is not synchronized"
329+
assert len(image_stamps) == self.num_frames_video, "The image data is not the correct length"
330+
assert image_data.shape == (
331+
self.num_frames_video,
332+
3,
333+
self.image_resolution,
334+
self.image_resolution,
335+
), "The image data has the wrong shape"
336+
assert (
337+
image_stamps[0] >= stamp - (self.num_frames_video + 1) / self.max_fps_video
338+
), "The image data is not synchronized"
339+
else:
340+
image_stamps, image_data = None, None
305341

306342
# Get the joint command target (future)
307343
joint_command = self.query_joint_data(
@@ -310,20 +346,32 @@ def __getitem__(self, idx: int) -> Result:
310346
assert len(joint_command) == self.num_samples_joint_trajectory_future, "The joint command has the wrong length"
311347

312348
# Get the joint command history
313-
joint_command_history = self.query_joint_data_history(
314-
recording_id, sample_joint_command_index, self.num_samples_joint_trajectory, "JointCommands"
315-
)
349+
if self.use_action_history:
350+
joint_command_history = self.query_joint_data_history(
351+
recording_id, sample_joint_command_index, self.num_samples_joint_trajectory, "JointCommands"
352+
)
353+
else:
354+
joint_command_history = None
316355

317356
# Get the joint state
318-
joint_state = self.query_joint_data_history(
319-
recording_id, sample_joint_command_index, self.num_samples_joint_states, "JointStates"
320-
)
357+
if self.use_joint_states:
358+
joint_state = self.query_joint_data_history(
359+
recording_id, sample_joint_command_index, self.num_samples_joint_states, "JointStates"
360+
)
361+
else:
362+
joint_state = None
321363

322364
# Get the robot rotation (IMU data)
323-
robot_rotation = self.query_imu_data(recording_id, sample_joint_command_index, self.num_samples_imu)
365+
if self.use_imu:
366+
robot_rotation = self.query_imu_data(recording_id, sample_joint_command_index, self.num_samples_imu)
367+
else:
368+
robot_rotation = None
324369

325370
# Get the game state
326-
game_state = self.query_current_game_state(recording_id, stamp)
371+
if self.use_game_state:
372+
game_state = self.query_current_game_state(recording_id, stamp)
373+
else:
374+
game_state = None
327375

328376
return self.Result(
329377
joint_command=joint_command,
@@ -339,12 +387,14 @@ def __getitem__(self, idx: int) -> Result:
339387
def collate_fn(batch: Iterable[Result]) -> Result:
340388
return DDLITLab2024Dataset.Result(
341389
joint_command=torch.stack([x.joint_command for x in batch]),
342-
joint_command_history=torch.stack([x.joint_command_history for x in batch]),
343-
joint_state=torch.stack([x.joint_state for x in batch]),
344-
image_data=torch.stack([x.image_data for x in batch]),
345-
image_stamps=torch.stack([x.image_stamps for x in batch]),
346-
rotation=torch.stack([x.rotation for x in batch]),
347-
game_state=torch.tensor([x.game_state for x in batch]),
390+
joint_command_history=torch.stack([x.joint_command_history for x in batch])
391+
if batch[0].joint_command_history is not None
392+
else None,
393+
joint_state=torch.stack([x.joint_state for x in batch]) if batch[0].joint_state is not None else None,
394+
image_data=torch.stack([x.image_data for x in batch]) if batch[0].image_data is not None else None,
395+
image_stamps=torch.stack([x.image_stamps for x in batch]) if batch[0].image_stamps is not None else None,
396+
rotation=torch.stack([x.rotation for x in batch]) if batch[0].rotation is not None else None,
397+
game_state=torch.tensor([x.game_state for x in batch]) if batch[0].game_state is not None else None,
348398
)
349399

350400

ddlitlab2024/ml/inference/plot.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# Parse the command line arguments
2626
parser = argparse.ArgumentParser(description="Inference Plot")
2727
parser.add_argument("checkpoint", type=str, help="Path to the checkpoint to load")
28-
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps")
28+
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps (not used for distilled)")
2929
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate")
3030
args = parser.parse_args()
3131

@@ -55,8 +55,12 @@
5555
image_encoder_type=ImageEncoderType(params["image_encoder_type"]),
5656
num_image_sequence_encoder_layers=params["num_image_sequence_encoder_layers"],
5757
image_context_length=params["image_context_length"],
58+
image_use_final_avgpool=params.get("image_use_final_avgpool", True),
59+
image_resolution=params.get("image_resolution", 480),
5860
num_decoder_layers=params["num_decoder_layers"],
5961
trajectory_prediction_length=params["trajectory_prediction_length"],
62+
use_gamestate=params["use_gamestate"],
63+
encoder_patch_size=params["encoder_patch_size"],
6064
).to(device)
6165
normalizer = Normalizer(model.mean, model.std)
6266
model.load_state_dict(checkpoint["model_state_dict"])
@@ -76,6 +80,13 @@
7680
num_samples_joint_trajectory=params["action_context_length"],
7781
num_samples_imu=params["imu_context_length"],
7882
num_samples_joint_states=params["joint_state_context_length"],
83+
imu_representation=IMUEncoder.OrientationEmbeddingMethod(params["imu_orientation_embedding_method"]),
84+
use_action_history=params["use_action_history"],
85+
use_imu=params["use_imu"],
86+
use_joint_states=params["use_joint_states"],
87+
use_images=params["use_images"],
88+
use_game_state=params["use_gamestate"],
89+
image_resolution=params.get("image_resolution", 480),
7990
)
8091

8192
# Create DataLoader object
@@ -104,15 +115,20 @@
104115
noisy_trajectory = torch.randn_like(joint_targets).to(device)
105116
trajectory = noisy_trajectory
106117

107-
# Perform the denoising process
108-
scheduler.set_timesteps(args.steps)
109-
for t in scheduler.timesteps:
118+
if params.get("distilled_decoder", False):
119+
# Directly predict the trajectory based on the noise
110120
with torch.no_grad():
111-
# Predict the noise residual
112-
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))
113-
114-
# Update the trajectory based on the predicted noise and the current step of the denoising process
115-
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample
121+
trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device))
122+
else:
123+
# Perform the denoising process
124+
scheduler.set_timesteps(args.steps)
125+
for t in scheduler.timesteps:
126+
with torch.no_grad():
127+
# Predict the noise residual
128+
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))
129+
130+
# Update the trajectory based on the predicted noise and the current step of the denoising process
131+
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample
116132

117133
# Undo the normalization
118134
trajectory = normalizer.denormalize(trajectory)

0 commit comments

Comments
 (0)