Skip to content

Commit

Permalink
rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 5, 2024
1 parent 7af0eec commit bb22697
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/beignet/_rotation_matrix_to_rotation_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def rotation_matrix_to_rotation_quaternion(
c = output[j, 2]
d = output[j, 3]

if d == 0 and (a == 0 and (b == 0 and c < 0 or b < 0) or a < 0) or d < 0:
if d == 0 and (a == 0 & (b == 0 & c < 0 | b < 0) | a < 0) | d < 0:
output[j] = -output[j]

return output
35 changes: 14 additions & 21 deletions src/beignet/_rotation_quaternion_magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,30 @@ def rotation_quaternion_magnitude(input: Tensor, canonical=False) -> Tensor:
Parameters
----------
input : Tensor, shape (..., 4)
input : Tensor, shape=(..., 4)
Rotation quaternions.
Returns
-------
rotation_quaternion_magnitudes: Tensor, shape (...)
output : Tensor, shape=(...)
Angles in radians. Magnitudes will be in the range :math:`[0, \pi]`.
"""
angles = torch.empty(
output = torch.empty(
input.shape[0],
dtype=input.dtype,
layout=input.layout,
device=input.device,
requires_grad=input.requires_grad,
)

for index in range(input.shape[0]):
angles[index] = torch.multiply(
torch.atan2(
torch.sqrt(
torch.add(
torch.add(
torch.square(input[index, 0]),
torch.square(input[index, 1]),
),
torch.square(input[index, 2]),
),
),
torch.abs(input[index, 3]),
),
2.0,
)

return angles
for j in range(input.shape[0]):
a = input[j, 0]
b = input[j, 1]
c = input[j, 2]
d = input[j, 3]

x = torch.atan2(torch.sqrt(a**2 + b**2 + c**2), torch.abs(d))

output[j] = x * 2.0

return output
2 changes: 1 addition & 1 deletion src/beignet/_rotation_vector_to_rotation_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def rotation_vector_to_rotation_quaternion(
c = output[j, 2]
d = output[j, 3]

if d == 0 and (a == 0 and (b == 0 and c < 0 or b < 0) or a < 0) or d < 0:
if d == 0 and (a == 0 & (b == 0 & c < 0 | b < 0) | a < 0) | d < 0:
output[j] = -output[j]

return output

0 comments on commit bb22697

Please sign in to comment.