Skip to content

Commit 11de9ad

Browse files
committed
Sample 2D rotations when given axis of rotation
1 parent dbb641a commit 11de9ad

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
2-
from pytorch_kinematics.transforms.rotation_conversions import axis_angle_to_matrix
2+
from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33
33

44

5-
def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma):
5+
def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma, axis_of_rotation=None,
6+
translation_perpendicular_to_axis_of_rotation=True):
67
"""
78
Sample perturbations around the given transform. The translation and rotation are sampled independently from
89
0 mean gaussians. The angular perturbations' directions are uniformly sampled from the unit sphere while its
@@ -11,18 +12,33 @@ def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma):
1112
:param num_perturbations: number of perturbations to sample
1213
:param radian_sigma: standard deviation of the gaussian angular perturbation in radians
1314
:param translation_sigma: standard deviation of the gaussian translation perturbation in meters / T units
15+
:param axis_of_rotation: if not None, the axis of rotation to sample the perturbations around
16+
:param translation_perpendicular_to_axis_of_rotation: if True and the axis_of_rotation is not None, the translation
17+
perturbations will be perpendicular to the axis of rotation
1418
:return: perturbed transforms; may not include the original transform
1519
"""
1620
dtype = T.dtype
1721
device = T.device
1822
perturbed = torch.eye(4, dtype=dtype, device=device).repeat(num_perturbations, 1, 1)
1923

20-
delta_R = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * radian_sigma
21-
delta_R = axis_angle_to_matrix(delta_R)
24+
delta_t = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * translation_sigma
25+
# consider sampling from the Bingham distribution
26+
theta = torch.randn(num_perturbations, dtype=dtype, device=device) * radian_sigma
27+
if axis_of_rotation is not None:
28+
axis_angle = axis_of_rotation
29+
# sample translation perturbation perpendicular to the axis of rotation
30+
# remove the component of delta_t along the axis_of_rotation
31+
if translation_perpendicular_to_axis_of_rotation:
32+
delta_t -= (delta_t * axis_of_rotation).sum(dim=1, keepdim=True) * axis_of_rotation
33+
else:
34+
axis_angle = torch.randn((num_perturbations, 3), dtype=dtype, device=device)
35+
# normalize to unit length
36+
axis_angle = axis_angle / axis_angle.norm(dim=1, keepdim=True)
37+
38+
delta_R = axis_and_angle_to_matrix_33(axis_angle, theta)
2239
perturbed[:, :3, :3] = delta_R @ T[..., :3, :3]
2340
perturbed[:, :3, 3] = T[..., :3, 3]
2441

25-
delta_t = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * translation_sigma
2642
perturbed[:, :3, 3] += delta_t
2743

2844
return perturbed

0 commit comments

Comments
 (0)