Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Alibert <[email protected]>
  • Loading branch information
marinabar and aliberts authored Jun 12, 2024
1 parent 58faf75 commit 3646f6d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ test-act-ete-train:
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
training.image_transforms.enable=true \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/act/

test-act-ete-eval:
Expand Down Expand Up @@ -74,7 +74,7 @@ test-act-ete-train-amp:
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act_amp/ \
training.image_transforms.enable=true \
training.image_transforms.enable=true \
use_amp=true

test-act-ete-eval-amp:
Expand Down Expand Up @@ -102,7 +102,7 @@ test-diffusion-ete-train:
training.save_checkpoint=true \
training.save_freq=2 \
training.batch_size=2 \
training.image_transforms.enable=true \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/diffusion/

test-diffusion-ete-eval:
Expand Down Expand Up @@ -130,7 +130,7 @@ test-tdmpc-ete-train:
training.save_checkpoint=true \
training.save_freq=2 \
training.batch_size=2 \
training.image_transforms.enable=true \
training.image_transforms.enable=true \
hydra.run.dir=tests/outputs/tdmpc/

test-tdmpc-ete-eval:
Expand Down
25 changes: 13 additions & 12 deletions lerobot/common/datasets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,20 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData

image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
image_transforms = get_image_transforms(
brightness_weight=cfg.training.image_transforms.brightness.weight,
brightness_min_max=cfg.training.image_transforms.brightness.min_max,
contrast_weight=cfg.training.image_transforms.contrast.weight,
contrast_min_max=cfg.training.image_transforms.contrast.min_max,
saturation_weight=cfg.training.image_transforms.saturation.weight,
saturation_min_max=cfg.training.image_transforms.saturation.min_max,
hue_weight=cfg.training.image_transforms.hue.weight,
hue_min_max=cfg.training.image_transforms.hue.min_max,
sharpness_weight=cfg.training.image_transforms.sharpness.weight,
sharpness_min_max=cfg.training.image_transforms.sharpness.min_max,
max_num_transforms=cfg.training.image_transforms.max_num_transforms,
random_order=cfg.training.image_transforms.random_order,
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)

if isinstance(cfg.dataset_repo_id, str):
Expand Down

0 comments on commit 3646f6d

Please sign in to comment.