@@ -36,18 +36,28 @@ def contact_point_kinematics(
3636 the linear component of the mixed 6D frame velocity.
3737 """
3838
39- _ , _ , _ , W_p_Ci , W_ṗ_Ci = jax .vmap (
40- jaxsim .rbda .contacts .common .compute_penetration_data , in_axes = (None ,)
39+ _ , _ , _ , W_p_Ci , W_ṗ_Ci = jax .vmap (
40+ lambda shape_transform , shape_type , shape_size , link_transform , link_velocity : jaxsim .rbda .contacts .common .compute_penetration_data (
41+ model ,
42+ shape_transform = shape_transform ,
43+ shape_type = shape_type ,
44+ shape_size = shape_size ,
45+ link_transforms = link_transform ,
46+ link_velocities = link_velocity ,
47+ )
4148 )(
42- model ,
43- shape_offset = model .kin_dyn_parameters .contact_parameters .center ,
44- shape_type = model .kin_dyn_parameters .contact_parameters .shape_type ,
45- shape_size = model .kin_dyn_parameters .contact_parameters .shape_size ,
46- link_transforms = data ._link_transforms ,
47- link_velocities = data ._link_velocities ,
49+ model .kin_dyn_parameters .contact_parameters .transform ,
50+ model .kin_dyn_parameters .contact_parameters .shape_type ,
51+ model .kin_dyn_parameters .contact_parameters .shape_size ,
52+ data ._link_transforms [
53+ jnp .array (model .kin_dyn_parameters .contact_parameters .body )
54+ ],
55+ data ._link_velocities [
56+ jnp .array (model .kin_dyn_parameters .contact_parameters .body )
57+ ],
4858 )
4959
50- return W_p_Ci , W_ṗ_Ci
60+ return W_p_Ci , W_ṗ_Ci
5161
5262
5363@jax .jit
@@ -241,21 +251,29 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
241251 # Get the transforms of the parent link of all collidable points.
242252 W_H_L = data ._link_transforms
243253
244- def _process_single_shape (shape_type , shape_size , W_H_Li ):
254+ # Index transforms by the body (parent link) of each collision shape
255+ body_indices = jnp .array (model .kin_dyn_parameters .contact_parameters .body )
256+ W_H_L_indexed = W_H_L [body_indices ]
257+
258+ def _process_single_shape (shape_type , shape_size , shape_transform , W_H_Li ):
259+ # Apply the collision shape transform to get W_H_S
260+ W_H_S = W_H_Li @ shape_transform
261+
245262 _ , W_H_C = jax .lax .switch (
246263 shape_type ,
247264 (detection .box_plane , detection .cylinder_plane , detection .sphere_plane ),
248265 model .terrain ,
249266 shape_size ,
250- W_H_Li ,
267+ W_H_S ,
251268 )
252269
253270 return W_H_C
254271
255272 return jax .vmap (_process_single_shape )(
256273 model .kin_dyn_parameters .contact_parameters .shape_type ,
257274 model .kin_dyn_parameters .contact_parameters .shape_size ,
258- W_H_L ,
275+ model .kin_dyn_parameters .contact_parameters .transform ,
276+ W_H_L_indexed ,
259277 )
260278
261279
@@ -294,13 +312,17 @@ def jacobian(
294312 model = model , data = data , output_vel_repr = VelRepr .Inertial
295313 )
296314
297- # Compute contact transforms (n_links, n_contacts , 4, 4)
315+ # Compute contact transforms (n_shapes, n_contacts_per_shape , 4, 4)
298316 W_H_C = transforms (model = model , data = data )
299317
300- # Flatten link × contact axes for single-batch processing (n_links*n_contacts, 6, 6+n)
301- W_J_WC_flat = jnp .repeat (W_J_WL , 3 , axis = 0 )
318+ # Index Jacobians by the body (parent link) of each collision shape
319+ body_indices = jnp .array (model .kin_dyn_parameters .contact_parameters .body )
320+ W_J_WL_indexed = W_J_WL [body_indices ] # (n_shapes, 6, 6+n)
321+
322+ # Repeat for each contact point per shape: (n_shapes*n_contacts_per_shape, 6, 6+n)
323+ W_J_WC_flat = jnp .repeat (W_J_WL_indexed , 3 , axis = 0 )
302324
303- # Flatten contact transforms (n_links*n_contacts , 4, 4)
325+ # Flatten contact transforms (n_shapes*n_contacts_per_shape , 4, 4)
304326 W_H_C_flat = W_H_C .reshape (- 1 , 4 , 4 )
305327
306328 # Transform Jacobian based on velocity representation
@@ -357,7 +379,11 @@ def jacobian_derivative(
357379 # Get the link velocities.
358380 W_v_WL = data ._link_velocities
359381
360- # Compute the contact transforms (n_links, n_contacts, 4, 4)
382+ # Index link velocities by body (parent link) of each collision shape
383+ body_indices = jnp .array (model .kin_dyn_parameters .contact_parameters .body )
384+ W_v_WL_indexed = W_v_WL [body_indices ] # (n_shapes, 6)
385+
386+ # Compute the contact transforms (n_shapes, n_contacts, 4, 4)
361387 W_H_C = transforms (model = model , data = data )
362388
363389 # =====================================================
@@ -408,6 +434,10 @@ def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix:
408434 model = model , data = data
409435 )
410436
437+ # Index Jacobians by body (parent link) of each collision shape
438+ W_J_WL_W_indexed = W_J_WL_W [body_indices ] # (n_shapes, 6, 6+n)
439+ W_J̇_WL_W_indexed = W_J̇_WL_W [body_indices ] # (n_shapes, 6, 6+n)
440+
411441 def compute_O_J̇_WC_I (W_H_C , W_v_WL , W_J_WL_W , W_J̇_WL_W ) -> jtp .Matrix :
412442 match output_vel_repr :
413443 case VelRepr .Inertial :
@@ -430,15 +460,15 @@ def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix:
430460
431461 return O_J̇_WC_I
432462
433- O_J̇_per_link = jax .vmap (
434- lambda H_C_link , v_WL_link , J_WL_link , J̇_WL_link : jax .vmap (
463+ O_J̇_per_shape = jax .vmap (
464+ lambda H_C_shape , v_WL_shape , J_WL_shape , J̇_WL_shape : jax .vmap (
435465 compute_O_J̇_WC_I ,
436466 in_axes = (0 , None , None , None ), # Map over contacts for W_H_C only
437- )(H_C_link , v_WL_link , J_WL_link , J̇_WL_link ),
438- in_axes = (0 , 0 , 0 , 0 ), # Map over links
439- )(W_H_C , W_v_WL , W_J_WL_W , W_J̇_WL_W )
467+ )(H_C_shape , v_WL_shape , J_WL_shape , J̇_WL_shape ),
468+ in_axes = (0 , 0 , 0 , 0 ), # Map over shapes
469+ )(W_H_C , W_v_WL_indexed , W_J_WL_W_indexed , W_J̇_WL_W_indexed )
440470
441- O_J̇_WC = O_J̇_per_link .reshape (- 1 , 6 , 6 + model .dofs ())
471+ O_J̇_WC = O_J̇_per_shape .reshape (- 1 , 6 , 6 + model .dofs ())
442472
443473 return O_J̇_WC
444474
0 commit comments