Skip to content

Commit

Permalink
rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 10, 2024
1 parent 366f002 commit d60b947
Showing 1 changed file with 29 additions and 107 deletions.
136 changes: 29 additions & 107 deletions src/beignet/ops/_geometry/_transformations/_quaternion_slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,128 +44,50 @@ def quaternion_slerp(
Rotation quaternions. Rotation quaternions are normalized to unit norm.
""" # noqa: E501
if time.shape[-1] != rotation.shape[-2]:
raise ValueError("`times` and `rotations` must match in size.")
raise ValueError

interpolated_rotation_quaternions = torch.empty(
output = torch.empty(
[*input.shape, 4],
dtype=input.dtype,
layout=input.layout,
device=input.device,
)

for index, t in enumerate(input):
# FIND KEYFRAMES IN INTERVAL, ASSUME `times` IS SORTED:
b = torch.min(torch.nonzero(torch.greater_equal(time, t)))

if b > 0:
a = b - 1
else:
a = b

# IF `times` MATCHES A KEYFRAME:
if time[b] == t or b == a:
interpolated_rotation_quaternions[index] = rotation[b]
output[index] = rotation[b]

continue

time_0, time_1 = time[a], time[b]

relative_time = torch.divide(
torch.subtract(
t,
time_0,
),
torch.subtract(
time_1,
time_0,
),
)

quaternion_0, quaternion_1 = rotation[a], rotation[b]

cos_theta = torch.dot(quaternion_0, quaternion_1)

# NOTE: IF `cos_theta` IS NEGATIVE, `rotations` HAS OPPOSITE HANDEDNESS
# SO `slerp` WON’T TAKE THE SHORTER PATH. INVERT `quaternion_1`.
# @0X00B1, FEBRUARY, 26, 2024
if cos_theta < 0.0:
quaternion_1 = -quaternion_1

cos_theta = torch.negative(cos_theta)

# USE LINEAR INTERPOLATION IF QUATERNIONS ARE CLOSE:
if cos_theta > 0.9995:
interpolated_rotation_quaternion = torch.add(
torch.multiply(
torch.subtract(
torch.tensor(1.0),
relative_time,
),
quaternion_0,
),
torch.multiply(
relative_time,
quaternion_1,
),
)
p, q = time[a], time[b]

r = (t - p) / (q - p)

t = rotation[a]
u = rotation[b]

v = torch.dot(t, u)

if v < 0.0:
u = -u
v = -v

if v > 0.9995:
z = (1.0 - r) * t + r * u
else:
interpolated_rotation_quaternion = torch.add(
torch.multiply(
quaternion_0,
torch.divide(
torch.sin(
torch.multiply(
torch.subtract(
torch.tensor(1.0),
relative_time,
),
torch.atan2(
torch.sqrt(
torch.subtract(
torch.tensor(1.0),
torch.square(cos_theta),
)
),
cos_theta,
),
),
),
torch.sqrt(
torch.subtract(
torch.tensor(1.0),
torch.square(cos_theta),
)
),
),
),
torch.multiply(
quaternion_1,
torch.divide(
torch.sin(
torch.multiply(
relative_time,
torch.atan2(
torch.sqrt(
torch.subtract(
torch.tensor(1.0),
torch.square(cos_theta),
)
),
cos_theta,
),
),
),
torch.sqrt(
torch.subtract(
torch.tensor(1.0),
torch.square(cos_theta),
),
),
),
),
)

interpolated_rotation_quaternions[index] = torch.divide(
interpolated_rotation_quaternion,
torch.norm(interpolated_rotation_quaternion),
)

return interpolated_rotation_quaternions
x = torch.sqrt(1.0 - v**2.0)

y = torch.atan2(x, v)

z = t * torch.sin((1.0 - r) * y) / x + u * torch.sin(r * y) / x

output[index] = z / torch.linalg.norm(z)

return output

0 comments on commit d60b947

Please sign in to comment.