Skip to content

Commit e3113f2

Browse files
committed
Save collision shape to link transform
1 parent af5c206 commit e3113f2

File tree

6 files changed

+88
-81
lines changed

6 files changed

+88
-81
lines changed

src/jaxsim/api/kin_dyn_parameters.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -783,24 +783,41 @@ class ContactParameters(JaxsimDataclass):
783783
784784
Attributes:
785785
body:
786-
A tuple of integers representing, for each collidable point, the index of
787-
the body (link) to which it is rigidly attached to.
788-
point:
789-
The translations between the link frame and the collidable point, expressed
790-
in the coordinates of the parent link frame.
791-
enabled:
792-
A tuple of booleans representing, for each collidable point, whether it is
793-
enabled or not in contact models.
786+
A tuple of integers representing, for each collision shape, the index of
787+
the link to which it is rigidly attached to.
788+
transform:
789+
The 4x4 homogeneous transformation matrices representing the pose of each
790+
collision shape with respect to the parent link frame.
791+
shape_size:
792+
The size parameters of each collidable shape.
793+
shape_type:
794+
The type of each collidable shape (sphere, box, cylinder, etc.).
794795
795796
Note:
796797
Contrarily to LinkParameters and JointParameters, this class is not meant
797798
to be created with vmap. This is because the `body` attribute must be `Static`.
798799
"""
799800

800-
center: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([]))
801+
body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
802+
803+
transform: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))
801804
shape_size: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([]))
802805
shape_type: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([]))
803806

807+
@property
808+
def center(self) -> jtp.Array:
809+
"""Extract translation vectors from transformation matrices."""
810+
if self.transform.size == 0:
811+
return jnp.array([])
812+
return self.transform[:, :3, 3]
813+
814+
@property
815+
def orientation(self) -> jtp.Array:
816+
"""Extract rotation matrices from transformation matrices."""
817+
if self.transform.size == 0:
818+
return jnp.array([])
819+
return self.transform[:, :3, :3]
820+
804821
@staticmethod
805822
def build_from(model_description: ModelDescription) -> ContactParameters:
806823
"""
@@ -816,7 +833,12 @@ def build_from(model_description: ModelDescription) -> ContactParameters:
816833
if len(model_description.collision_shapes) == 0:
817834
return ContactParameters()
818835

819-
shape_types, shape_sizes, centers = [], [], []
836+
shape_types, shape_sizes, transforms, parent_link_indices = (
837+
[],
838+
[],
839+
[],
840+
[],
841+
)
820842

821843
# Assume the link_parameters and the collision_shapes are in the same order.
822844
for collision in model_description.collision_shapes:
@@ -828,11 +850,17 @@ def build_from(model_description: ModelDescription) -> ContactParameters:
828850

829851
shape_sizes.append(collision.size.squeeze())
830852

831-
centers.append(collision.center)
853+
transforms.append(collision.transform)
854+
855+
# Get the parent link index for this collision shape.
856+
parent_link_indices.append(
857+
model_description.links_dict[collision.parent_link].index
858+
)
832859

833860
# Build the ContactParameters object.
834861
return ContactParameters(
835-
center=jnp.array(centers, dtype=float),
862+
body=tuple(parent_link_indices),
863+
transform=jnp.array(transforms, dtype=float),
836864
shape_type=jnp.array(shape_types, dtype=int),
837865
shape_size=jnp.array(shape_sizes, dtype=float),
838866
)

src/jaxsim/parsers/descriptions/collision.py

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import dataclasses
44
from abc import ABC
55

6+
import numpy as np
7+
68
import jaxsim.typing as jtp
79

810

