Skip to content

Commit 3ca731a

Browse files
committed
Move serial chain testing to separate file
1 parent 2baf77c commit 3ca731a

File tree

3 files changed

+90
-84
lines changed

3 files changed

+90
-84
lines changed

tests/test_inverse_kinematics.py

+3-83
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import pybullet as p
1111
import pybullet_data
1212

13-
visualize = True
13+
visualize = False
1414

15-
TEST_DIR = os.path.dirname(__file__)
1615

1716
def _make_robot_translucent(robot_id, alpha=0.4):
1817
def make_transparent(link):
@@ -207,87 +206,8 @@ def test_ik_in_place_no_err():
207206
assert torch.allclose(sol.err_rot[0], torch.zeros(1, device=device), atol=1e-6)
208207

209208

210-
def test_extract_serial_chain_from_tree():
211-
pytorch_seed.seed(2)
212-
device = "cuda" if torch.cuda.is_available() else "cpu"
213-
# device = "cpu"
214-
urdf = "widowx/wx250s.urdf"
215-
full_urdf = os.path.join(TEST_DIR, urdf)
216-
chain = pk.build_chain_from_urdf(open(full_urdf, mode="rb").read())
217-
# full frames
218-
full_frame_expected = """
219-
base_link
220-
└── shoulder_link
221-
└── upper_arm_link
222-
└── upper_forearm_link
223-
└── lower_forearm_link
224-
└── wrist_link
225-
└── gripper_link
226-
└── ee_arm_link
227-
├── gripper_prop_link
228-
└── gripper_bar_link
229-
└── fingers_link
230-
├── left_finger_link
231-
├── right_finger_link
232-
└── ee_gripper_link
233-
"""
234-
full_frame = chain.print_link_tree()
235-
assert full_frame_expected.strip() == full_frame.strip()
236-
237-
chain = pk.SerialChain(chain, "ee_gripper_link", "base_link")
238-
serial_frame = chain.print_link_tree()
239-
chain = chain.to(device=device)
240-
241-
# full chain should have DOF = 8, however since we are creating just a serial chain to ee_gripper_link, should be 6
242-
dof = len(chain.get_joints(exclude_fixed=True))
243-
assert dof == 6
244-
245-
# robot frame
246-
pos = torch.tensor([0.0, 0.0, 0.0], device=device)
247-
rot = torch.tensor([0.0, 0.0, 0.0], device=device)
248-
rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device)
249-
250-
# world frame goal
251-
M = 1000
252-
# generate random goal joint angles (so these are all achievable)
253-
# use the joint limits to generate random joint angles
254-
lim = torch.tensor(chain.get_joint_limits(), device=device)
255-
goal_q = torch.rand(M, 7, device=device) * (lim[1] - lim[0]) + lim[0]
256-
257-
# get ee pose (in robot frame)
258-
goal_in_rob_frame_tf = chain.forward_kinematics(goal_q)
259-
260-
# transform to world frame for visualization
261-
goal_tf = rob_tf.compose(goal_in_rob_frame_tf)
262-
goal = goal_tf.get_matrix()
263-
goal_pos = goal[..., :3, 3]
264-
goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ")
265-
266-
num_retries = 10
267-
ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=num_retries,
268-
joint_limits=lim.T,
269-
early_stopping_any_converged=True,
270-
early_stopping_no_improvement="all",
271-
# line_search=pk.BacktrackingLineSearch(max_lr=0.2),
272-
debug=False,
273-
lr=0.2)
274-
275-
# do IK
276-
timer_start = timer()
277-
sol = ik.solve(goal_in_rob_frame_tf)
278-
timer_end = timer()
279-
print("IK took %f seconds" % (timer_end - timer_start))
280-
print("IK converged number: %d / %d" % (sol.converged.sum(), sol.converged.numel()))
281-
print("IK took %d iterations" % sol.iterations)
282-
print("IK solved %d / %d goals" % (sol.converged_any.sum(), M))
283-
284-
# check that solving again produces the same solutions
285-
sol_again = ik.solve(goal_in_rob_frame_tf)
286-
assert torch.allclose(sol.solutions, sol_again.solutions)
287-
assert torch.allclose(sol.converged, sol_again.converged)
288209

