Skip to content

Commit

Permalink
backup wip
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare committed Aug 22, 2024
1 parent 6d8fbf3 commit d236ead
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 17 deletions.
11 changes: 9 additions & 2 deletions lerobot/common/policies/rollout_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, policy: Policy, fps: float, n_action_buffer: int = 0):
inference run.
"""
self.policy = policy
self.period_us = int(round(MICROSEC * (1 / fps)))
self.period_us = int(round(MICROSEC * 1 / fps))
# We'll allow half a clock cycle of tolerance on timestamp retrieval.
self.timestamp_tolerance_us = int(round(MICROSEC * (1 / fps / 2)))
self.n_action_buffer = n_action_buffer
Expand Down Expand Up @@ -149,7 +149,14 @@ def run_inference(
}
)

logging.info(f"Inference time: {(time.perf_counter() - start_inference_t) * 1000 :.0f} ms")
inference_time = time.perf_counter() - start_inference_t
# logging.info(f"Inference time: {inference_time * 1000 :.0f} ms")
if inference_time > (self.n_action_buffer * self.period_us + self.period_us) / MICROSEC:
logging.warning(
"Inference is taking longer than your buffer.\n"
f" Buffer time : {self.n_action_buffer * self.period_us + self.period_us / 1000=} ms\n"
f" Inference time: {inference_time * 1000 :.0f} ms"
)

def _get_contiguous_action_sequence_from_cache(self, first_action_timestamp_us: float) -> Tensor | None:
with self._thread_lock:
Expand Down
20 changes: 20 additions & 0 deletions lerobot/common/policies/vqbet/modeling_vqbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ def reset(self):
"action": deque(maxlen=self.config.action_chunk_size),
}

@property
def n_obs_steps(self) -> int:
return self.config.n_obs_steps

@property
def input_keys(self) -> set[str]:
return set(self.config.input_shapes)

@torch.no_grad
def run_inference(self, observation_batch: dict[str, Tensor]) -> Tensor:
observation_batch = self.normalize_inputs(observation_batch)
# shallow copy so that adding a key doesn't modify the original
observation_batch = dict(observation_batch)
observation_batch["observation.images"] = torch.stack(
[observation_batch[k] for k in self.expected_image_keys], dim=-4
)
actions = self.vqbet(observation_batch, rollout=True)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions

@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
Expand Down
28 changes: 26 additions & 2 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def record(
robot: Robot,
policy: torch.nn.Module | None = None,
hydra_cfg: DictConfig | None = None,
policy_action_safety_cap: torch.Tensor | None = None,
fps: int | None = None,
root="data",
repo_id="lerobot/debug",
Expand All @@ -313,6 +314,15 @@ def record(
if not video:
raise NotImplementedError()

if policy_action_safety_cap is None and policy is not None:
policy_action_safety_cap = torch.tensor([10.0, 10.0, 10.0, 10.0, 10.0, 16.0])
logging.info(
"Actions from the policy will be clamped such that they result in a maximum relative positional "
f"target magnitude of no greater than {policy_action_safety_cap.tolist()}. This is for safety "
"reasons (mostly to avoid damaging your motors). Any instances of capping will be logged. You "
"may override these values by passing `policy_action_safety_cap`."
)

if not robot.is_connected:
robot.connect()

Expand Down Expand Up @@ -485,9 +495,23 @@ def on_press(key):
# Move to cpu, if not already the case
action = action.to("cpu")

# Cap relative action target magnitude for safety.
current_pos = observation["observation.state"].cpu().squeeze(0)
diff = action - current_pos
safe_diff = diff.clone()
safe_diff = torch.minimum(diff, policy_action_safety_cap)
safe_diff = torch.maximum(safe_diff, -policy_action_safety_cap)
safe_action = current_pos + safe_diff
if not torch.allclose(safe_action, action):
logging.warning(
"Relative action magnitude had to be clamped to be safe.\n"
f" requested relative action target: {diff}\n"
f" clamped relative action target: {safe_diff}"
)

# Order the robot to move
robot.send_action(action)
action = {"action": action}
robot.send_action(safe_action)
action = {"action": safe_action}

for key in action:
if key not in ep_dict:
Expand Down
33 changes: 20 additions & 13 deletions lerobot/scripts/eval_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ def to_relative_time(t):
to_visualize = {}
while True:
is_dropped_cycle = False
over_time = False
start_step_time = to_relative_time(time.perf_counter())
observation: dict[str, torch.Tensor] = robot.capture_observation()

elapsed = to_relative_time(time.perf_counter()) - start_step_time
if elapsed > period:
over_time = True
logging.warning(f"Over time after capturing observation! {elapsed=}")

# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if name.startswith("observation.image"):
Expand All @@ -68,7 +74,7 @@ def to_relative_time(t):
# based on the current observation
with torch.inference_mode():
timeout = (
period - (to_relative_time(time.perf_counter()) - start_step_time) - 0.005
period - (to_relative_time(time.perf_counter()) - start_step_time) - 0.01
if step > 0
else None
)
Expand All @@ -79,10 +85,6 @@ def to_relative_time(t):
strict_observation_timestamps=step > 0,
timeout=timeout,
)
elapsed = to_relative_time(time.perf_counter()) - start_step_time
if elapsed > period:
logging.warning(f"C: Step took too long! {elapsed=}")
# print(Timer.render_timing_statistics())

if action_sequence is not None:
action_sequence = action_sequence.squeeze(1) # remove batch dim
Expand Down Expand Up @@ -118,29 +120,33 @@ def to_relative_time(t):
to_visualize[name][-10:] = red
to_visualize[name][:, :10] = red
to_visualize[name][:, -10:] = red
if over_time:
purple = np.array([255, 0, 255], dtype=np.uint8)
to_visualize[name][:20] = purple
to_visualize[name][-20:] = purple
to_visualize[name][:, :20] = purple
to_visualize[name][:, -20:] = purple
cv2.imshow(name, cv2.cvtColor(to_visualize[name], cv2.COLOR_RGB2BGR))
k = cv2.waitKey(1)
if k == ord("q"):
return

elapsed = to_relative_time(time.perf_counter()) - start_step_time
if elapsed > period:
logging.warning(f"B: Step took too long! {elapsed=}")

# Order the robot to move
if start_step_time < warmup_s:
if start_step_time <= warmup_s:
policy_rollout_wrapper.reset()
logging.info("Warming up.")
else:
robot_pos = torch.tensor(robot.follower_arms["main"].read("Present_Position"))
# Cap action magnitude at 10 degrees
robot_pos = observation["observation.state"].cpu().squeeze(0)
# Cap action magnitude.
diff = action - robot_pos
safe_diff = diff.clone()
maximum_diff = torch.tensor([10, 10, 10, 10, 10, 15])
safe_diff = torch.minimum(diff, maximum_diff)
safe_diff = torch.maximum(diff, -maximum_diff)
safe_diff = torch.maximum(safe_diff, -maximum_diff)
safe_action = robot_pos + safe_diff
if not torch.equal(safe_action, action):
if not torch.allclose(safe_action, action):
logging.warning(
"Action diff had to be clamped to be safe.\n"
f" requested diff: {diff}\n"
Expand All @@ -154,7 +160,8 @@ def to_relative_time(t):
else:
busy_wait(period - elapsed - 0.001)

step += 1
if start_step_time > warmup_s:
step += 1


if __name__ == "__main__":
Expand Down

0 comments on commit d236ead

Please sign in to comment.