Skip to content

Commit 42645b7

Browse files
authored
Merge pull request #29 from UM-ARM-Lab/inverse_kinematics
Inverse kinematics
2 parents 31f1422 + c85e294 commit 42645b7

File tree

10 files changed

+751
-6
lines changed

10 files changed

+751
-6
lines changed

.github/workflows/python-package.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ jobs:
4242
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
4343
- name: Test with pytest
4444
run: |
45-
pytest
45+
pytest --ignore=tests/mujoco_menagerie

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.idea
2+
*.mp4
23
*.so
34
*.pkl
45
*.egg-info

pyproject.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,15 @@ dependencies = [
4646
'numpy',
4747
'pyyaml',
4848
'torch',
49+
'matplotlib',
50+
'pytorch_seed',
4951
]
5052

5153
[project.optional-dependencies]
52-
test = ["pytest"]
54+
test = [
55+
"pytest",
56+
"pybullet",
57+
]
5358

5459
[project.urls]
5560
"Homepage" = "https://github.com/UM-ARM-Lab/pytorch_kinematics"

src/pytorch_kinematics/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from pytorch_kinematics.sdf import *
22
from pytorch_kinematics.urdf import *
3-
from pytorch_kinematics.mjcf import *
3+
4+
try:
5+
from pytorch_kinematics.mjcf import *
6+
except ImportError:
7+
pass
48
from pytorch_kinematics.transforms import *
59
from pytorch_kinematics.chain import *
10+
from pytorch_kinematics.ik import *

src/pytorch_kinematics/chain.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,10 @@ def _generate_serial_chain_recurse(root_frame, end_frame_name):
458458
return [child] + frames
459459
return None
460460

461-
def jacobian(self, th, locations=None):
461+
def jacobian(self, th, locations=None, **kwargs):
462462
if locations is not None:
463463
locations = tf.Transform3d(pos=locations)
464-
return jacobian.calc_jacobian(self, th, tool=locations)
464+
return jacobian.calc_jacobian(self, th, tool=locations, **kwargs)
465465

466466
def forward_kinematics(self, th, end_only: bool = True):
467467
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """

src/pytorch_kinematics/ik.py

+425
Large diffs are not rendered by default.

src/pytorch_kinematics/jacobian.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pytorch_kinematics import transforms
44

55

6-
def calc_jacobian(serial_chain, th, tool=None):
6+
def calc_jacobian(serial_chain, th, tool=None, ret_eef_pose=False):
77
"""
88
Return robot Jacobian J in base frame (N,6,DOF) where dot{x} = J dot{q}
99
The first 3 rows relate the translational velocities and the
@@ -57,4 +57,6 @@ def calc_jacobian(serial_chain, th, tool=None):
5757
j_tr[:, :3, :3] = rotation
5858
j_tr[:, 3:, 3:] = rotation
5959
j_w = j_tr @ j_eef
60+
if ret_eef_pose:
61+
return j_w, pose
6062
return j_w

tests/meshes/cone.mtl

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Blender 3.6.4 MTL File: 'ycb.blend'
2+
# www.blender.org

