Skip to content

Commit

Permalink
Handle crop_shape=None in Diffusion Policy (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-soare authored May 28, 2024
1 parent e3b9f1c commit 3d625ae
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __post_init__(self):
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if (
if self.crop_shape is not None and (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
Expand Down
8 changes: 6 additions & 2 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,15 @@ def __init__(self, config: DiffusionConfig):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape`.
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
Expand Down
4 changes: 1 addition & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3d625ae

Please sign in to comment.