1+ # Copyright 2022 The jax3d Authors.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ """Quaternion math.
16+
17+ This module assumes the xyzw quaternion format where xyz is the imaginary part
18+ and w is the real part.
19+
20+ Functions in this module support both batched and unbatched quaternions.
21+ """
22+ from jax import numpy as jnp
23+ from jax .numpy import linalg
24+
25+
26+ def safe_acos (t , eps = 1e-7 ):
27+ """A safe version of arccos which avoids evaluating at -1 or 1."""
28+ return jnp .arccos (jnp .clip (t , - 1.0 + eps , 1.0 - eps ))
29+
30+
31+ def im (q ):
32+ """Fetch the imaginary part of the quaternion."""
33+ return q [..., :3 ]
34+
35+
36+ def re (q ):
37+ """Fetch the real part of the quaternion."""
38+ return q [..., 3 :]
39+
40+
41+ def identity ():
42+ return jnp .array ([0.0 , 0.0 , 0.0 , 1.0 ])
43+
44+
45+ def conjugate (q ):
46+ """Compute the conjugate of a quaternion."""
47+ return jnp .concatenate ([- im (q ), re (q )], axis = - 1 )
48+
49+
50+ def inverse (q ):
51+ """Compute the inverse of a quaternion."""
52+ return normalize (conjugate (q ))
53+
54+
55+ def normalize (q ):
56+ """Normalize a quaternion."""
57+ return q / norm (q )
58+
59+
60+ def norm (q ):
61+ return linalg .norm (q , axis = - 1 , keepdims = True )
62+
63+
64+ def multiply (q1 , q2 ):
65+ """Multiply two quaternions."""
66+ c = (re (q1 ) * im (q2 )
67+ + re (q2 ) * im (q1 )
68+ + jnp .cross (im (q1 ), im (q2 )))
69+ w = re (q1 ) * re (q2 ) - jnp .dot (im (q1 ), im (q2 ))
70+ return jnp .concatenate ([c , w ], axis = - 1 )
71+
72+
73+ def rotate (q , v ):
74+ """Rotate a vector using a quaternion."""
75+ # Create the quaternion representation of the vector.
76+ q_v = jnp .concatenate ([v , jnp .zeros_like (v [..., :1 ])], axis = - 1 )
77+ return im (multiply (multiply (q , q_v ), conjugate (q )))
78+
79+
80+ def log (q , eps = 1e-8 ):
81+ """Computes the quaternion logarithm.
82+
83+ References:
84+ https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
85+
86+ Args:
87+ q: the quaternion in (x,y,z,w) format.
88+ eps: an epsilon value for numerical stability.
89+
90+ Returns:
91+ The logarithm of q.
92+ """
93+ mag = linalg .norm (q , axis = - 1 , keepdims = True )
94+ v = im (q )
95+ s = re (q )
96+ w = jnp .log (mag )
97+ denom = jnp .maximum (
98+ linalg .norm (v , axis = - 1 , keepdims = True ), eps * jnp .ones_like (v ))
99+ xyz = v / denom * safe_acos (s / eps )
100+ return jnp .concatenate ((xyz , w ), axis = - 1 )
101+
102+
103+ def exp (q , eps = 1e-8 ):
104+ """Computes the quaternion exponential.
105+
106+ References:
107+ https://en.wikipedia.org/wiki/Quaternion#Exponential,_logarithm,_and_power_functions
108+
109+ Args:
110+ q: the quaternion in (x,y,z,w) format or (x,y,z) if is_pure is True.
111+ eps: an epsilon value for numerical stability.
112+
113+ Returns:
114+ The exponential of q.
115+ """
116+ is_pure = q .shape [- 1 ] == 3
117+ if is_pure :
118+ s = jnp .zeros_like (q [..., - 1 :])
119+ v = q
120+ else :
121+ v = im (q )
122+ s = re (q )
123+
124+ norm_v = linalg .norm (v , axis = - 1 , keepdims = True )
125+ exp_s = jnp .exp (s )
126+ w = jnp .cos (norm_v )
127+ xyz = jnp .sin (norm_v ) * v / jnp .maximum (norm_v , eps * jnp .ones_like (norm_v ))
128+ return exp_s * jnp .concatenate ((xyz , w ), axis = - 1 )
129+
130+
131+ def to_rotation_matrix (q ):
132+ """Constructs a rotation matrix from a quaternion.
133+
134+ Args:
135+ q: a (*,4) array containing quaternions.
136+
137+ Returns:
138+ A (*,3,3) array containing rotation matrices.
139+ """
140+ x , y , z , w = jnp .split (q , 4 , axis = - 1 )
141+ s = 1.0 / jnp .sum (q ** 2 , axis = - 1 )
142+ return jnp .stack ([
143+ jnp .stack ([1 - 2 * s * (y ** 2 + z ** 2 ),
144+ 2 * s * (x * y - z * w ),
145+ 2 * s * (x * z + y * w )], axis = 0 ),
146+ jnp .stack ([2 * s * (x * y + z * w ),
147+ 1 - s * 2 * (x ** 2 + z ** 2 ),
148+ 2 * s * (y * z - x * w )], axis = 0 ),
149+ jnp .stack ([2 * s * (x * z - y * w ),
150+ 2 * s * (y * z + x * w ),
151+ 1 - 2 * s * (x ** 2 + y ** 2 )], axis = 0 ),
152+ ], axis = 0 )
153+
154+
155+ def from_rotation_matrix (m , eps = 1e-9 ):
156+ """Construct quaternion from a rotation matrix.
157+
158+ Args:
159+ m: a (*,3,3) array containing rotation matrices.
160+ eps: a small number for numerical stability.
161+
162+ Returns:
163+ A (*,4) array containing quaternions.
164+ """
165+ trace = jnp .trace (m )
166+ m00 = m [..., 0 , 0 ]
167+ m01 = m [..., 0 , 1 ]
168+ m02 = m [..., 0 , 2 ]
169+ m10 = m [..., 1 , 0 ]
170+ m11 = m [..., 1 , 1 ]
171+ m12 = m [..., 1 , 2 ]
172+ m20 = m [..., 2 , 0 ]
173+ m21 = m [..., 2 , 1 ]
174+ m22 = m [..., 2 , 2 ]
175+
176+ def tr_positive ():
177+ sq = jnp .sqrt (trace + 1.0 ) * 2. # sq = 4 * w.
178+ w = 0.25 * sq
179+ x = jnp .divide (m21 - m12 , sq )
180+ y = jnp .divide (m02 - m20 , sq )
181+ z = jnp .divide (m10 - m01 , sq )
182+ return jnp .stack ((x , y , z , w ), axis = - 1 )
183+
184+ def cond_1 ():
185+ sq = jnp .sqrt (1.0 + m00 - m11 - m22 + eps ) * 2. # sq = 4 * x.
186+ w = jnp .divide (m21 - m12 , sq )
187+ x = 0.25 * sq
188+ y = jnp .divide (m01 + m10 , sq )
189+ z = jnp .divide (m02 + m20 , sq )
190+ return jnp .stack ((x , y , z , w ), axis = - 1 )
191+
192+ def cond_2 ():
193+ sq = jnp .sqrt (1.0 + m11 - m00 - m22 + eps ) * 2. # sq = 4 * y.
194+ w = jnp .divide (m02 - m20 , sq )
195+ x = jnp .divide (m01 + m10 , sq )
196+ y = 0.25 * sq
197+ z = jnp .divide (m12 + m21 , sq )
198+ return jnp .stack ((x , y , z , w ), axis = - 1 )
199+
200+ def cond_3 ():
201+ sq = jnp .sqrt (1.0 + m22 - m00 - m11 + eps ) * 2. # sq = 4 * z.
202+ w = jnp .divide (m10 - m01 , sq )
203+ x = jnp .divide (m02 + m20 , sq )
204+ y = jnp .divide (m12 + m21 , sq )
205+ z = 0.25 * sq
206+ return jnp .stack ((x , y , z , w ), axis = - 1 )
207+
208+ def cond_idx (cond ):
209+ cond = jnp .expand_dims (cond , - 1 )
210+ cond = jnp .tile (cond , [1 ] * (len (m .shape ) - 2 ) + [4 ])
211+ return cond
212+
213+ where_2 = jnp .where (cond_idx (m11 > m22 ), cond_2 (), cond_3 ())
214+ where_1 = jnp .where (cond_idx ((m00 > m11 ) & (m00 > m22 )), cond_1 (), where_2 )
215+ return jnp .where (cond_idx (trace > 0 ), tr_positive (), where_1 )
216+
217+
218+ def from_axis_angle (axis , theta ):
219+ """Constructs a quaternion for the given axis/angle rotation."""
220+ qx = axis [0 ] * jnp .sin (theta / 2 )
221+ qy = axis [1 ] * jnp .sin (theta / 2 )
222+ qz = axis [2 ] * jnp .sin (theta / 2 )
223+ qw = jnp .cos (theta / 2 )
224+
225+ return jnp .squeeze (jnp .array ([qx , qy , qz , qw ]))
0 commit comments