@@ -15,16 +17,16 @@ class CollisionShape(ABC):
1517
It is not intended to be instantiated directly.
1618
"""
1719

18-
center: jtp.VectorLike
1920
size: jtp.VectorLike
2021
parent_link: str
22+
transform: jtp.MatrixLike = dataclasses.field(default_factory=lambda: np.eye(4))
2123

2224
def __hash__(self) -> int:
2325
return hash(
2426
(
25-
hash(tuple(self.center.tolist())),
2627
hash(tuple(self.size.tolist())),
2728
hash(self.parent_link),
29+
hash(tuple(self.transform.flatten().tolist())),
2830
)
2931
)
3032

@@ -35,59 +37,33 @@ def __eq__(self, other: CollisionShape) -> bool:
3537

3638
return hash(self) == hash(other)
3739

40+
@property
41+
def center(self) -> jtp.Vector:
42+
"""Extract the translation from the transformation matrix."""
43+
return self.transform[:3, 3]
44+
45+
@property
46+
def orientation(self) -> jtp.Matrix:
47+
"""Extract the rotation matrix from the transformation matrix."""
48+
return self.transform[:3, :3]
49+
3850

3951
@dataclasses.dataclass
4052
class BoxCollision(CollisionShape):
4153
"""
4254
Represents a box-shaped collision shape.
4355
"""
4456

45-
@property
46-
def x(self) -> float:
47-
return self.size[0]
48-
49-
@property
50-
def y(self) -> float:
51-
return self.size[1]
52-
53-
@property
54-
def z(self) -> float:
55-
return self.size[2]
56-
57-
@x.setter
58-
def x(self, value: float) -> None:
59-
self.size[0] = value
60-
61-
@y.setter
62-
def y(self, value: float) -> None:
63-
self.size[1] = value
64-
65-
@z.setter
66-
def z(self, value: float) -> None:
67-
self.size[2] = value
68-
6957

7058
@dataclasses.dataclass
7159
class SphereCollision(CollisionShape):
7260
"""
7361
Represents a spherical collision shape.
7462
"""
7563

76-
@property
77-
def radius(self) -> float:
78-
return self.size[0]
79-
8064

8165
@dataclasses.dataclass
8266
class CylinderCollision(CollisionShape):
8367
"""
8468
Represents a cylindrical collision shape.
8569
"""
86-
87-
@property
88-
def radius(self) -> float:
89-
return self.size[0]
90-
91-
@property
92-
def height(self) -> float:
93-
return self.size[1]

src/jaxsim/parsers/rod/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def extract_model_data(
346346
# Fill with unsupported collision shape
347347
collisions.append(
348348
descriptions.collision.CollisionShape(
349-
center=jnp.array([0.0, 0.0, 0.0]),
349+
transform=jnp.eye(4),
350350
size=jnp.array([0.0, 0.0, 0.0]),
351351
parent_link=link.name,
352352
)

src/jaxsim/parsers/rod/utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,9 @@ def create_box_collision(
105105

106106
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
107107

108-
center = H[:3, 3]
109-
110108
return descriptions.BoxCollision(
111109
size=np.array([x, y, z]),
112-
center=center,
110+
transform=H,
113111
parent_link=link_description.name,
114112
)
115113

@@ -132,11 +130,9 @@ def create_sphere_collision(
132130

133131
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
134132

135-
center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1]
136-
137133
return descriptions.SphereCollision(
138134
size=np.array([r] * 3),
139-
center=center_wrt_link,
135+
transform=H,
140136
parent_link=link_description.name,
141137
)
142138

@@ -160,10 +156,8 @@ def create_cylinder_collision(
160156

161157
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
162158

163-
center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1]
164-
165159
return descriptions.CylinderCollision(
166160
size=np.array([r, l, 0]),
167-
center=center_wrt_link,
161+
transform=H,
168162
parent_link=link_description.name,
169163
)

src/jaxsim/rbda/contacts/common.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
def compute_penetration_data(
3333
model: js.model.JaxSimModel,
3434
*,
35-
shape_offset: jtp.Vector,
35+
shape_transform: jtp.Matrix,
3636
shape_type: CollidableShapeType,
3737
shape_size: jtp.Vector,
3838
link_transforms: jtp.Matrix,
@@ -43,7 +43,7 @@ def compute_penetration_data(
4343
4444
Args:
4545
model: The model to consider.
46-
shape_offset: The offset of the collidable shape with respect to the link frame.
46+
shape_transform: The 4x4 transform of the collidable shape with respect to the link frame.
4747
shape_type: The type of the collidable shape.
4848
shape_size: The size parameters of the collidable shape.
4949
link_transforms: The transforms from the world frame to each link.
@@ -55,10 +55,11 @@ def compute_penetration_data(
5555
expressed in mixed representation.
5656
"""
5757

