1
1
import torch
2
- from pytorch_kinematics .transforms .rotation_conversions import axis_angle_to_matrix
2
+ from pytorch_kinematics .transforms .rotation_conversions import axis_and_angle_to_matrix_33
3
3
4
4
5
- def sample_perturbations (T , num_perturbations , radian_sigma , translation_sigma ):
5
+ def sample_perturbations (T , num_perturbations , radian_sigma , translation_sigma , axis_of_rotation = None ,
6
+ translation_perpendicular_to_axis_of_rotation = True ):
6
7
"""
7
8
Sample perturbations around the given transform. The translation and rotation are sampled independently from
8
9
0 mean gaussians. The angular perturbations' directions are uniformly sampled from the unit sphere while its
@@ -11,18 +12,33 @@ def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma):
11
12
:param num_perturbations: number of perturbations to sample
12
13
:param radian_sigma: standard deviation of the gaussian angular perturbation in radians
13
14
:param translation_sigma: standard deviation of the gaussian translation perturbation in meters / T units
15
+ :param axis_of_rotation: if not None, the axis of rotation to sample the perturbations around
16
+ :param translation_perpendicular_to_axis_of_rotation: if True and the axis_of_rotation is not None, the translation
17
+ perturbations will be perpendicular to the axis of rotation
14
18
:return: perturbed transforms; may not include the original transform
15
19
"""
16
20
dtype = T .dtype
17
21
device = T .device
18
22
perturbed = torch .eye (4 , dtype = dtype , device = device ).repeat (num_perturbations , 1 , 1 )
19
23
20
- delta_R = torch .randn ((num_perturbations , 3 ), dtype = dtype , device = device ) * radian_sigma
21
- delta_R = axis_angle_to_matrix (delta_R )
24
+ delta_t = torch .randn ((num_perturbations , 3 ), dtype = dtype , device = device ) * translation_sigma
25
+ # consider sampling from the Bingham distribution
26
+ theta = torch .randn (num_perturbations , dtype = dtype , device = device ) * radian_sigma
27
+ if axis_of_rotation is not None :
28
+ axis_angle = axis_of_rotation
29
+ # sample translation perturbation perpendicular to the axis of rotation
30
+ # remove the component of delta_t along the axis_of_rotation
31
+ if translation_perpendicular_to_axis_of_rotation :
32
+ delta_t -= (delta_t * axis_of_rotation ).sum (dim = 1 , keepdim = True ) * axis_of_rotation
33
+ else :
34
+ axis_angle = torch .randn ((num_perturbations , 3 ), dtype = dtype , device = device )
35
+ # normalize to unit length
36
+ axis_angle = axis_angle / axis_angle .norm (dim = 1 , keepdim = True )
37
+
38
+ delta_R = axis_and_angle_to_matrix_33 (axis_angle , theta )
22
39
perturbed [:, :3 , :3 ] = delta_R @ T [..., :3 , :3 ]
23
40
perturbed [:, :3 , 3 ] = T [..., :3 , 3 ]
24
41
25
- delta_t = torch .randn ((num_perturbations , 3 ), dtype = dtype , device = device ) * translation_sigma
26
42
perturbed [:, :3 , 3 ] += delta_t
27
43
28
44
return perturbed
0 commit comments