From d1fa58cdf053aa100ad13a6eb2fc6c8f18c2fcb4 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Mon, 8 Apr 2024 12:49:25 -0400 Subject: [PATCH] rotation --- .../_compose_rotation_quaternion.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/beignet/ops/_geometry/_transformations/_rotations/_compose_rotation_quaternion.py b/src/beignet/ops/_geometry/_transformations/_rotations/_compose_rotation_quaternion.py index 444e299c57..49da39c89d 100644 --- a/src/beignet/ops/_geometry/_transformations/_rotations/_compose_rotation_quaternion.py +++ b/src/beignet/ops/_geometry/_transformations/_rotations/_compose_rotation_quaternion.py @@ -39,12 +39,27 @@ def compose_rotation_quaternion( ) for j in range(max(input.shape[0], other.shape[0])): - output[j, 0] = input[j, 3] * other[j, 0] + other[j, 3] * input[j, 0] + input[j, 1] * other[j, 2] - input[j, 2] * other[j, 1] # fmt: skip - output[j, 1] = input[j, 3] * other[j, 1] + other[j, 3] * input[j, 1] + input[j, 2] * other[j, 0] - input[j, 0] * other[j, 2] # fmt: skip - output[j, 2] = input[j, 3] * other[j, 2] + other[j, 3] * input[j, 2] + input[j, 0] * other[j, 1] - input[j, 1] * other[j, 0] # fmt: skip - output[j, 3] = input[j, 3] * other[j, 3] - input[j, 0] * other[j, 0] - input[j, 1] * other[j, 1] - input[j, 2] * other[j, 2] # fmt: skip + a = input[j, 0] + b = input[j, 1] + c = input[j, 2] + d = input[j, 3] - x = torch.sqrt(output[j, 0] ** 2.0 + output[j, 1] ** 2.0 + output[j, 2] ** 2.0 + output[j, 3] ** 2.0) # fmt: skip + p = other[j, 0] + q = other[j, 1] + r = other[j, 2] + s = other[j, 3] + + t = output[j, 0] + u = output[j, 1] + v = output[j, 2] + w = output[j, 3] + + output[j, 0] = d * p + s * a + b * r - c * q + output[j, 1] = d * q + s * b + c * p - a * r + output[j, 2] = d * r + s * c + a * q - b * p + output[j, 3] = d * s - a * p - b * q - c * r + + x = torch.sqrt(t**2.0 + u**2.0 + v**2.0 + w**2.0) if x == 0.0: output[j] = torch.nan @@ -52,7 +67,7 @@ def compose_rotation_quaternion( output[j] = output[j] / x if canonical: - if output[j, 3] == 0 and (output[j, 0] == 0 and (output[j, 1] == 0 and output[j, 2] < 0 or output[j, 1] < 0) or output[j, 0] < 0) or output[j, 3] < 0: # fmt: skip + if w == 0 and (t == 0 and (u == 0 and v < 0 or u < 0) or t < 0) or w < 0: output[j] = -output[j] return output