tests/meshes/cone.obj

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Blender 3.6.4
2+
# www.blender.org
3+
mtllib cone.mtl
4+
o Cone
5+
v 0.097681 0.000000 -0.097681
6+
v 0.095804 -0.019057 -0.097681
7+
v 0.090246 -0.037381 -0.097681
8+
v 0.081219 -0.054269 -0.097681
9+
v 0.069071 -0.069071 -0.097681
10+
v 0.054269 -0.081219 -0.097681
11+
v 0.037381 -0.090246 -0.097681
12+
v 0.019057 -0.095804 -0.097681
13+
v 0.000000 -0.097681 -0.097681
14+
v -0.019057 -0.095804 -0.097681
15+
v -0.037381 -0.090246 -0.097681
16+
v -0.054269 -0.081219 -0.097681
17+
v -0.069071 -0.069071 -0.097681
18+
v -0.081219 -0.054269 -0.097681
19+
v -0.090246 -0.037381 -0.097681
20+
v -0.095804 -0.019057 -0.097681
21+
v -0.097681 0.000000 -0.097681
22+
v -0.095804 0.019057 -0.097681
23+
v -0.090246 0.037381 -0.097681
24+
v -0.081219 0.054269 -0.097681
25+
v -0.069071 0.069071 -0.097681
26+
v -0.054269 0.081219 -0.097681
27+
v -0.037381 0.090246 -0.097681
28+
v -0.019057 0.095804 -0.097681
29+
v 0.000000 0.097681 -0.097681
30+
v 0.019057 0.095804 -0.097681
31+
v 0.037381 0.090246 -0.097681
32+
v 0.054269 0.081219 -0.097681
33+
v 0.069071 0.069071 -0.097681
34+
v 0.081219 0.054269 -0.097681
35+
v 0.090246 0.037381 -0.097681
36+
v 0.095804 0.019057 -0.097681
37+
v 0.000000 0.000000 0.097681
38+
vn 0.8910 -0.0878 0.4455
39+
vn 0.8567 -0.2599 0.4455
40+
vn 0.7896 -0.4220 0.4455
41+
vn 0.6921 -0.5680 0.4455
42+
vn 0.5680 -0.6921 0.4455
43+
vn 0.4220 -0.7896 0.4455
44+
vn 0.2599 -0.8567 0.4455
45+
vn 0.0878 -0.8910 0.4455
46+
vn -0.0878 -0.8910 0.4455
47+
vn -0.2599 -0.8567 0.4455
48+
vn -0.4220 -0.7896 0.4455
49+
vn -0.5680 -0.6921 0.4455
50+
vn -0.6921 -0.5680 0.4455
51+
vn -0.7896 -0.4220 0.4455
52+
vn -0.8567 -0.2599 0.4455
53+
vn -0.8910 -0.0878 0.4455
54+
vn -0.8910 0.0878 0.4455
55+
vn -0.8567 0.2599 0.4455
56+
vn -0.7896 0.4220 0.4455
57+
vn -0.6921 0.5680 0.4455
58+
vn -0.5680 0.6921 0.4455
59+
vn -0.4220 0.7896 0.4455
60+
vn -0.2599 0.8567 0.4455
61+
vn -0.0878 0.8910 0.4455
62+
vn 0.0878 0.8910 0.4455
63+
vn 0.2599 0.8567 0.4455
64+
vn 0.4220 0.7896 0.4455
65+
vn 0.5680 0.6921 0.4455
66+
vn 0.6921 0.5680 0.4455
67+
vn 0.7896 0.4220 0.4455
68+
vn -0.0000 -0.0000 -1.0000
69+
vn 0.8567 0.2599 0.4455
70+
vn 0.8910 0.0878 0.4455
71+
vt 0.250000 0.490000
72+
vt 0.250000 0.250000
73+
vt 0.296822 0.485388
74+
vt 0.341844 0.471731
75+
vt 0.383337 0.449553
76+
vt 0.419706 0.419706
77+
vt 0.449553 0.383337
78+
vt 0.471731 0.341844
79+
vt 0.485388 0.296822
80+
vt 0.490000 0.250000
81+
vt 0.485388 0.203178
82+
vt 0.471731 0.158156
83+
vt 0.449553 0.116663
84+
vt 0.419706 0.080294
85+
vt 0.383337 0.050447
86+
vt 0.341844 0.028269
87+
vt 0.296822 0.014612
88+
vt 0.250000 0.010000
89+
vt 0.203178 0.014612
90+
vt 0.158156 0.028269
91+
vt 0.116663 0.050447
92+
vt 0.080294 0.080294
93+
vt 0.050447 0.116663
94+
vt 0.028269 0.158156
95+
vt 0.014612 0.203178
96+
vt 0.010000 0.250000
97+
vt 0.014612 0.296822
98+
vt 0.028269 0.341844
99+
vt 0.050447 0.383337
100+
vt 0.080294 0.419706
101+
vt 0.116663 0.449553
102+
vt 0.158156 0.471731
103+
vt 0.750000 0.490000
104+
vt 0.796822 0.485388
105+
vt 0.841844 0.471731
106+
vt 0.883337 0.449553
107+
vt 0.919706 0.419706
108+
vt 0.949553 0.383337
109+
vt 0.971731 0.341844
110+
vt 0.985388 0.296822
111+
vt 0.990000 0.250000
112+
vt 0.985388 0.203178
113+
vt 0.971731 0.158156
114+
vt 0.949553 0.116663
115+
vt 0.919706 0.080294
116+
vt 0.883337 0.050447
117+
vt 0.841844 0.028269
118+
vt 0.796822 0.014612
119+
vt 0.750000 0.010000
120+
vt 0.703178 0.014612
121+
vt 0.658156 0.028269
122+
vt 0.616663 0.050447
123+
vt 0.580294 0.080294
124+
vt 0.550447 0.116663
125+
vt 0.528269 0.158156
126+
vt 0.514612 0.203178
127+
vt 0.510000 0.250000
128+
vt 0.514612 0.296822
129+
vt 0.528269 0.341844
130+
vt 0.550447 0.383337
131+
vt 0.580294 0.419706
132+
vt 0.616663 0.449553
133+
vt 0.658156 0.471731
134+
vt 0.703178 0.485388
135+
vt 0.203178 0.485388
136+
s 0
137+
f 1/1/1 33/2/1 2/3/1
138+
f 2/3/2 33/2/2 3/4/2
139+
f 3/4/3 33/2/3 4/5/3
140+
f 4/5/4 33/2/4 5/6/4
141+
f 5/6/5 33/2/5 6/7/5
142+
f 6/7/6 33/2/6 7/8/6
143+
f 7/8/7 33/2/7 8/9/7
144+
f 8/9/8 33/2/8 9/10/8
145+
f 9/10/9 33/2/9 10/11/9
146+
f 10/11/10 33/2/10 11/12/10
147+
f 11/12/11 33/2/11 12/13/11
148+
f 12/13/12 33/2/12 13/14/12
149+
f 13/14/13 33/2/13 14/15/13
150+
f 14/15/14 33/2/14 15/16/14
151+
f 15/16/15 33/2/15 16/17/15
152+
f 16/17/16 33/2/16 17/18/16
153+
f 17/18/17 33/2/17 18/19/17
154+
f 18/19/18 33/2/18 19/20/18
155+
f 19/20/19 33/2/19 20/21/19
156+
f 20/21/20 33/2/20 21/22/20
157+
f 21/22/21 33/2/21 22/23/21
158+
f 22/23/22 33/2/22 23/24/22
159+
f 23/24/23 33/2/23 24/25/23
160+
f 24/25/24 33/2/24 25/26/24
161+
f 25/26/25 33/2/25 26/27/25
162+
f 26/27/26 33/2/26 27/28/26
163+
f 27/28/27 33/2/27 28/29/27
164+
f 28/29/28 33/2/28 29/30/28
165+
f 29/30/29 33/2/29 30/31/29
166+
f 30/31/30 33/2/30 31/32/30
167+
f 1/33/31 2/34/31 3/35/31 4/36/31 5/37/31 6/38/31 7/39/31 8/40/31 9/41/31 10/42/31 11/43/31 12/44/31 13/45/31 14/46/31 15/47/31 16/48/31 17/49/31 18/50/31 19/51/31 20/52/31 21/53/31 22/54/31 23/55/31 24/56/31 25/57/31 26/58/31 27/59/31 28/60/31 29/61/31 30/62/31 31/63/31 32/64/31
168+
f 31/32/32 33/2/32 32/65/32
169+
f 32/65/33 33/2/33 1/1/33