58-
W_H_L, W_ṗ_L = link_transforms, link_velocities
58+
W_H_L, W_ṗ_L = link_transforms, link_velocities
5959

60-
# Offset the collision shape origin.
61-
W_H_L = W_H_L.at[:3, 3].set(W_H_L[:3, 3] + shape_offset @ W_H_L[:3, :3].T)
60+
# Apply the collision shape transform.
61+
# This computes W_H_S where S is the collision shape frame.
62+
W_H_S = W_H_L @ shape_transform
6263

6364
# Pre-process the position and the linear velocity of the collidable point.
6465
# Note that we consider 3 candidate contact points also for spherical shapes,
@@ -69,7 +70,7 @@ def compute_penetration_data(
6970
(box_plane, cylinder_plane, sphere_plane),
7071
model.terrain,
7172
shape_size,
72-
W_H_L,
73+
W_H_S,
7374
)
7475

7576
W_p_C = W_H_C[:, :3, 3]

tests/test_api_contact.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,24 @@ def test_contact_jacobian_derivative(
9797
velocity_representation=velocity_representation,
9898
)
9999

100-
W_H_L = data._link_transforms
100+
body_indices = np.array(model.kin_dyn_parameters.contact_parameters.body)
101+
102+
# Get link transforms for each collision shape
103+
W_H_L = data._link_transforms[body_indices]
104+
105+
# Get contact point positions (shape: num_collision_shapes, 3, 3)
101106
W_p_C = js.contact.contact_point_positions(model=model, data=data)
102107

103-
# Vectorize over the 3 points for one link
104-
transform_points = jax.vmap(
105-
lambda H, p: H @ jnp.hstack([p, 1.0]), in_axes=(None, 0)
106-
)
108+
# Transform contact points from world to link frame
109+
# For each collision shape, transform its 3 contact points
110+
def transform_to_link_frame(W_H_L_i, W_p_Ci):
111+
"""Transform 3 contact points from world to link frame."""
107112

108-
# Vectorize over the links
109-
L_p_Ci = jax.vmap(transform_points, in_axes=(0, 0))(W_H_L, W_p_C)[..., :3]
113+
L_H_W = jnp.linalg.inv(W_H_L_i)
114+
return jax.vmap(lambda p: (L_H_W @ jnp.hstack([p, 1.0]))[:3])(W_p_Ci)
115+
116+
# Apply to all collision shapes: shape (num_collision_shapes, 3, 3)
117+
L_p_Ci = jax.vmap(transform_to_link_frame)(W_H_L, W_p_C)
110118

111119
# =====
112120
# Tests
@@ -115,16 +123,15 @@ def test_contact_jacobian_derivative(
115123
# Load the model in ROD.
116124
rod_model = rod.Sdf.load(sdf=model.built_from).model
117125

118-
# Add dummy frames on the contact shapes.
119-
120-
for idx, link_name, points in zip(
121-
np.arange(model.number_of_links()), model.link_names(), L_p_Ci, strict=True
126+
for shape_idx, (link_idx, points) in enumerate(
127+
zip(body_indices, L_p_Ci, strict=True)
122128
):
123-
# points: shape (3, 3) for this link
129+
link_name = model.link_names()[link_idx]
130+
124131
for j, p in enumerate(points):
125132
rod_model.add_frame(
126133
frame=rod.Frame(
127-
name=f"contact_shape_{idx}_{j}",
134+
name=f"contact_shape_{shape_idx}_{j}",
128135
attached_to=link_name,
129136
pose=rod.Pose(
130137
relative_to=link_name,
@@ -154,11 +161,12 @@ def test_contact_jacobian_derivative(
154161
)
155162

156163
# Extract the indexes of the frames attached to the contact shapes.
164+
num_collision_shapes = len(model.kin_dyn_parameters.contact_parameters.body)
157165
frame_idxs = js.frame.names_to_idxs(
158166
model=model_with_frames,
159167
frame_names=(
160-
f"contact_shape_{idx}_{j}"
161-
for idx in np.arange(model.number_of_links())
168+
f"contact_shape_{shape_idx}_{j}"
169+
for shape_idx in range(num_collision_shapes)
162170
for j in range(3)
163171
),
164172
)

0 commit comments

Comments
 (0)