289210

290211
if __name__ == "__main__":
291-
# test_jacobian_follower()
292-
# test_ik_in_place_no_err()
293-
test_extract_serial_chain_from_tree()
212+
test_jacobian_follower()
213+
test_ik_in_place_no_err()

tests/test_jacobian.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def get_pt(th):
182182
try:
183183
import functorch
184184
ft_start = timer()
185-
grad_func = functorch.vmap(functorch.jacrev(get_pt))
185+
grad_func = torch.vmap(functorch.jacrev(get_pt))
186186
j3 = grad_func(th).squeeze(1)
187187
ft_end = timer()
188188
assert torch.allclose(j1_, j3, atol=1e-6)

tests/test_serial_chain_creation.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
from timeit import default_timer as timer
3+
4+
import torch
5+
6+
import pytorch_kinematics as pk
7+
8+
TEST_DIR = os.path.dirname(__file__)
9+
10+
11+
def test_extract_serial_chain_from_tree():
12+
urdf = "widowx/wx250s.urdf"
13+
full_urdf = os.path.join(TEST_DIR, urdf)
14+
chain = pk.build_chain_from_urdf(open(full_urdf, mode="rb").read())
15+
# full frames
16+
full_frame_expected = """
17+
base_link
18+
└── shoulder_link
19+
└── upper_arm_link
20+
└── upper_forearm_link
21+
└── lower_forearm_link
22+
└── wrist_link
23+
└── gripper_link
24+
└── ee_arm_link
25+
├── gripper_prop_link
26+
└── gripper_bar_link
27+
└── fingers_link
28+
├── left_finger_link
29+
├── right_finger_link
30+
└── ee_gripper_link
31+
"""
32+
full_frame = chain.print_link_tree()
33+
assert full_frame_expected.strip() == full_frame.strip()
34+
35+
serial_chain = pk.SerialChain(chain, "ee_gripper_link", "base_link")
36+
serial_frame_expected = """
37+
base_link
38+
└── shoulder_link
39+
└── upper_arm_link
40+
└── upper_forearm_link
41+
└── lower_forearm_link
42+
└── wrist_link
43+
└── gripper_link
44+
└── ee_arm_link
45+
└── gripper_bar_link
46+
└── fingers_link
47+
└── ee_gripper_link
48+
"""
49+
serial_frame = serial_chain.print_link_tree()
50+
assert serial_frame_expected.strip() == serial_frame.strip()
51+
52+
# full chain should have DOF = 8, however since we are creating just a serial chain to ee_gripper_link, should be 6
53+
assert chain.n_joints == 8
54+
assert serial_chain.n_joints == 6
55+
56+
serial_chain = pk.SerialChain(chain, "gripper_prop_link", "base_link")
57+
serial_frame_expected = """
58+
base_link
59+
└── shoulder_link
60+
└── upper_arm_link
61+
└── upper_forearm_link
62+
└── lower_forearm_link
63+
└── wrist_link
64+
└── gripper_link
65+
└── ee_arm_link
66+
└── gripper_prop_link
67+
"""
68+
serial_frame = serial_chain.print_link_tree()
69+
assert serial_frame_expected.strip() == serial_frame.strip()
70+
71+
serial_chain = pk.SerialChain(chain, "ee_gripper_link", "gripper_link")
72+
serial_frame_expected = """
73+
gripper_link
74+
└── ee_arm_link
75+
└── gripper_bar_link
76+
└── fingers_link
77+
└── ee_gripper_link
78+
"""
79+
serial_frame = serial_chain.print_link_tree()
80+
assert serial_frame_expected.strip() == serial_frame.strip()
81+
# only gripper_link is the parent frame of a joint in this serial chain
82+
assert serial_chain.n_joints == 1
83+
84+
85+
if __name__ == "__main__":
86+
test_extract_serial_chain_from_tree()

0 commit comments

Comments
 (0)