44from collections .abc import Iterable
55from dataclasses import asdict , dataclass
66from pathlib import Path
7- from typing import Literal
7+ from typing import Literal , Optional
88
9+ import cv2
910import numpy as np
1011import pandas as pd
1112import torch
12- from profilehooks import profile
1313from tabulate import tabulate
1414from torch .utils .data import DataLoader , Dataset
15+ from torchvision .transforms import v2
1516
1617from ddlitlab2024 import DB_PATH
1718from ddlitlab2024 .dataset import logger
@@ -40,15 +41,15 @@ class DDLITLab2024Dataset(Dataset):
4041 @dataclass
4142 class Result :
4243 joint_command : torch .Tensor
43- joint_command_history : torch .Tensor
44- joint_state : torch .Tensor
45- rotation : torch .Tensor
46- game_state : torch .Tensor
47- image_data : torch .Tensor
48- image_stamps : torch .Tensor
44+ joint_command_history : Optional [ torch .Tensor ]
45+ joint_state : Optional [ torch .Tensor ]
46+ rotation : Optional [ torch .Tensor ]
47+ game_state : Optional [ torch .Tensor ]
48+ image_data : Optional [ torch .Tensor ]
49+ image_stamps : Optional [ torch .Tensor ]
4950
5051 def shapes (self ) -> dict [str , tuple [int , ...]]:
51- return {k : v .shape for k , v in asdict (self ).items ()}
52+ return {k : v .shape for k , v in asdict (self ).items () if v is not None }
5253
5354 def __init__ (
5455 self ,
@@ -61,8 +62,14 @@ def __init__(
6162 sampling_rate : int = 100 ,
6263 max_fps_video : int = 10 ,
6364 num_frames_video : int = 50 ,
64- trajectory_stride : int = 10 ,
65+ image_resolution : int = 480 ,
66+ trajectory_stride : int = 1 ,
6567 num_joints : int = 20 ,
68+ use_images : bool = True ,
69+ use_imu : bool = True ,
70+ use_joint_states : bool = True ,
71+ use_action_history : bool = True ,
72+ use_game_state : bool = True ,
6673 ):
6774 # Initialize the database connection
6875 self .db_connection : sqlite3 .Connection = db_connection if db_connection else connect_to_db ()
@@ -76,8 +83,15 @@ def __init__(
7683 self .sampling_rate = sampling_rate
7784 self .max_fps_video = max_fps_video
7885 self .num_frames_video = num_frames_video
86+ self .image_resolution = image_resolution
7987 self .trajectory_stride = trajectory_stride
8088 self .num_joints = num_joints
89+ self .joint_names = JointStates .get_ordered_joint_names ()
90+ self .use_images = use_images
91+ self .use_imu = use_imu
92+ self .use_joint_states = use_joint_states
93+ self .use_action_history = use_action_history
94+ self .use_game_state = use_game_state
8195
8296 # Print out metadata
8397 cursor = self .db_connection .cursor ()
@@ -100,7 +114,9 @@ def __init__(
100114 assert num_data_points > 0 , "Recording length is negative or zero"
101115 total_samples_before = self .num_samples
102116 # Calculate the number of batches that can be build from the recording including the stride
103- self .num_samples += int (num_data_points / self .trajectory_stride )
117+ self .num_samples += int (
118+ (num_data_points - self .num_samples_joint_trajectory_future ) / self .trajectory_stride
119+ )
104120 # Store the boundaries of the samples for later retrieval
105121 self .sample_boundaries .append ((total_samples_before , self .num_samples , recording_id ))
106122
@@ -119,7 +135,7 @@ def query_joint_data(
119135 )
120136
121137 # Convert to numpy array, keep only the joint angle columns in alphabetical order
122- raw_joint_data = raw_joint_data [JointStates . get_ordered_joint_names () ].to_numpy (dtype = np .float32 )
138+ raw_joint_data = raw_joint_data [self . joint_names ].to_numpy (dtype = np .float32 )
123139
124140 assert raw_joint_data .shape [1 ] == self .num_joints , "The number of joints is not correct"
125141
@@ -155,7 +171,7 @@ def query_joint_data_history(
155171 return raw_joint_data
156172
157173 def query_image_data (
158- self , recording_id : int , end_time_stamp : float , context_len : float , num_frames : int
174+ self , recording_id : int , end_time_stamp : float , context_len : float , num_frames : int , resolution : int
159175 ) -> tuple [torch .Tensor , torch .Tensor ]:
160176 # Get the image data
161177 cursor = self .db_connection .cursor ()
@@ -178,25 +194,36 @@ def query_image_data(
178194 stamps = []
179195 image_data = []
180196
197+ # Define the preprocessing pipeline
198+ preprocessing = v2 .Compose (
199+ [
200+ v2 .ToImage (),
201+ v2 .ToDtype (torch .float32 , scale = True ),
202+ v2 .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 )),
203+ ]
204+ )
205+
181206 # Get the raw image data
182207 for stamp , data in response :
183208 # Deserialize the image data
184209 image = np .frombuffer (data , dtype = np .uint8 ).reshape (480 , 480 , 3 )
185- # Make chw from hwc
186- image = np .moveaxis (image , - 1 , 0 )
210+ # Resize the image
211+ image = cv2 .resize (image , (resolution , resolution ), interpolation = cv2 .INTER_AREA )
212+ # Apply the preprocessing pipeline
213+ image = preprocessing (image )
187214 # Append to the list
188215 image_data .append (image )
189216 stamps .append (stamp )
190217
191218 # Apply zero padding if necessary
192219 if len (image_data ) < num_frames :
193220 image_data = [
194- np .zeros ((3 , 480 , 480 ), dtype = np .uint8 ) for _ in range (num_frames - len (image_data ))
221+ torch .zeros ((3 , resolution , resolution ), dtype = torch .float32 )
222+ for _ in range (num_frames - len (image_data ))
195223 ] + image_data
196224 stamps = [end_time_stamp - context_len for _ in range (num_frames - len (stamps ))] + stamps
197225
198- # Convert to tensor
199- image_data = torch .from_numpy (np .stack (image_data , axis = 0 )).float ()
226+ image_data = torch .stack (image_data , axis = 0 )
200227 stamps = torch .tensor (stamps )
201228
202229 return stamps , image_data
@@ -244,7 +271,7 @@ def query_imu_data(self, recording_id: int, end_sample: int, num_samples: int) -
244271 case rep :
245272 raise NotImplementedError (f"Unknown IMU representation { rep } " )
246273
247- return torch .from_numpy (imu_data )
274+ return torch .from_numpy (imu_data ). float ()
248275
249276 def query_current_game_state (self , recording_id : int , stamp : float ) -> torch .Tensor :
250277 cursor = self .db_connection .cursor ()
@@ -265,7 +292,6 @@ def query_current_game_state(self, recording_id: int, stamp: float) -> torch.Ten
265292
266293 return torch .tensor (int (game_state ))
267294
268- @profile
269295 def __getitem__ (self , idx : int ) -> Result :
270296 # Find the recording that contains the sample
271297 for start_sample , end_sample , recording_id in self .sample_boundaries :
@@ -288,20 +314,30 @@ def __getitem__(self, idx: int) -> Result:
288314 stamp = sample_joint_command_index / self .sampling_rate
289315
290316 # Get the image data
291- image_stamps , image_data = self .query_image_data (
292- recording_id ,
293- stamp ,
294- # The duration is used to narrow down the query for a faster retrieval, so we consider it as an upper bound
295- (self .num_frames_video + 1 ) / self .max_fps_video ,
296- self .num_frames_video ,
297- )
298- # Some sanity checks
299- assert all ([stamp >= image_stamp for image_stamp in image_stamps ]), "The image data is not synchronized"
300- assert len (image_stamps ) == self .num_frames_video , "The image data is not the correct length"
301- assert image_data .shape == (self .num_frames_video , 3 , 480 , 480 ), "The image data has the wrong shape"
302- assert (
303- image_stamps [0 ] >= stamp - (self .num_frames_video + 1 ) / self .max_fps_video
304- ), "The image data is not synchronized"
317+ if self .use_images :
318+ image_stamps , image_data = self .query_image_data (
319+ recording_id ,
320+ stamp ,
321+ # The duration is used to narrow down the query for a faster retrieval,
322+ # so we consider it as an upper bound
323+ (self .num_frames_video + 1 ) / self .max_fps_video ,
324+ self .num_frames_video ,
325+ self .image_resolution ,
326+ )
327+ # Some sanity checks
328+ assert all ([stamp >= image_stamp for image_stamp in image_stamps ]), "The image data is not synchronized"
329+ assert len (image_stamps ) == self .num_frames_video , "The image data is not the correct length"
330+ assert image_data .shape == (
331+ self .num_frames_video ,
332+ 3 ,
333+ self .image_resolution ,
334+ self .image_resolution ,
335+ ), "The image data has the wrong shape"
336+ assert (
337+ image_stamps [0 ] >= stamp - (self .num_frames_video + 1 ) / self .max_fps_video
338+ ), "The image data is not synchronized"
339+ else :
340+ image_stamps , image_data = None , None
305341
306342 # Get the joint command target (future)
307343 joint_command = self .query_joint_data (
@@ -310,20 +346,32 @@ def __getitem__(self, idx: int) -> Result:
310346 assert len (joint_command ) == self .num_samples_joint_trajectory_future , "The joint command has the wrong length"
311347
312348 # Get the joint command history
313- joint_command_history = self .query_joint_data_history (
314- recording_id , sample_joint_command_index , self .num_samples_joint_trajectory , "JointCommands"
315- )
349+ if self .use_action_history :
350+ joint_command_history = self .query_joint_data_history (
351+ recording_id , sample_joint_command_index , self .num_samples_joint_trajectory , "JointCommands"
352+ )
353+ else :
354+ joint_command_history = None
316355
317356 # Get the joint state
318- joint_state = self .query_joint_data_history (
319- recording_id , sample_joint_command_index , self .num_samples_joint_states , "JointStates"
320- )
357+ if self .use_joint_states :
358+ joint_state = self .query_joint_data_history (
359+ recording_id , sample_joint_command_index , self .num_samples_joint_states , "JointStates"
360+ )
361+ else :
362+ joint_state = None
321363
322364 # Get the robot rotation (IMU data)
323- robot_rotation = self .query_imu_data (recording_id , sample_joint_command_index , self .num_samples_imu )
365+ if self .use_imu :
366+ robot_rotation = self .query_imu_data (recording_id , sample_joint_command_index , self .num_samples_imu )
367+ else :
368+ robot_rotation = None
324369
325370 # Get the game state
326- game_state = self .query_current_game_state (recording_id , stamp )
371+ if self .use_game_state :
372+ game_state = self .query_current_game_state (recording_id , stamp )
373+ else :
374+ game_state = None
327375
328376 return self .Result (
329377 joint_command = joint_command ,
@@ -339,12 +387,14 @@ def __getitem__(self, idx: int) -> Result:
339387 def collate_fn (batch : Iterable [Result ]) -> Result :
340388 return DDLITLab2024Dataset .Result (
341389 joint_command = torch .stack ([x .joint_command for x in batch ]),
342- joint_command_history = torch .stack ([x .joint_command_history for x in batch ]),
343- joint_state = torch .stack ([x .joint_state for x in batch ]),
344- image_data = torch .stack ([x .image_data for x in batch ]),
345- image_stamps = torch .stack ([x .image_stamps for x in batch ]),
346- rotation = torch .stack ([x .rotation for x in batch ]),
347- game_state = torch .tensor ([x .game_state for x in batch ]),
390+ joint_command_history = torch .stack ([x .joint_command_history for x in batch ])
391+ if batch [0 ].joint_command_history is not None
392+ else None ,
393+ joint_state = torch .stack ([x .joint_state for x in batch ]) if batch [0 ].joint_state is not None else None ,
394+ image_data = torch .stack ([x .image_data for x in batch ]) if batch [0 ].image_data is not None else None ,
395+ image_stamps = torch .stack ([x .image_stamps for x in batch ]) if batch [0 ].image_stamps is not None else None ,
396+ rotation = torch .stack ([x .rotation for x in batch ]) if batch [0 ].rotation is not None else None ,
397+ game_state = torch .tensor ([x .game_state for x in batch ]) if batch [0 ].game_state is not None else None ,
348398 )
349399
350400
0 commit comments