Skip to content

Commit

Permalink
add test_R_to_quat
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlao committed Jul 9, 2024
1 parent ce986c9 commit fa9d459
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 50 deletions.
16 changes: 16 additions & 0 deletions camtools/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,22 @@ def create_zeros(
return torch.zeros(shape, dtype=dtype)


def create_empty(
shape: Tuple[int, ...],
dtype: Any,
backend: Literal["numpy", "torch"],
) -> Tensor:
"""
Call np.empty() or torch.empty() depending on the backend.
"""
if backend == "numpy":
return np.empty(shape, dtype=dtype)
elif backend == "torch":
if not is_torch_available():
raise ValueError("Torch is not available.")
return torch.empty(shape, dtype=dtype)


def get_tensor_backend(arr: Tensor) -> Literal["numpy", "torch"]:
"""
Get the backend of a tensor.
Expand Down
1 change: 1 addition & 0 deletions camtools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
torch,
create_array,
create_ones,
create_empty,
get_tensor_backend,
)
from jaxtyping import Float
Expand Down
102 changes: 52 additions & 50 deletions test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,58 @@ def test_from_homo():
ct.convert.from_homo(incorrect_in_val_b)


def test_R_to_quat():
theta = np.pi / 2 # 90 degree rotation

# Test for rotation around the x-axis
R_x = np.array(
[
[1, 0, 0],
[0, np.cos(theta), -np.sin(theta)],
[0, np.sin(theta), np.cos(theta)],
]
)
gt_quat_x = np.array([np.cos(theta / 2), np.sin(theta / 2), 0, 0])
quat_x = ct.convert.R_to_quat(R_x)
assert np.allclose(quat_x, gt_quat_x, atol=1e-5)

# Test for rotation around the y-axis
R_y = np.array(
[
[np.cos(theta), 0, np.sin(theta)],
[0, 1, 0],
[-np.sin(theta), 0, np.cos(theta)],
]
)
gt_quat_y = np.array([np.cos(theta / 2), 0, np.sin(theta / 2), 0])
quat_y = ct.convert.R_to_quat(R_y)
assert np.allclose(quat_y, gt_quat_y, atol=1e-5)

# Test for rotation around the z-axis
R_z = np.array(
[
[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1],
]
)
gt_quat_z = np.array([np.cos(theta / 2), 0, 0, np.sin(theta / 2)])
quat_z = ct.convert.R_to_quat(R_z)
assert np.allclose(quat_z, gt_quat_z, atol=1e-5)

# Test with a batch of complex rotations about different axes
R_batch = np.array([R_x, R_y, R_z])
gt_quat_batch = np.array(
[
[np.cos(theta / 2), np.sin(theta / 2), 0, 0],
[np.cos(theta / 2), 0, np.sin(theta / 2), 0],
[np.cos(theta / 2), 0, 0, np.sin(theta / 2)],
]
)
quat_batch = ct.convert.R_to_quat(R_batch)
assert np.allclose(quat_batch, gt_quat_batch, atol=1e-5)


def test_R_t_to_cameracenter():
T = np.array(
[
Expand Down Expand Up @@ -505,53 +557,3 @@ def gen_random_T():
rtol=1e-5,
atol=1e-5,
)


# def test_to_homo():
# # Regular case
# src = np.array(
# [
# [1, 2],
# [3, 4],
# ]
# )
# dst_gt = np.array(
# [
# [1, 2, 1],
# [3, 4, 1],
# ]
# )
# dst = ct.convert.to_homo(src)
# np.testing.assert_array_equal(dst, dst_gt)

# # Exception case
# with pytest.raises(ValueError) as _:
# src = np.array([1, 2, 3])
# ct.convert.to_homo(src)


# def test_from_homo():
# src = np.array(
# [
# [2, 4, 2],
# [6, 8, 1],
# ]
# )
# dst_gt = np.array(
# [
# [1, 2],
# [6, 8],
# ]
# )
# dst = ct.convert.from_homo(src)
# np.testing.assert_array_equal(dst, dst_gt)

# # Exception case for non-2D input
# with pytest.raises(ValueError) as _:
# src = np.array([1, 2, 3])
# ct.convert.from_homo(src)

# # Exception case for insufficient columns
# with pytest.raises(ValueError) as _:
# src = np.array([[1]])
# ct.convert.from_homo(src)

0 comments on commit fa9d459

Please sign in to comment.