Skip to content

Commit 2d16aa6

Browse files
committed
Refactor contact models to use shape transforms
1 parent e3113f2 commit 2d16aa6

File tree

3 files changed

+117
-49
lines changed

3 files changed

+117
-49
lines changed

src/jaxsim/api/contact.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,28 @@ def contact_point_kinematics(
3636
the linear component of the mixed 6D frame velocity.
3737
"""
3838

39-
_, _, _, W_p_Ci, W_ṗ_Ci = jax.vmap(
40-
jaxsim.rbda.contacts.common.compute_penetration_data, in_axes=(None,)
39+
_, _, _, W_p_Ci, W_ṗ_Ci = jax.vmap(
40+
lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: jaxsim.rbda.contacts.common.compute_penetration_data(
41+
model,
42+
shape_transform=shape_transform,
43+
shape_type=shape_type,
44+
shape_size=shape_size,
45+
link_transforms=link_transform,
46+
link_velocities=link_velocity,
47+
)
4148
)(
42-
model,
43-
shape_offset=model.kin_dyn_parameters.contact_parameters.center,
44-
shape_type=model.kin_dyn_parameters.contact_parameters.shape_type,
45-
shape_size=model.kin_dyn_parameters.contact_parameters.shape_size,
46-
link_transforms=data._link_transforms,
47-
link_velocities=data._link_velocities,
49+
model.kin_dyn_parameters.contact_parameters.transform,
50+
model.kin_dyn_parameters.contact_parameters.shape_type,
51+
model.kin_dyn_parameters.contact_parameters.shape_size,
52+
data._link_transforms[
53+
jnp.array(model.kin_dyn_parameters.contact_parameters.body)
54+
],
55+
data._link_velocities[
56+
jnp.array(model.kin_dyn_parameters.contact_parameters.body)
57+
],
4858
)
4959

50-
return W_p_Ci, W_ṗ_Ci
60+
return W_p_Ci, W_ṗ_Ci
5161

5262

5363
@jax.jit
@@ -241,21 +251,29 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
241251
# Get the transforms of the parent link of all collidable points.
242252
W_H_L = data._link_transforms
243253

244-
def _process_single_shape(shape_type, shape_size, W_H_Li):
254+
# Index transforms by the body (parent link) of each collision shape
255+
body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
256+
W_H_L_indexed = W_H_L[body_indices]
257+
258+
def _process_single_shape(shape_type, shape_size, shape_transform, W_H_Li):
259+
# Apply the collision shape transform to get W_H_S
260+
W_H_S = W_H_Li @ shape_transform
261+
245262
_, W_H_C = jax.lax.switch(
246263
shape_type,
247264
(detection.box_plane, detection.cylinder_plane, detection.sphere_plane),
248265
model.terrain,
249266
shape_size,
250-
W_H_Li,
267+
W_H_S,
251268
)
252269

253270
return W_H_C
254271

255272
return jax.vmap(_process_single_shape)(
256273
model.kin_dyn_parameters.contact_parameters.shape_type,
257274
model.kin_dyn_parameters.contact_parameters.shape_size,
258-
W_H_L,
275+
model.kin_dyn_parameters.contact_parameters.transform,
276+
W_H_L_indexed,
259277
)
260278

261279

@@ -294,13 +312,17 @@ def jacobian(
294312
model=model, data=data, output_vel_repr=VelRepr.Inertial
295313
)
296314

297-
# Compute contact transforms (n_links, n_contacts, 4, 4)
315+
# Compute contact transforms (n_shapes, n_contacts_per_shape, 4, 4)
298316
W_H_C = transforms(model=model, data=data)
299317

300-
# Flatten link × contact axes for single-batch processing (n_links*n_contacts, 6, 6+n)
301-
W_J_WC_flat = jnp.repeat(W_J_WL, 3, axis=0)
318+
# Index Jacobians by the body (parent link) of each collision shape
319+
body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
320+
W_J_WL_indexed = W_J_WL[body_indices] # (n_shapes, 6, 6+n)
321+
322+
# Repeat for each contact point per shape: (n_shapes*n_contacts_per_shape, 6, 6+n)
323+
W_J_WC_flat = jnp.repeat(W_J_WL_indexed, 3, axis=0)
302324

303-
# Flatten contact transforms (n_links*n_contacts, 4, 4)
325+
# Flatten contact transforms (n_shapes*n_contacts_per_shape, 4, 4)
304326
W_H_C_flat = W_H_C.reshape(-1, 4, 4)
305327

306328
# Transform Jacobian based on velocity representation
@@ -357,7 +379,11 @@ def jacobian_derivative(
357379
# Get the link velocities.
358380
W_v_WL = data._link_velocities
359381

360-
# Compute the contact transforms (n_links, n_contacts, 4, 4)
382+
# Index link velocities by body (parent link) of each collision shape
383+
body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
384+
W_v_WL_indexed = W_v_WL[body_indices] # (n_shapes, 6)
385+
386+
# Compute the contact transforms (n_shapes, n_contacts, 4, 4)
361387
W_H_C = transforms(model=model, data=data)
362388

363389
# =====================================================
@@ -408,6 +434,10 @@ def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix:
408434
model=model, data=data
409435
)
410436

437+
# Index Jacobians by body (parent link) of each collision shape
438+
W_J_WL_W_indexed = W_J_WL_W[body_indices] # (n_shapes, 6, 6+n)
439+
W_J̇_WL_W_indexed = W_J̇_WL_W[body_indices] # (n_shapes, 6, 6+n)
440+
411441
def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix:
412442
match output_vel_repr:
413443
case VelRepr.Inertial:
@@ -430,15 +460,15 @@ def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix:
430460

431461
return O_J̇_WC_I
432462

433-
O_J̇_per_link = jax.vmap(
434-
lambda H_C_link, v_WL_link, J_WL_link, J̇_WL_link: jax.vmap(
463+
O_J̇_per_shape = jax.vmap(
464+
lambda H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape: jax.vmap(
435465
compute_O_J̇_WC_I,
436466
in_axes=(0, None, None, None), # Map over contacts for W_H_C only
437-
)(H_C_link, v_WL_link, J_WL_link, J̇_WL_link),
438-
in_axes=(0, 0, 0, 0), # Map over links
439-
)(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W)
467+
)(H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape),
468+
in_axes=(0, 0, 0, 0), # Map over shapes
469+
)(W_H_C, W_v_WL_indexed, W_J_WL_W_indexed, W_J̇_WL_W_indexed)
440470

441-
O_J̇_WC = O_J̇_per_link.reshape(-1, 6, 6 + model.dofs())
471+
O_J̇_WC = O_J̇_per_shape.reshape(-1, 6, 6 + model.dofs())
442472

443473
return O_J̇_WC
444474

src/jaxsim/rbda/contacts/relaxed_rigid.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,25 @@ def compute_contact_forces(
328328

329329
# Compute the penetration depth and velocity of the collidable points.
330330
# Note that this function considers the penetration in the normal direction.
331-
δ, δ̇, , W_p_C, CW_ṗ_C = jax.vmap(
332-
common.compute_penetration_data, in_axes=(None,)
331+
δ, δ̇, , W_p_C, CW_ṗ_C = jax.vmap(
332+
lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data(
333+
model,
334+
shape_transform=shape_transform,
335+
shape_type=shape_type,
336+
shape_size=shape_size,
337+
link_transforms=link_transform,
338+
link_velocities=link_velocity,
339+
)
333340
)(
334-
model,
335-
shape_offset=model.kin_dyn_parameters.contact_parameters.center,
336-
shape_type=model.kin_dyn_parameters.contact_parameters.shape_type,
337-
shape_size=model.kin_dyn_parameters.contact_parameters.shape_size,
338-
link_transforms=data._link_transforms,
339-
link_velocities=data._link_velocities,
341+
model.kin_dyn_parameters.contact_parameters.transform,
342+
model.kin_dyn_parameters.contact_parameters.shape_type,
343+
model.kin_dyn_parameters.contact_parameters.shape_size,
344+
data._link_transforms[
345+
jnp.array(model.kin_dyn_parameters.contact_parameters.body)
346+
],
347+
data._link_velocities[
348+
jnp.array(model.kin_dyn_parameters.contact_parameters.body)
349+
],
340350
)
341351

342352
# Compute the position in the constraint frame.
@@ -346,7 +356,7 @@ def compute_contact_forces(
346356
a_ref, r, *_ = self._regularizers(
347357
model=model,
348358
position_constraint=position_constraint,
349-
velocity_constraint=CW_ṗ_C,
359+
velocity_constraint=CW_ṗ_C,
350360
parameters=model.contact_params,
351361
)
352362

@@ -529,13 +539,21 @@ def to_inertial(force, H_C):
529539

530540
# Compute the contact forces in inertial representation for
531541
# each link and contact point.
532-
# Nested vmap: inner over contacts, outer over links
533-
W_f_C = jax.vmap(lambda f_link, H_link: jax.vmap(to_inertial)(f_link, H_link))(
534-
CW_fl_per_link, W_H_C
542+
# Nested vmap: inner over contacts, outer over shapes
543+
W_f_C = jax.vmap(
544+
lambda f_shape, H_shape: jax.vmap(to_inertial)(f_shape, H_shape)
545+
)(CW_fl_per_link, W_H_C)
546+
547+
# Sum over contacts for each shape: (n_shapes, 6)
548+
W_f_per_shape = W_f_C.sum(axis=1)
549+
550+
# Accumulate forces by parent link using segment_sum
551+
body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
552+
W_f_per_link = jax.ops.segment_sum(
553+
W_f_per_shape, body_indices, num_segments=model.number_of_links()
535554
)
536555

537-
# Sum over contacts for each link
538-
return W_f_C.sum(axis=1), {}
556+
return W_f_per_link, {}
539557

540558
@staticmethod
541559
def _regularizers(
@@ -576,7 +594,11 @@ def _regularizers(
576594
)
577595

578596
# Compute the 6D inertia matrices of all links.
579-
M_L = js.model.link_spatial_inertia_matrices(model=model)[:, :3, :3]
597+
M_L_all = js.model.link_spatial_inertia_matrices(model=model)[:, :3, :3]
598+
599+
# Index M_L by the body (parent link) of each collision shape
600+
body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
601+
M_L = M_L_all[body_indices]
580602

581603
def imp_aref(
582604
pos: jtp.Vector,

src/jaxsim/rbda/contacts/soft.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -419,15 +419,25 @@ def compute_contact_forces(
419419
# Compute the position and linear velocities (mixed representation) of
420420
# all the collidable shapes belonging to the robot and extract the ones
421421
# for the enabled collidable shapes.
422-
δ, δ̇, , W_p_C, CW_ṗ_C = jax.vmap(
423-
common.compute_penetration_data, in_axes=(None,)
422+
δ, δ̇, , W_p_C, CW_ṗ_C = jax.vmap(
423+
lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data(
424+
model,
425+
shape_transform=shape_transform,
426+
shape_type=shape_type,
427+
shape_size=shape_size,
428+
link_transforms=link_transform,
429+
link_velocities=link_velocity,
430+
)
424431
)(
425-
model,
426-
shape_offset=model.kin_dyn_parameters.contact_parameters.center,
427-
shape_type=model.kin_dyn_parameters.contact_parameters.shape_type,
428-
shape_size=model.kin_dyn_parameters.contact_parameters.shape_size,
429-
link_transforms=data._link_transforms,
430-
link_velocities=data._link_velocities,
432+
model.kin_dyn_parameters.contact_parameters.transform,
433+
model.kin_dyn_parameters.contact_parameters.shape_type,
434+
model.kin_dyn_parameters.contact_parameters.shape_size,
435+
data._link_transforms[
436+
jnp.array(model.kin_dyn_parameters.contact_parameters.body)
437+
],
438+
data._link_velocities[
439+
jnp.array(model.kin_dyn_parameters.contact_parameters.body)
440+
],
431441
)
432442

433443
# Extract the material deformation corresponding to the collidable shapes.
@@ -441,9 +451,15 @@ def compute_contact_forces(
441451
# We exploit two levels of vmap to vectorize over both the shapes and the points.
442452
# The outer vmap vectorizes over the shapes, while the inner vmap vectorizes
443453
# over the maximum points (3) belonging to each shape.
444-
W_f, = jax.vmap(
454+
W_f_per_shape, = jax.vmap(
445455
SoftContacts.compute_contact_force,
446456
in_axes=(0, 0, 0, 0, 0, 0, None), # vectorize over shapes
447-
)(δ, δ̇, W_p_C, CW_ṗ_C, , m, model.contact_params)
457+
)(δ, δ̇, W_p_C, CW_ṗ_C, , m, model.contact_params)
458+
459+
# Accumulate forces by parent link using segment_sum
460+
body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
461+
W_f = jax.ops.segment_sum(
462+
W_f_per_shape, body_indices, num_segments=model.number_of_links()
463+
)
448464

449-
return W_f, {"m_dot": }
465+
return W_f, {"m_dot": }

0 commit comments

Comments
 (0)