generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
example.py
70 lines (54 loc) · 2.08 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Example script demonstrating forward pass usage of the Surgical Robot Transformer.
"""
import torch
from loguru import logger
from srt_torch.main import (
SurgicalRobotTransformer,
ModelConfig,
RobotObservation,
)
def run_forward_pass():
# Initialize model and config
config = ModelConfig()
model = SurgicalRobotTransformer(config)
model.eval() # Set to evaluation mode
# Create sample camera images (simulating robot observations)
# Normally these would come from your robot's cameras
sample_image = torch.zeros((3, 224, 224)) # [C, H, W] format
# Create observation object containing all camera views
observation = RobotObservation(
stereo_left=sample_image,
stereo_right=sample_image,
wrist_left=sample_image,
wrist_right=sample_image,
)
# Perform forward pass
with torch.no_grad():
try:
action = model(observation)
# Extract predicted actions
left_pos = action.left_pos.numpy() # [3] - xyz position
left_rot = action.left_rot.numpy() # [6] - 6D rotation
left_grip = (
action.left_gripper.numpy()
) # [1] - gripper angle
right_pos = action.right_pos.numpy() # [3]
right_rot = action.right_rot.numpy() # [6]
right_grip = action.right_gripper.numpy() # [1]
logger.info(f"Left arm position: {left_pos}")
logger.info(f"Left arm rotation: {left_rot}")
logger.info(f"Left gripper angle: {left_grip}")
logger.info(f"Right arm position: {right_pos}")
logger.info(f"Right arm rotation: {right_rot}")
logger.info(f"Right gripper angle: {right_grip}")
return action
except Exception as e:
logger.error(f"Error during forward pass: {str(e)}")
raise
if __name__ == "__main__":
# Set up logging
logger.add("srt_inference.log")
logger.info("Starting SRT forward pass example")
action = run_forward_pass()
logger.info("Forward pass completed successfully")