Skip to content

Commit 5b5d505

Browse files
author
The jax3d Authors
committed
Add differentiable rigid body SE3 transforms.
PiperOrigin-RevId: 484325168
1 parent 6090d3a commit 5b5d505

File tree

4 files changed

+792
-0
lines changed

4 files changed

+792
-0
lines changed

jax3d/math/quaternion.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright 2022 The jax3d Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Quaternion math.
16+
17+
This module assumes the xyzw quaternion format where xyz is the imaginary part
18+
and w is the real part.
19+
20+
Functions in this module support both batched and unbatched quaternions.
21+
"""
22+
from jax import numpy as jnp
23+
from jax.numpy import linalg
24+
25+
26+
def safe_acos(t, eps=1e-7):
27+
"""A safe version of arccos which avoids evaluating at -1 or 1."""
28+
return jnp.arccos(jnp.clip(t, -1.0 + eps, 1.0 - eps))
29+
30+
31+
def im(q):
32+
"""Fetch the imaginary part of the quaternion."""
33+
return q[..., :3]
34+
35+
36+
def re(q):
37+
"""Fetch the real part of the quaternion."""
38+
return q[..., 3:]
39+
40+
41+
def identity():
42+
return jnp.array([0.0, 0.0, 0.0, 1.0])
43+
44+
45+
def conjugate(q):
46+
"""Compute the conjugate of a quaternion."""
47+
return jnp.concatenate([-im(q), re(q)], axis=-1)
48+
49+
50+
def inverse(q):
51+
"""Compute the inverse of a quaternion."""
52+
return normalize(conjugate(q))
53+
54+
55+
def normalize(q):
56+
"""Normalize a quaternion."""
57+
return q / norm(q)
58+
59+
60+
def norm(q):
61+
return linalg.norm(q, axis=-1, keepdims=True)
62+
63+
64+
def multiply(q1, q2):
65+
"""Multiply two quaternions."""
66+
c = (re(q1) * im(q2)
67+
+ re(q2) * im(q1)
68+
+ jnp.cross(im(q1), im(q2)))
69+
w = re(q1) * re(q2) - jnp.dot(im(q1), im(q2))
70+
return jnp.concatenate([c, w], axis=-1)
71+
72+
73+
def rotate(q, v):
74+
"""Rotate a vector using a quaternion."""
75+
# Create the quaternion representation of the vector.
76+
q_v = jnp.concatenate([v, jnp.zeros_like(v[..., :1])], axis=-1)
77+
return im(multiply(multiply(q, q_v), conjugate(q)))
78+
79+
80+
def log(q, eps=1e-8):
81+
"""Computes the quaternion logarithm.
82+
83+
References:
84+
https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
85+
86+
Args:
87+
q: the quaternion in (x,y,z,w) format.
88+
eps: an epsilon value for numerical stability.
89+
90+
Returns:
91+
The logarithm of q.
92+
"""
93+
mag = linalg.norm(q, axis=-1, keepdims=True)
94+
v = im(q)
95+
s = re(q)
96+
w = jnp.log(mag)
97+
denom = jnp.maximum(
98+
linalg.norm(v, axis=-1, keepdims=True), eps * jnp.ones_like(v))
99+
xyz = v / denom * safe_acos(s / eps)
100+
return jnp.concatenate((xyz, w), axis=-1)
101+
102+
103+
def exp(q, eps=1e-8):
104+
"""Computes the quaternion exponential.
105+
106+
References:
107+
https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
108+
109+
Args:
110+
q: the quaternion in (x,y,z,w) format or (x,y,z) if is_pure is True.
111+
eps: an epsilon value for numerical stability.
112+
113+
Returns:
114+
The exponential of q.
115+
"""
116+
is_pure = q.shape[-1] == 3
117+
if is_pure:
118+
s = jnp.zeros_like(q[..., -1:])
119+
v = q
120+
else:
121+
v = im(q)
122+
s = re(q)
123+
124+
norm_v = linalg.norm(v, axis=-1, keepdims=True)
125+
exp_s = jnp.exp(s)
126+
w = jnp.cos(norm_v)
127+
xyz = jnp.sin(norm_v) * v / jnp.maximum(norm_v, eps * jnp.ones_like(norm_v))
128+
return exp_s * jnp.concatenate((xyz, w), axis=-1)
129+
130+
131+
def to_rotation_matrix(q):
132+
"""Constructs a rotation matrix from a quaternion.
133+
134+
Args:
135+
q: a (*,4) array containing quaternions.
136+
137+
Returns:
138+
A (*,3,3) array containing rotation matrices.
139+
"""
140+
x, y, z, w = jnp.split(q, 4, axis=-1)
141+
s = 1.0 / jnp.sum(q ** 2, axis=-1)
142+
return jnp.stack([
143+
jnp.stack([1 - 2 * s * (y ** 2 + z ** 2),
144+
2 * s * (x * y - z * w),
145+
2 * s * (x * z + y * w)], axis=0),
146+
jnp.stack([2 * s * (x * y + z * w),
147+
1 - s * 2 * (x ** 2 + z ** 2),
148+
2 * s * (y * z - x * w)], axis=0),
149+
jnp.stack([2 * s * (x * z - y * w),
150+
2 * s * (y * z + x * w),
151+
1 - 2 * s * (x ** 2 + y ** 2)], axis=0),
152+
], axis=0)
153+
154+
155+
def from_rotation_matrix(m, eps=1e-9):
156+
"""Construct quaternion from a rotation matrix.
157+
158+
Args:
159+
m: a (*,3,3) array containing rotation matrices.
160+
eps: a small number for numerical stability.
161+
162+
Returns:
163+
A (*,4) array containing quaternions.
164+
"""
165+
trace = jnp.trace(m)
166+
m00 = m[..., 0, 0]
167+
m01 = m[..., 0, 1]
168+
m02 = m[..., 0, 2]
169+
m10 = m[..., 1, 0]
170+
m11 = m[..., 1, 1]
171+
m12 = m[..., 1, 2]
172+
m20 = m[..., 2, 0]
173+
m21 = m[..., 2, 1]
174+
m22 = m[..., 2, 2]
175+
176+
def tr_positive():
177+
sq = jnp.sqrt(trace + 1.0) * 2. # sq = 4 * w.
178+
w = 0.25 * sq
179+
x = jnp.divide(m21 - m12, sq)
180+
y = jnp.divide(m02 - m20, sq)
181+
z = jnp.divide(m10 - m01, sq)
182+
return jnp.stack((x, y, z, w), axis=-1)
183+
184+
def cond_1():
185+
sq = jnp.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * x.
186+
w = jnp.divide(m21 - m12, sq)
187+
x = 0.25 * sq
188+
y = jnp.divide(m01 + m10, sq)
189+
z = jnp.divide(m02 + m20, sq)
190+
return jnp.stack((x, y, z, w), axis=-1)
191+
192+
def cond_2():
193+
sq = jnp.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * y.
194+
w = jnp.divide(m02 - m20, sq)
195+
x = jnp.divide(m01 + m10, sq)
196+
y = 0.25 * sq
197+
z = jnp.divide(m12 + m21, sq)
198+
return jnp.stack((x, y, z, w), axis=-1)
199+
200+
def cond_3():
201+
sq = jnp.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * z.
202+
w = jnp.divide(m10 - m01, sq)
203+
x = jnp.divide(m02 + m20, sq)
204+
y = jnp.divide(m12 + m21, sq)
205+
z = 0.25 * sq
206+
return jnp.stack((x, y, z, w), axis=-1)
207+
208+
def cond_idx(cond):
209+
cond = jnp.expand_dims(cond, -1)
210+
cond = jnp.tile(cond, [1] * (len(m.shape) - 2) + [4])
211+
return cond
212+
213+
where_2 = jnp.where(cond_idx(m11 > m22), cond_2(), cond_3())
214+
where_1 = jnp.where(cond_idx((m00 > m11) & (m00 > m22)), cond_1(), where_2)
215+
return jnp.where(cond_idx(trace > 0), tr_positive(), where_1)
216+
217+
218+
def from_axis_angle(axis, theta):
219+
"""Constructs a quaternion for the given axis/angle rotation."""
220+
qx = axis[0] * jnp.sin(theta / 2)
221+
qy = axis[1] * jnp.sin(theta / 2)
222+
qz = axis[2] * jnp.sin(theta / 2)
223+
qw = jnp.cos(theta / 2)
224+
225+
return jnp.squeeze(jnp.array([qx, qy, qz, qw]))

jax3d/math/quaternion_test.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2022 The jax3d Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for quaternions."""
16+
17+
import functools
18+
import math
19+
import unittest
20+
21+
from jax import random
22+
import jax.numpy as jnp
23+
from jax3d.math import quaternion
24+
import pytest
25+
26+
27+
TEST_BATCH_SIZE = 128
28+
29+
30+
class QuaternionTest(unittest.TestCase):
31+
32+
def setUp(self):
33+
super().setUp()
34+
self._seed = 42
35+
self._key = random.PRNGKey(self._seed)
36+
37+
def test_identity(self):
38+
identity = quaternion.identity()
39+
self.assertLen(identity, 4)
40+
self.assertEqual(identity.tolist(), [0.0, 0.0, 0.0, 1.0])
41+
42+
@pytest.mark.parametrize(('single', (4,)), ('batched', (TEST_BATCH_SIZE, 4)))
43+
@pytest.mark.parametrize('shape', )
44+
def test_real_imaginary_part(self, shape):
45+
if len(shape) > 1:
46+
num_quaternions = shape[0]
47+
else:
48+
num_quaternions = 1
49+
random_quat = random.uniform(self._key, shape=shape)
50+
imaginary = quaternion.im(random_quat)
51+
real = quaternion.re(random_quat)
52+
53+
# The first three components are imaginary and the fourth is real.
54+
self.assertEqual(jnp.prod(jnp.array(imaginary.shape)), num_quaternions * 3)
55+
self.assertEqual(jnp.prod(jnp.array(real.shape)), num_quaternions)
56+
self.assertEqual(random_quat[..., :3].tolist(), imaginary[..., :].tolist())
57+
self.assertEqual(random_quat[..., 3:].tolist(), real[..., :].tolist())
58+
59+
@pytest.mark.parametrize('batch', [None, TEST_BATCH_SIZE])
60+
@pytest.mark.parametrize('func', [random.uniform, jnp.ones, jnp.zeros])
61+
@pytest.mark.parametrize('sign', [-1, 1])
62+
def test_safe_acos(self, batch, func, sign):
63+
# We need a seed to generate random numbers.
64+
if func == random.uniform:
65+
func = functools.partial(func, key=self._key)
66+
67+
if batch:
68+
shape = (batch, 4)
69+
else:
70+
shape = (4,)
71+
t = sign * func(shape=shape)
72+
73+
output = quaternion.safe_acos(t)
74+
75+
# All elements must be within the range of the arc-cosine function.
76+
self.assertTrue(jnp.all(output > 0))
77+
self.assertTrue(jnp.all(output < math.pi))
78+
79+
@pytest.mark.parametrize(('single', None), ('batched', TEST_BATCH_SIZE))
80+
def test_conjugate(self, batch):
81+
if batch:
82+
shape = (batch, 4)
83+
else:
84+
shape = (4,)
85+
quat = random.uniform(self._key, shape=shape)
86+
conjugate = quaternion.conjugate(quat)
87+
self.assertTrue(jnp.all(-1 * quat[..., :3] == conjugate[..., :3]))
88+
self.assertTrue(jnp.all(quat[..., 3:] == conjugate[..., 3:]))
89+
90+
@pytest.mark.parametrize(('single', None), ('batched', TEST_BATCH_SIZE))
91+
def test_normalize(self, batch):
92+
eps = 1e-6
93+
if batch:
94+
shape = (batch, 4)
95+
else:
96+
shape = (4,)
97+
q = random.uniform(self._key, shape=shape)
98+
self.assertTrue(jnp.all(jnp.abs(quaternion.norm(q) - 1) > eps))
99+
q_norm = quaternion.normalize(q)
100+
self.assertTrue(jnp.all(jnp.abs(quaternion.norm(q_norm) - 1) < eps))

0 commit comments

Comments
 (0)