Skip to content

Commit 535d1e1

Browse files
committed
Add quaternion angular distance and slerp
1 parent 19b722b commit 535d1e1

File tree

3 files changed

+108
-1
lines changed

3 files changed

+108
-1
lines changed

src/pytorch_kinematics/transforms/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
quaternion_raw_multiply,
1515
quaternion_to_matrix,
1616
quaternion_from_euler,
17+
quaternion_to_axis_angle,
1718
random_quaternions,
1819
random_rotation,
1920
random_rotations,
@@ -35,5 +36,11 @@
3536
so3_rotation_angle,
3637
)
3738
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
39+
from pytorch_kinematics.transforms.math import (
40+
quaternion_angular_distance,
41+
acos_linear_extrapolation,
42+
quaternion_close,
43+
quaternion_slerp,
44+
)
3845

3946
__all__ = [k for k in globals().keys() if not k.startswith("_")]

src/pytorch_kinematics/transforms/math.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,73 @@
1010
import torch
1111

1212

13+
def quaternion_angular_distance(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
14+
"""
15+
Computes the angular distance between two quaternions.
16+
Args:
17+
q1: First quaternion (assume normalized).
18+
q2: Second quaternion (assume normalized).
19+
Returns:
20+
Angular distance between the two quaternions.
21+
"""
22+
23+
# Compute the cosine of the angle between the two quaternions
24+
cos_theta = torch.sum(q1 * q2, dim=-1)
25+
# we use atan2 instead of acos for better numerical stability
26+
cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
27+
abs_dot = torch.abs(cos_theta)
28+
# identity sin^2(theta) = 1 - cos^2(theta)
29+
sin_half_theta = torch.sqrt(1.0 - torch.square(abs_dot))
30+
theta = 2.0 * torch.atan2(sin_half_theta, abs_dot)
31+
32+
# theta for the ones that are close gets 0 and we don't care about them
33+
close = quaternion_close(q1, q2)
34+
theta[close] = 0
35+
return theta
36+
37+
1338
def quaternion_close(q1: torch.Tensor, q2: torch.Tensor, eps: float = 1e-4):
1439
"""
1540
Returns true if two quaternions are close to each other. Assumes the quaternions are normalized.
1641
Based on: https://math.stackexchange.com/a/90098/516340
1742
1843
"""
19-
dist = 1 - torch.square(torch.sum(q1*q2, dim=-1))
44+
dist = 1 - torch.square(torch.sum(q1 * q2, dim=-1))
2045
return torch.all(dist < eps)
2146

2247

48+
def quaternion_slerp(q1: torch.Tensor, q2: torch.Tensor, t: Union[float, torch.tensor]) -> torch.Tensor:
49+
"""
50+
Spherical linear interpolation between two quaternions.
51+
Args:
52+
q1: First quaternion (assume normalized).
53+
q2: Second quaternion (assume normalized).
54+
t: Interpolation parameter.
55+
Returns:
56+
Interpolated quaternion.
57+
"""
58+
# Compute the cosine of the angle between the two quaternions
59+
cos_theta = torch.sum(q1 * q2, dim=-1)
60+
61+
# reverse the direction of q2 if q1 and q2 are not in the same hemisphere
62+
to_invert = cos_theta < 0
63+
q2[to_invert] = -q2[to_invert]
64+
cos_theta[to_invert] = -cos_theta[to_invert]
65+
66+
# If the quaternions are close, perform a linear interpolation
67+
if torch.all(cos_theta > 1.0 - 1e-6):
68+
return q1 + t * (q2 - q1)
69+
70+
# Ensure the angle is between 0 and pi
71+
theta = torch.acos(cos_theta)
72+
sin_theta = torch.sin(theta)
73+
74+
# Perform the interpolation
75+
w1 = torch.sin((1.0 - t) * theta) / sin_theta
76+
w2 = torch.sin(t * theta) / sin_theta
77+
return w1[:, None] * q1 + w2[:, None] * q2
78+
79+
2380
def acos_linear_extrapolation(
2481
x: torch.Tensor,
2582
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,

tests/test_transform.py

+43
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22

33
import pytorch_kinematics.transforms as tf
4+
import pytorch_kinematics as pk
45

56

67
def test_transform():
@@ -106,11 +107,39 @@ def test_euler():
106107

107108

108109
def test_quaternions():
110+
import pytorch_seed
111+
pytorch_seed.seed(0)
112+
109113
n = 10
110114
q = tf.random_quaternions(n)
111115
q_tf = tf.wxyz_to_xyzw(q)
112116
assert torch.allclose(q, tf.xyzw_to_wxyz(q_tf))
113117

118+
qq = pk.standardize_quaternion(q)
119+
assert torch.allclose(qq.norm(dim=-1), torch.ones(n))
120+
121+
# random quaternions should already be unit quaternions
122+
assert torch.allclose(q, qq)
123+
124+
# distances to themselves should be zero
125+
d = pk.quaternion_angular_distance(q, q)
126+
assert torch.allclose(d, torch.zeros(n))
127+
# q = -q
128+
d = pk.quaternion_angular_distance(q, -q)
129+
assert torch.allclose(d, torch.zeros(n))
130+
131+
axis = torch.tensor([0.0, 0.5, 0.5])
132+
axis = axis / axis.norm()
133+
magnitudes = torch.tensor([2.32, 1.56, -0.52, 0.1])
134+
n = len(magnitudes)
135+
aa_1 = axis.repeat(n, 1)
136+
aa_2 = axis * magnitudes[:, None]
137+
q1 = pk.axis_angle_to_quaternion(aa_1)
138+
q2 = pk.axis_angle_to_quaternion(aa_2)
139+
d = pk.quaternion_angular_distance(q1, q2)
140+
expected_d = (magnitudes - 1).abs()
141+
assert torch.allclose(d, expected_d, atol=1e-4)
142+
114143

115144
def test_compose():
116145
import torch
@@ -124,6 +153,19 @@ def test_compose():
124153
print(a2c.transform_points(torch.zeros([1, 3])))
125154

126155

156+
def test_quaternion_slerp():
157+
q = tf.random_quaternions(20)
158+
q1 = q[:10]
159+
q2 = q[10:]
160+
t = torch.rand(10)
161+
q_interp = pk.quaternion_slerp(q1, q2, t)
162+
# check the distance between them is consistent
163+
full_dist = pk.quaternion_angular_distance(q1, q2)
164+
interp_dist = pk.quaternion_angular_distance(q1, q_interp)
165+
# print(f"full_dist: {full_dist} interp_dist: {interp_dist} t: {t}")
166+
assert torch.allclose(full_dist * t, interp_dist, atol=1e-5)
167+
168+
127169
if __name__ == "__main__":
128170
test_compose()
129171
test_transform()
@@ -132,3 +174,4 @@ def test_compose():
132174
test_rotate()
133175
test_euler()
134176
test_quaternions()
177+
test_quaternion_slerp()

0 commit comments

Comments
 (0)