Skip to content

Commit 0eed466

Browse files
committed
Ensure IK solutions are deterministic given random seed
1 parent baa6ced commit 0eed466

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

src/pytorch_kinematics/ik.py

+10
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,12 @@ def __init__(self, serial_chain: SerialChain,
205205
# could give a batch of initial configs
206206
self.num_retries = self.initial_config.shape[-2]
207207

208+
def clear(self):
209+
self.err = None
210+
self.err_all = None
211+
self.err_min = None
212+
self.no_improve_counter = None
213+
208214
def sample_configs(self, num_configs: int) -> torch.Tensor:
209215
if self.config_sampling_method == "uniform":
210216
# bound by joint_limits
@@ -274,6 +280,8 @@ def compute_dq(self, J, dx):
274280
return dq
275281

276282
def solve(self, target_poses: Transform3d) -> IKSolution:
283+
self.clear()
284+
277285
target = target_poses.get_matrix()
278286

279287
M = target.shape[0]
@@ -291,6 +299,8 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
291299
elif q.numel() == self.dof * self.num_retries:
292300
# repeat and manually flatten it
293301
q = self.initial_config.repeat(M, 1)
302+
elif q.numel() == self.dof:
303+
q = q.unsqueeze(0).repeat(M * self.num_retries, 1)
294304
else:
295305
raise ValueError(
296306
f"initial_config must have shape ({M}, {self.num_retries}, {self.dof}) or ({self.num_retries}, {self.dof})")

tests/test_inverse_kinematics.py

+53
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ def test_jacobian_follower():
7272
print("IK took %d iterations" % sol.iterations)
7373
print("IK solved %d / %d goals" % (sol.converged_any.sum(), M))
7474

75+
# check that solving again produces the same solutions
76+
sol_again = ik.solve(goal_in_rob_frame_tf)
77+
assert torch.allclose(sol.solutions, sol_again.solutions)
78+
assert torch.allclose(sol.converged, sol_again.converged)
79+
7580
# visualize everything
7681
if visualize:
7782
p.connect(p.GUI)
@@ -132,5 +137,53 @@ def test_jacobian_follower():
132137
p.stepSimulation()
133138

134139

140+
def test_ik_in_place_no_err():
141+
pytorch_seed.seed(2)
142+
device = "cuda" if torch.cuda.is_available() else "cpu"
143+
# device = "cpu"
144+
urdf = "kuka_iiwa/model.urdf"
145+
search_path = pybullet_data.getDataPath()
146+
full_urdf = os.path.join(search_path, urdf)
147+
chain = pk.build_serial_chain_from_urdf(open(full_urdf).read(), "lbr_iiwa_link_7")
148+
chain = chain.to(device=device)
149+
150+
# robot frame
151+
pos = torch.tensor([0.0, 0.0, 0.0], device=device)
152+
rot = torch.tensor([0.0, 0.0, 0.0], device=device)
153+
rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device)
154+
155+
# goal equal to current configuration
156+
lim = torch.tensor(chain.get_joint_limits(), device=device)
157+
cur_q = torch.rand(7, device=device) * (lim[1] - lim[0]) + lim[0]
158+
M = 1
159+
goal_q = cur_q.unsqueeze(0).repeat(M, 1)
160+
161+
# get ee pose (in robot frame)
162+
goal_in_rob_frame_tf = chain.forward_kinematics(goal_q)
163+
164+
# transform to world frame for visualization
165+
goal_tf = rob_tf.compose(goal_in_rob_frame_tf)
166+
goal = goal_tf.get_matrix()
167+
goal_pos = goal[..., :3, 3]
168+
goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ")
169+
170+
ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=10,
171+
joint_limits=lim.T,
172+
early_stopping_any_converged=True,
173+
early_stopping_no_improvement="all",
174+
retry_configs=cur_q.reshape(1, -1),
175+
# line_search=pk.BacktrackingLineSearch(max_lr=0.2),
176+
debug=False,
177+
lr=0.2)
178+
179+
# do IK
180+
sol = ik.solve(goal_in_rob_frame_tf)
181+
assert sol.converged.sum() == M
182+
assert torch.allclose(sol.solutions[0][0], cur_q)
183+
assert torch.allclose(sol.err_pos[0], torch.zeros(1, device=device), atol=1e-6)
184+
assert torch.allclose(sol.err_rot[0], torch.zeros(1, device=device), atol=1e-6)
185+
186+
135187
if __name__ == "__main__":
136188
test_jacobian_follower()
189+
test_ik_in_place_no_err()

0 commit comments

Comments
 (0)