tests/test_inverse_kinematics.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
from timeit import default_timer as timer
3+
4+
import torch
5+
6+
import pytorch_kinematics as pk
7+
import pytorch_seed
8+
9+
import pybullet as p
10+
import pybullet_data
11+
12+
visualize = False
13+
14+
15+
def _make_robot_translucent(robot_id, alpha=0.4):
16+
def make_transparent(link):
17+
link_id = link[1]
18+
rgba = list(link[7])
19+
rgba[3] = alpha
20+
p.changeVisualShape(robot_id, link_id, rgbaColor=rgba)
21+
22+
visual_data = p.getVisualShapeData(robot_id)
23+
for link in visual_data:
24+
make_transparent(link)
25+
26+
27+
def test_jacobian_follower():
28+
pytorch_seed.seed(2)
29+
device = "cuda" if torch.cuda.is_available() else "cpu"
30+
# device = "cpu"
31+
urdf = "kuka_iiwa/model.urdf"
32+
search_path = pybullet_data.getDataPath()
33+
full_urdf = os.path.join(search_path, urdf)
34+
chain = pk.build_serial_chain_from_urdf(open(full_urdf).read(), "lbr_iiwa_link_7")
35+
chain = chain.to(device=device)
36+
37+
# robot frame
38+
pos = torch.tensor([0.0, 0.0, 0.0], device=device)
39+
rot = torch.tensor([0.0, 0.0, 0.0], device=device)
40+
rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device)
41+
42+
# world frame goal
43+
M = 1000
44+
# generate random goal joint angles (so these are all achievable)
45+
# use the joint limits to generate random joint angles
46+
lim = torch.tensor(chain.get_joint_limits(), device=device)
47+
goal_q = torch.rand(M, 7, device=device) * (lim[1] - lim[0]) + lim[0]
48+
49+
# get ee pose (in robot frame)
50+
goal_in_rob_frame_tf = chain.forward_kinematics(goal_q)
51+
52+
# transform to world frame for visualization
53+
goal_tf = rob_tf.compose(goal_in_rob_frame_tf)
54+
goal = goal_tf.get_matrix()
55+
goal_pos = goal[..., :3, 3]
56+
goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ")
57+
58+
ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=10,
59+
joint_limits=lim.T,
60+
early_stopping_any_converged=True,
61+
early_stopping_no_improvement="all",
62+
# line_search=pk.BacktrackingLineSearch(max_lr=0.2),
63+
debug=False,
64+
lr=0.2)
65+
66+
# do IK
67+
timer_start = timer()
68+
sol = ik.solve(goal_in_rob_frame_tf)
69+
timer_end = timer()
70+
print("IK took %f seconds" % (timer_end - timer_start))
71+
print("IK converged number: %d / %d" % (sol.converged.sum(), sol.converged.numel()))
72+
print("IK took %d iterations" % sol.iterations)
73+
print("IK solved %d / %d goals" % (sol.converged_any.sum(), M))
74+
75+
# visualize everything
76+
if visualize:
77+
p.connect(p.GUI)
78+
p.setRealTimeSimulation(False)
79+
p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)
80+
p.setAdditionalSearchPath(search_path)
81+
82+
yaw = 90
83+
pitch = -55
84+
dist = 1.
85+
target = goal_pos[0].cpu().numpy()
86+
p.resetDebugVisualizerCamera(dist, yaw, pitch, target)
87+
88+
p.loadURDF("plane.urdf", [0, 0, 0], useFixedBase=True)
89+
m = rob_tf.get_matrix()
90+
pos = m[0, :3, 3]
91+
rot = m[0, :3, :3]
92+
quat = pk.matrix_to_quaternion(rot)
93+
pos = pos.cpu().numpy()
94+
rot = pk.wxyz_to_xyzw(quat).cpu().numpy()
95+
armId = p.loadURDF(urdf, basePosition=pos, baseOrientation=rot, useFixedBase=True)
96+
97+
_make_robot_translucent(armId, alpha=0.6)
98+
# p.resetBasePositionAndOrientation(armId, [0, 0, 0], [0, 0, 0, 1])
99+
# draw goal
100+
# place a translucent sphere at the goal
101+
show_max_num_retries_per_goal = 10
102+
for goal_num in range(M):
103+
# draw cone to indicate pose instead of sphere
104+
visId = p.createVisualShape(p.GEOM_MESH, fileName="meshes/cone.obj", meshScale=1.0,
105+
rgbaColor=[0., 1., 0., 0.5])
106+
# visId = p.createVisualShape(p.GEOM_SPHERE, radius=0.05, rgbaColor=[0., 1., 0., 0.5])
107+
r = goal_rot[goal_num]
108+
xyzw = pk.wxyz_to_xyzw(pk.matrix_to_quaternion(pk.euler_angles_to_matrix(r, "XYZ")))
109+
goalId = p.createMultiBody(baseMass=0, baseVisualShapeIndex=visId,
110+
basePosition=goal_pos[goal_num].cpu().numpy(),
111+
baseOrientation=xyzw.cpu().numpy())
112+
113+
solutions = sol.solutions[goal_num]
114+
# sort based on if they converged
115+
converged = sol.converged[goal_num]
116+
idx = torch.argsort(converged.to(int), descending=True)
117+
solutions = solutions[idx]
118+
119+
# print how many retries converged for this one
120+
print("Goal %d converged %d / %d" % (goal_num, converged.sum(), converged.numel()))
121+
122+
for i, q in enumerate(solutions):
123+
if i > show_max_num_retries_per_goal:
124+
break
125+
for dof in range(q.shape[0]):
126+
p.resetJointState(armId, dof, q[dof])
127+
input("Press enter to continue")
128+
129+
p.removeBody(goalId)
130+
131+
while True:
132+
p.stepSimulation()
133+
134+
135+
if __name__ == "__main__":
136+
test_jacobian_follower()

0 commit comments

Comments
 (0)