From 4fa3bb1a841816414ceddb7815e21a7ba58b8102 Mon Sep 17 00:00:00 2001 From: TMats Date: Fri, 26 May 2023 16:27:01 +0900 Subject: [PATCH 1/7] feat: add memory replay buffer --- diffusion_policy/common/replay_buffer.py | 16 ++++ diffusion_policy/config/task/can_image.yaml | 2 +- .../config/task/can_image_abs.yaml | 2 +- diffusion_policy/config/task/lift_image.yaml | 2 +- .../config/task/lift_image_abs.yaml | 2 +- .../config/task/real_pusht_image.yaml | 2 +- .../config/task/square_image.yaml | 2 +- .../config/task/square_image_abs.yaml | 2 +- .../config/task/tool_hang_image.yaml | 2 +- .../config/task/tool_hang_image_abs.yaml | 2 +- .../config/task/transport_image.yaml | 2 +- .../config/task/transport_image_abs.yaml | 2 +- .../dataset/robomimic_replay_image_dataset.py | 93 ++++++++++++++++++- 13 files changed, 118 insertions(+), 13 deletions(-) diff --git a/diffusion_policy/common/replay_buffer.py b/diffusion_policy/common/replay_buffer.py index 022a704e..b5decba3 100644 --- a/diffusion_policy/common/replay_buffer.py +++ b/diffusion_policy/common/replay_buffer.py @@ -586,3 +586,19 @@ def set_compressors(self, compressors: dict): compressor = self.resolve_compressor(value) if compressor != arr.compressor: rechunk_recompress_array(self.data, key, compressor=compressor) + + +class MemoryReplayBuffer: + def __init__(self): + self.data = {} + self.meta = {} + self.n_episodes = 0 + + @property + def episode_ends(self): + return self.meta['episode_ends'] + + def keys(self): + return self.data.keys() + def __getitem__(self, key): + return self.data[key] \ No newline at end of file diff --git a/diffusion_policy/config/task/can_image.yaml b/diffusion_policy/config/task/can_image.yaml index 10158bbd..bc10d3dc 100644 --- a/diffusion_policy/config/task/can_image.yaml +++ b/diffusion_policy/config/task/can_image.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/can_image_abs.yaml b/diffusion_policy/config/task/can_image_abs.yaml index eee34dc6..1f824c34 100644 --- a/diffusion_policy/config/task/can_image_abs.yaml +++ b/diffusion_policy/config/task/can_image_abs.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/lift_image.yaml b/diffusion_policy/config/task/lift_image.yaml index 8ddde456..2f5734ee 100644 --- a/diffusion_policy/config/task/lift_image.yaml +++ b/diffusion_policy/config/task/lift_image.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/lift_image_abs.yaml b/diffusion_policy/config/task/lift_image_abs.yaml index b25002f3..9a29fac7 100644 --- a/diffusion_policy/config/task/lift_image_abs.yaml +++ b/diffusion_policy/config/task/lift_image_abs.yaml @@ -58,6 +58,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/real_pusht_image.yaml b/diffusion_policy/config/task/real_pusht_image.yaml index a3f7c3f0..41eb2094 100644 --- a/diffusion_policy/config/task/real_pusht_image.yaml +++ b/diffusion_policy/config/task/real_pusht_image.yaml @@ -39,7 +39,7 @@ dataset: pad_after: ${eval:'${n_action_steps}-1'} n_obs_steps: ${dataset_obs_steps} n_latency_steps: ${n_latency_steps} - use_cache: True + use_cache: false seed: 42 val_ratio: 0.00 max_train_episodes: null diff --git a/diffusion_policy/config/task/square_image.yaml b/diffusion_policy/config/task/square_image.yaml index 827151be..03e5edcd 100644 --- a/diffusion_policy/config/task/square_image.yaml +++ b/diffusion_policy/config/task/square_image.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/square_image_abs.yaml b/diffusion_policy/config/task/square_image_abs.yaml index d27a916a..53c9ebb5 100644 --- a/diffusion_policy/config/task/square_image_abs.yaml +++ b/diffusion_policy/config/task/square_image_abs.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/tool_hang_image.yaml b/diffusion_policy/config/task/tool_hang_image.yaml index 8a8b298f..6e046758 100644 --- a/diffusion_policy/config/task/tool_hang_image.yaml +++ b/diffusion_policy/config/task/tool_hang_image.yaml @@ -58,6 +58,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/tool_hang_image_abs.yaml b/diffusion_policy/config/task/tool_hang_image_abs.yaml index 068bdc93..92fa8e63 100644 --- a/diffusion_policy/config/task/tool_hang_image_abs.yaml +++ b/diffusion_policy/config/task/tool_hang_image_abs.yaml @@ -58,6 +58,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/transport_image.yaml b/diffusion_policy/config/task/transport_image.yaml index c9c24cc0..b1076e82 100644 --- a/diffusion_policy/config/task/transport_image.yaml +++ b/diffusion_policy/config/task/transport_image.yaml @@ -70,6 +70,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/transport_image_abs.yaml b/diffusion_policy/config/task/transport_image_abs.yaml index 74e76c60..3fe6f4da 100644 --- a/diffusion_policy/config/task/transport_image_abs.yaml +++ b/diffusion_policy/config/task/transport_image_abs.yaml @@ -70,6 +70,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/dataset/robomimic_replay_image_dataset.py b/diffusion_policy/dataset/robomimic_replay_image_dataset.py index 2728e9e9..3c7c4de2 100644 --- a/diffusion_policy/dataset/robomimic_replay_image_dataset.py +++ b/diffusion_policy/dataset/robomimic_replay_image_dataset.py @@ -19,7 +19,7 @@ from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer from diffusion_policy.model.common.rotation_transformer import RotationTransformer from diffusion_policy.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k -from diffusion_policy.common.replay_buffer import ReplayBuffer +from diffusion_policy.common.replay_buffer import ReplayBuffer, MemoryReplayBuffer from diffusion_policy.common.sampler import SequenceSampler, get_val_mask from diffusion_policy.common.normalize_util import ( robomimic_abs_action_only_normalizer_from_stat, @@ -81,7 +81,7 @@ def __init__(self, src_store=zip_store, store=zarr.MemoryStore()) print('Loaded!') else: - replay_buffer = _convert_robomimic_to_replay( + replay_buffer = _convert_robomimic_to_memory_replay( store=zarr.MemoryStore(), shape_meta=shape_meta, dataset_path=dataset_path, @@ -362,6 +362,95 @@ def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx): replay_buffer = ReplayBuffer(root) return replay_buffer + +def _convert_robomimic_to_memory_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer, + n_workers=None, max_inflight_tasks=None): + if n_workers is None: + n_workers = multiprocessing.cpu_count() + if max_inflight_tasks is None: + max_inflight_tasks = n_workers * 5 + + # parse shape_meta + rgb_keys = list() + lowdim_keys = list() + # construct compressors and chunks + obs_shape_meta = shape_meta['obs'] + for key, attr in obs_shape_meta.items(): + shape = attr['shape'] + type = attr.get('type', 'low_dim') + if type == 'rgb': + rgb_keys.append(key) + elif type == 'low_dim': + lowdim_keys.append(key) + + root = MemoryReplayBuffer() + + with h5py.File(dataset_path) as file: + # count total steps + demos = file['data'] + episode_ends = list() + prev_end = 0 + for i in range(len(demos)): + demo = demos[f'demo_{i}'] + episode_length = demo['actions'].shape[0] + episode_end = prev_end + episode_length + prev_end = episode_end + episode_ends.append(episode_end) + n_steps = episode_ends[-1] + episode_starts = [0] + episode_ends[:-1] + root.meta['episode_ends'] = np.array(episode_ends, dtype=np.int64) + root.n_episodes = len(episode_ends) + + # save lowdim data + for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"): + data_key = 'obs/' + key + if key == 'action': + data_key = 'actions' + this_data = list() + for i in range(len(demos)): + demo = demos[f'demo_{i}'] + this_data.append(demo[data_key][:].astype(np.float32)) + this_data = np.concatenate(this_data, axis=0) + if key == 'action': + this_data = _convert_actions( + raw_actions=this_data, + abs_action=abs_action, + rotation_transformer=rotation_transformer + ) + assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape']) + else: + assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape']) + + root.data[key] = np.array(this_data, dtype=this_data.dtype).reshape(this_data.shape) + # root.data[key] = torch.tensor(np.array(this_data, dtype=this_data.dtype).reshape(this_data.shape), dtype=torch.uint8) + + def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx): + try: + zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx] + # make sure we can successfully decode + _ = zarr_arr[zarr_idx] + return True + except Exception as e: + return False + # save img data + for key in tqdm(rgb_keys, desc="Loading image data"): + data_key = 'obs/' + key + hdf5_arr = list() + for episode_idx in range(len(demos)): + demo = demos[f'demo_{episode_idx}'] + # print(demo['obs'][key][:].astype(np.uint8).shape) + # print(demo['obs'][key][:].astype(np.uint8)) + shape = tuple(shape_meta['obs'][key]['shape']) + c,h,w = shape + hdf5_arr.append(demo['obs'][key][:].astype(np.uint8).reshape(-1,h,w,c)) + + hdf5_arr = np.concatenate(hdf5_arr, axis=0) + root.data[key] = np.array(hdf5_arr, dtype=hdf5_arr.dtype).reshape(hdf5_arr.shape) + # root.data[key] = torch.tensor(np.array(hdf5_arr, dtype=hdf5_arr.dtype).reshape(hdf5_arr.shape), dtype=torch.uint8) + replay_buffer = root + return replay_buffer + + def normalizer_from_stat(stat): max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max()) scale = np.full_like(stat['max'], fill_value=1/max_abs) From e33dda9f1b6fe1f2d9462c6d7e43d1d461cdcd34 Mon Sep 17 00:00:00 2001 From: TMats Date: Fri, 26 May 2023 09:04:17 +0000 Subject: [PATCH 2/7] test: add mp and logger --- config.yaml | 379 ++++++++++++++++++ .../train_diffusion_unet_hybrid_workspace.py | 71 +++- train.py | 15 +- 3 files changed, 457 insertions(+), 8 deletions(-) create mode 100644 config.yaml diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..48a06f1c --- /dev/null +++ b/config.yaml @@ -0,0 +1,379 @@ +_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace +checkpoint: + save_last_ckpt: true + save_last_snapshot: false + topk: + format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt + k: 5 + mode: max + monitor_key: test_mean_score +dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: true +dataset_obs_steps: 2 +ema: + _target_: diffusion_policy.model.diffusion.ema_model.EMAModel + inv_gamma: 1.0 + max_value: 0.9999 + min_value: 0.0 + power: 0.75 + update_after_step: 0 +exp_name: default +horizon: 16 +keypoint_visible_rate: 1.0 +logging: + group: null + id: null + mode: online + name: 2023.01.03-19.43.07_train_diffusion_unet_hybrid_transport_image + project: diffusion_policy_debug + resume: true + tags: + - train_diffusion_unet_hybrid + - transport_image + - default +multi_run: + run_dir: data/outputs/2023.01.03/19.43.07_train_diffusion_unet_hybrid_transport_image + wandb_name_base: 2023.01.03-19.43.07_train_diffusion_unet_hybrid_transport_image +n_action_steps: 8 +n_latency_steps: 0 +n_obs_steps: 2 +name: train_diffusion_unet_hybrid +obs_as_global_cond: true +optimizer: + _target_: torch.optim.AdamW + betas: + - 0.95 + - 0.999 + eps: 1.0e-08 + lr: 0.0001 + weight_decay: 1.0e-06 +past_action_visible: false +policy: + _target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy + cond_predict_scale: true + crop_shape: + - 76 + - 76 + diffusion_step_embed_dim: 128 + down_dims: + - 512 + - 1024 + - 2048 + eval_fixed_crop: true + horizon: 16 + kernel_size: 5 + n_action_steps: 8 + n_groups: 8 + n_obs_steps: 2 + noise_scheduler: + _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + beta_start: 0.0001 + clip_sample: true + num_train_timesteps: 100 + prediction_type: epsilon + variance_type: fixed_small + num_inference_steps: 100 + obs_as_global_cond: true + obs_encoder_group_norm: true + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb +shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb +task: + abs_action: true + dataset: + _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset + abs_action: true + dataset_path: data/robomimic/datasets/transport/mh/image_abs.hdf5 + horizon: 16 + n_obs_steps: 2 + pad_after: 7 + pad_before: 1 + rotation_rep: rotation_6d + seed: 42 + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb + use_cache: false + val_ratio: 0.02 + dataset_path: data/robomimic/datasets/transport/mh/image_abs.hdf5 + dataset_type: mh + env_runner: + _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner + abs_action: true + crf: 22 + dataset_path: data/robomimic/datasets/transport/mh/image_abs.hdf5 + fps: 10 + max_steps: 700 + n_action_steps: 8 + n_envs: 28 + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + render_obs_key: shouldercamera0_image + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb + test_start_seed: 100000 + tqdm_interval_sec: 1.0 + train_start_idx: 0 + name: transport_image + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb + task_name: transport +task_name: transport_image +training: + checkpoint_every: 50 + debug: false + device: cuda:0 + gradient_accumulate_every: 1 + lr_scheduler: cosine + lr_warmup_steps: 500 + max_train_steps: null + max_val_steps: null + num_epochs: 3000 + resume: true + rollout_every: 50 + sample_every: 5 + seed: 42 + tqdm_interval_sec: 1.0 + use_ema: true + val_every: 1 +val_dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: false diff --git a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py index 9219427c..fd0d8f2a 100644 --- a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py @@ -31,6 +31,20 @@ OmegaConf.register_new_resolver("eval", eval, replace=True) +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + + class TrainDiffusionUnetHybridWorkspace(BaseWorkspace): include_keys = ['global_step', 'epoch'] @@ -38,10 +52,10 @@ def __init__(self, cfg: OmegaConf, output_dir=None): super().__init__(cfg, output_dir=output_dir) # set seed - seed = cfg.training.seed - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) + self.seed = cfg.training.seed + torch.manual_seed(self.seed) + np.random.seed(self.seed) + random.seed(self.seed) # configure model self.model: DiffusionUnetHybridImagePolicy = hydra.utils.instantiate(cfg.policy) @@ -58,9 +72,29 @@ def __init__(self, cfg: OmegaConf, output_dir=None): self.global_step = 0 self.epoch = 0 - def run(self): + def run(self, rank, world_size): + print("rank:", rank, "world_size:", world_size) + setup(rank, world_size) + seed = self.seed + dist.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + device = f"cuda:{rank}" + cfg = copy.deepcopy(self.cfg) + # configure model + self.model: DiffusionUnetHybridImagePolicy = hydra.utils.instantiate(cfg.policy) + # configure the model for DDP + self.model = DDP(self.model) + + self.ema_model: DiffusionUnetHybridImagePolicy = None + if cfg.training.use_ema: + self.ema_model = copy.deepcopy(self.model) + # configure training state + self.optimizer = hydra.utils.instantiate( + cfg.optimizer, params=self.model.parameters()) # resume training if cfg.training.resume: lastest_ckpt_path = self.get_checkpoint_path() @@ -72,7 +106,8 @@ def run(self): dataset: BaseImageDataset dataset = hydra.utils.instantiate(cfg.task.dataset) assert isinstance(dataset, BaseImageDataset) - train_dataloader = DataLoader(dataset, **cfg.dataloader) + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + train_dataloader = DataLoader(dataset, sampler=train_sampler,**cfg.dataloader) normalizer = dataset.get_normalizer() # configure validation dataset @@ -156,23 +191,41 @@ def run(self): train_losses = list() with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: + import time + end_time = time.time() for batch_idx, batch in enumerate(tepoch): + start_time = time.time() + print("Batch", start_time-end_time) + end_time = time.time() # device transfer batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) + start_time = time.time() + print("GPU", start_time-end_time) + end_time = time.time() if train_sampling_batch is None: train_sampling_batch = batch # compute loss raw_loss = self.model.compute_loss(batch) + loss = raw_loss / cfg.training.gradient_accumulate_every + start_time = time.time() + print("Loss", start_time-end_time) + end_time = time.time() + loss.backward() + start_time = time.time() + print("Backword", start_time-end_time) + end_time = time.time() # step optimizer if self.global_step % cfg.training.gradient_accumulate_every == 0: self.optimizer.step() self.optimizer.zero_grad() lr_scheduler.step() - + start_time = time.time() + print("Optimizer", start_time-end_time) + end_time = time.time() # update ema if cfg.training.use_ema: ema.step(self.model) @@ -198,6 +251,9 @@ def run(self): if (cfg.training.max_train_steps is not None) \ and batch_idx >= (cfg.training.max_train_steps-1): break + start_time = time.time() + print("Logger", start_time-end_time) + end_time = time.time() # at the end of each epoch # replace train_loss with epoch average @@ -283,6 +339,7 @@ def run(self): json_logger.log(step_log) self.global_step += 1 self.epoch += 1 + cleanup() @hydra.main( version_base=None, diff --git a/train.py b/train.py index 51c564e6..6ca73090 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,8 @@ from omegaconf import OmegaConf import pathlib from diffusion_policy.workspace.base_workspace import BaseWorkspace +import torch +import torch.multiprocessing as mp # allows arbitrary python code execution in configs using the ${eval:''} resolver OmegaConf.register_new_resolver("eval", eval, replace=True) @@ -22,6 +24,15 @@ config_path=str(pathlib.Path(__file__).parent.joinpath( 'diffusion_policy','config')) ) +# def main(cfg: OmegaConf): +# # resolve immediately so all the ${now:} resolvers +# # will use the same time. +# OmegaConf.resolve(cfg) + +# cls = hydra.utils.get_class(cfg._target_) +# workspace: BaseWorkspace = cls(cfg) +# workspace.run() + def main(cfg: OmegaConf): # resolve immediately so all the ${now:} resolvers # will use the same time. @@ -29,7 +40,9 @@ def main(cfg: OmegaConf): cls = hydra.utils.get_class(cfg._target_) workspace: BaseWorkspace = cls(cfg) - workspace.run() + world_size = torch.cuda.device_count() + mp.spawn(workspace.run, args=(world_size,), nprocs=world_size, join=True) + if __name__ == "__main__": main() From e82291f1f2ce0e4a3888a2a37d3f9ced16ca147a Mon Sep 17 00:00:00 2001 From: TMats Date: Fri, 26 May 2023 19:21:54 +0900 Subject: [PATCH 3/7] debug: print debug for hydraconfig error --- diffusion_policy/workspace/base_workspace.py | 4 ++ .../train_diffusion_unet_hybrid_workspace.py | 54 +++++++++++++------ 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/diffusion_policy/workspace/base_workspace.py b/diffusion_policy/workspace/base_workspace.py index 1352404a..88e20c0b 100644 --- a/diffusion_policy/workspace/base_workspace.py +++ b/diffusion_policy/workspace/base_workspace.py @@ -9,6 +9,10 @@ import torch import threading +# @hydra.main( +# version_base=None, +# config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), +# config_name=pathlib.Path(__file__).stem) class BaseWorkspace: include_keys = tuple() diff --git a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py index fd0d8f2a..69f1448e 100644 --- a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py @@ -31,9 +31,12 @@ OmegaConf.register_new_resolver("eval", eval, replace=True) + import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp + def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' @@ -57,16 +60,16 @@ def __init__(self, cfg: OmegaConf, output_dir=None): np.random.seed(self.seed) random.seed(self.seed) - # configure model - self.model: DiffusionUnetHybridImagePolicy = hydra.utils.instantiate(cfg.policy) + # # configure model + # self.model: DiffusionUnetHybridImagePolicy = hydra.utils.instantiate(cfg.policy) - self.ema_model: DiffusionUnetHybridImagePolicy = None - if cfg.training.use_ema: - self.ema_model = copy.deepcopy(self.model) + # self.ema_model: DiffusionUnetHybridImagePolicy = None + # if cfg.training.use_ema: + # self.ema_model = copy.deepcopy(self.model) - # configure training state - self.optimizer = hydra.utils.instantiate( - cfg.optimizer, params=self.model.parameters()) + # # configure training state + # self.optimizer = hydra.utils.instantiate( + # cfg.optimizer, params=self.model.parameters()) # configure training state self.global_step = 0 @@ -106,17 +109,17 @@ def run(self, rank, world_size): dataset: BaseImageDataset dataset = hydra.utils.instantiate(cfg.task.dataset) assert isinstance(dataset, BaseImageDataset) - train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - train_dataloader = DataLoader(dataset, sampler=train_sampler,**cfg.dataloader) + # train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + train_dataloader = DataLoader(dataset, **cfg.dataloader) normalizer = dataset.get_normalizer() # configure validation dataset val_dataset = dataset.get_validation_dataset() val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader) - self.model.set_normalizer(normalizer) + self.model.module.set_normalizer(normalizer) if cfg.training.use_ema: - self.ema_model.set_normalizer(normalizer) + self.ema_model.module.set_normalizer(normalizer) # configure lr scheduler lr_scheduler = get_scheduler( @@ -139,6 +142,8 @@ def run(self, rank, world_size): model=self.ema_model) # configure env + from hydra.core.hydra_config import HydraConfig + print(HydraConfig.get().runtime.output_dir) env_runner: BaseImageRunner env_runner = hydra.utils.instantiate( cfg.task.env_runner, @@ -164,7 +169,7 @@ def run(self, rank, world_size): ) # device transfer - device = torch.device(cfg.training.device) + # device = torch.device(cfg.training.device) self.model.to(device) if self.ema_model is not None: self.ema_model.to(device) @@ -345,9 +350,24 @@ def run(self, rank, world_size): version_base=None, config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), config_name=pathlib.Path(__file__).stem) -def main(cfg): - workspace = TrainDiffusionUnetHybridWorkspace(cfg) - workspace.run() +# def main(cfg): +# workspace = TrainDiffusionUnetHybridWorkspace(cfg) +# workspace.run() + +# if __name__ == "__main__": +# main() + + +def main(cfg: OmegaConf): + # resolve immediately so all the ${now:} resolvers + # will use the same time. + OmegaConf.resolve(cfg) + + cls = hydra.utils.get_class(cfg._target_) + workspace: BaseWorkspace = cls(cfg) + world_size = torch.cuda.device_count() + mp.spawn(workspace.run, args=(world_size,), nprocs=world_size, join=True) + if __name__ == "__main__": - main() + main() \ No newline at end of file From 813693f98967589657bc54019eb2caf512269451 Mon Sep 17 00:00:00 2001 From: TMats Date: Fri, 26 May 2023 20:08:55 +0900 Subject: [PATCH 4/7] fix: ddp train --- .../train_diffusion_unet_hybrid_workspace.py | 108 +++++++++--------- 1 file changed, 57 insertions(+), 51 deletions(-) diff --git a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py index 69f1448e..b2168ad1 100644 --- a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py @@ -53,6 +53,8 @@ class TrainDiffusionUnetHybridWorkspace(BaseWorkspace): def __init__(self, cfg: OmegaConf, output_dir=None): super().__init__(cfg, output_dir=output_dir) + print(self.output_dir) + self.output_dir_base = self.output_dir # set seed self.seed = cfg.training.seed @@ -142,31 +144,32 @@ def run(self, rank, world_size): model=self.ema_model) # configure env - from hydra.core.hydra_config import HydraConfig - print(HydraConfig.get().runtime.output_dir) - env_runner: BaseImageRunner - env_runner = hydra.utils.instantiate( - cfg.task.env_runner, - output_dir=self.output_dir) - assert isinstance(env_runner, BaseImageRunner) - - # configure logging - wandb_run = wandb.init( - dir=str(self.output_dir), - config=OmegaConf.to_container(cfg, resolve=True), - **cfg.logging - ) - wandb.config.update( - { - "output_dir": self.output_dir, - } - ) - - # configure checkpoint - topk_manager = TopKCheckpointManager( - save_dir=os.path.join(self.output_dir, 'checkpoints'), - **cfg.checkpoint.topk - ) + # from hydra.core.hydra_config import HydraConfig + # print(HydraConfig.get().runtime.output_dir) + if rank == 0: + env_runner: BaseImageRunner + env_runner = hydra.utils.instantiate( + cfg.task.env_runner, + output_dir=self.output_dir_base) + assert isinstance(env_runner, BaseImageRunner) + + # configure logging + wandb_run = wandb.init( + dir=str(self.output_dir_base), + config=OmegaConf.to_container(cfg, resolve=True), + **cfg.logging + ) + wandb.config.update( + { + "output_dir": self.output_dir_base, + } + ) + + # configure checkpoint + topk_manager = TopKCheckpointManager( + save_dir=os.path.join(self.output_dir_base, 'checkpoints'), + **cfg.checkpoint.topk + ) # device transfer # device = torch.device(cfg.training.device) @@ -188,7 +191,7 @@ def run(self, rank, world_size): cfg.training.sample_every = 1 # training loop - log_path = os.path.join(self.output_dir, 'logs.json.txt') + log_path = os.path.join(self.output_dir_base, 'logs.json.txt') with JsonLogger(log_path) as json_logger: for local_epoch_idx in range(cfg.training.num_epochs): step_log = dict() @@ -211,7 +214,7 @@ def run(self, rank, world_size): train_sampling_batch = batch # compute loss - raw_loss = self.model.compute_loss(batch) + raw_loss = self.model.module.compute_loss(batch) loss = raw_loss / cfg.training.gradient_accumulate_every start_time = time.time() @@ -236,29 +239,32 @@ def run(self, rank, world_size): ema.step(self.model) # logging - raw_loss_cpu = raw_loss.item() - tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) - train_losses.append(raw_loss_cpu) - step_log = { - 'train_loss': raw_loss_cpu, - 'global_step': self.global_step, - 'epoch': self.epoch, - 'lr': lr_scheduler.get_last_lr()[0] - } - - is_last_batch = (batch_idx == (len(train_dataloader)-1)) - if not is_last_batch: - # log of last step is combined with validation and rollout - wandb_run.log(step_log, step=self.global_step) - json_logger.log(step_log) - self.global_step += 1 - - if (cfg.training.max_train_steps is not None) \ - and batch_idx >= (cfg.training.max_train_steps-1): - break - start_time = time.time() - print("Logger", start_time-end_time) - end_time = time.time() + if rank == 0: + raw_loss_cpu = raw_loss.item() + tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) + train_losses.append(raw_loss_cpu) + step_log = { + 'train_loss': raw_loss_cpu, + 'global_step': self.global_step, + 'epoch': self.epoch, + 'lr': lr_scheduler.get_last_lr()[0] + } + + is_last_batch = (batch_idx == (len(train_dataloader)-1)) + + if not is_last_batch: + # log of last step is combined with validation and rollout + + wandb_run.log(step_log, step=self.global_step) + json_logger.log(step_log) + self.global_step += 1 + + if (cfg.training.max_train_steps is not None) \ + and batch_idx >= (cfg.training.max_train_steps-1): + break + start_time = time.time() + print("Logger", start_time-end_time) + end_time = time.time() # at the end of each epoch # replace train_loss with epoch average @@ -285,7 +291,7 @@ def run(self, rank, world_size): leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: for batch_idx, batch in enumerate(tepoch): batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) - loss = self.model.compute_loss(batch) + loss = self.model.module.compute_loss(batch) val_losses.append(loss) if (cfg.training.max_val_steps is not None) \ and batch_idx >= (cfg.training.max_val_steps-1): From 594665e2d1c44a0ea11e30a25cb3503be7033439 Mon Sep 17 00:00:00 2001 From: TMats Date: Fri, 26 May 2023 22:15:21 +0900 Subject: [PATCH 5/7] fix: output_dir --- diffusion_policy/common/sampler.py | 2 +- .../train_diffusion_unet_hybrid_workspace.py | 176 +++++++++--------- 2 files changed, 92 insertions(+), 86 deletions(-) diff --git a/diffusion_policy/common/sampler.py b/diffusion_policy/common/sampler.py index 4accaa09..ec8a13d8 100644 --- a/diffusion_policy/common/sampler.py +++ b/diffusion_policy/common/sampler.py @@ -109,7 +109,7 @@ def __init__(self, indices = np.zeros((0,4), dtype=np.int64) # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx) - self.indices = indices + self.indices = indices # [:200] self.keys = list(keys) # prevent OmegaConf list performance problem self.sequence_length = sequence_length self.replay_buffer = replay_buffer diff --git a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py index b2168ad1..d778d182 100644 --- a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py @@ -53,8 +53,8 @@ class TrainDiffusionUnetHybridWorkspace(BaseWorkspace): def __init__(self, cfg: OmegaConf, output_dir=None): super().__init__(cfg, output_dir=output_dir) - print(self.output_dir) - self.output_dir_base = self.output_dir + + self.output_dir_base = super().output_dir # set seed self.seed = cfg.training.seed @@ -76,6 +76,21 @@ def __init__(self, cfg: OmegaConf, output_dir=None): # configure training state self.global_step = 0 self.epoch = 0 + # configure logging + self.wandb_run = wandb.init( + dir=str(self.output_dir_base), + config=OmegaConf.to_container(cfg, resolve=True), + **cfg.logging + ) + wandb.config.update( + { + "output_dir": self.output_dir_base, + } + ) + + @property + def output_dir(self): + return self.output_dir_base def run(self, rank, world_size): print("rank:", rank, "world_size:", world_size) @@ -153,17 +168,7 @@ def run(self, rank, world_size): output_dir=self.output_dir_base) assert isinstance(env_runner, BaseImageRunner) - # configure logging - wandb_run = wandb.init( - dir=str(self.output_dir_base), - config=OmegaConf.to_container(cfg, resolve=True), - **cfg.logging - ) - wandb.config.update( - { - "output_dir": self.output_dir_base, - } - ) + # configure checkpoint topk_manager = TopKCheckpointManager( @@ -239,7 +244,7 @@ def run(self, rank, world_size): ema.step(self.model) # logging - if rank == 0: + if True: raw_loss_cpu = raw_loss.item() tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) train_losses.append(raw_loss_cpu) @@ -255,7 +260,7 @@ def run(self, rank, world_size): if not is_last_batch: # log of last step is combined with validation and rollout - wandb_run.log(step_log, step=self.global_step) + self.wandb_run.log(step_log, step=self.global_step) json_logger.log(step_log) self.global_step += 1 @@ -272,81 +277,82 @@ def run(self, rank, world_size): step_log['train_loss'] = train_loss # ========= eval for this epoch ========== - policy = self.model - if cfg.training.use_ema: - policy = self.ema_model - policy.eval() - - # run rollout - if (self.epoch % cfg.training.rollout_every) == 0: - runner_log = env_runner.run(policy) - # log all - step_log.update(runner_log) - - # run validation - if (self.epoch % cfg.training.val_every) == 0: - with torch.no_grad(): - val_losses = list() - with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", - leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: - for batch_idx, batch in enumerate(tepoch): - batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) - loss = self.model.module.compute_loss(batch) - val_losses.append(loss) - if (cfg.training.max_val_steps is not None) \ - and batch_idx >= (cfg.training.max_val_steps-1): - break - if len(val_losses) > 0: - val_loss = torch.mean(torch.tensor(val_losses)).item() - # log epoch average validation loss - step_log['val_loss'] = val_loss - - # run diffusion sampling on a training batch - if (self.epoch % cfg.training.sample_every) == 0: - with torch.no_grad(): - # sample trajectory from training set, and evaluate difference - batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) - obs_dict = batch['obs'] - gt_action = batch['action'] - - result = policy.predict_action(obs_dict) - pred_action = result['action_pred'] - mse = torch.nn.functional.mse_loss(pred_action, gt_action) - step_log['train_action_mse_error'] = mse.item() - del batch - del obs_dict - del gt_action - del result - del pred_action - del mse - - # checkpoint - if (self.epoch % cfg.training.checkpoint_every) == 0: - # checkpointing - if cfg.checkpoint.save_last_ckpt: - self.save_checkpoint() - if cfg.checkpoint.save_last_snapshot: - self.save_snapshot() - - # sanitize metric names - metric_dict = dict() - for key, value in step_log.items(): - new_key = key.replace('/', '_') - metric_dict[new_key] = value + if rank == 0: + policy = self.model.module + if cfg.training.use_ema: + policy = self.ema_model.module + policy.eval() + + # run rollout + # if (self.epoch % cfg.training.rollout_every) == 0: + # runner_log = env_runner.run(policy) + # # log all + # step_log.update(runner_log) + + # run validation + if (self.epoch % cfg.training.val_every) == 0: + with torch.no_grad(): + val_losses = list() + with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", + leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: + for batch_idx, batch in enumerate(tepoch): + batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) + loss = self.model.module.compute_loss(batch) + val_losses.append(loss) + if (cfg.training.max_val_steps is not None) \ + and batch_idx >= (cfg.training.max_val_steps-1): + break + if len(val_losses) > 0: + val_loss = torch.mean(torch.tensor(val_losses)).item() + # log epoch average validation loss + step_log['val_loss'] = val_loss + + # run diffusion sampling on a training batch + if (self.epoch % cfg.training.sample_every) == 0: + with torch.no_grad(): + # sample trajectory from training set, and evaluate difference + batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) + obs_dict = batch['obs'] + gt_action = batch['action'] + + result = policy.predict_action(obs_dict) + pred_action = result['action_pred'] + mse = torch.nn.functional.mse_loss(pred_action, gt_action) + step_log['train_action_mse_error'] = mse.item() + del batch + del obs_dict + del gt_action + del result + del pred_action + del mse - # We can't copy the last checkpoint here - # since save_checkpoint uses threads. - # therefore at this point the file might have been empty! - topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) + # checkpoint + if (self.epoch % cfg.training.checkpoint_every) == 0: + # checkpointing + if cfg.checkpoint.save_last_ckpt: + self.save_checkpoint() + if cfg.checkpoint.save_last_snapshot: + self.save_snapshot() + + # sanitize metric names + metric_dict = dict() + for key, value in step_log.items(): + new_key = key.replace('/', '_') + metric_dict[new_key] = value + + # We can't copy the last checkpoint here + # since save_checkpoint uses threads. + # therefore at this point the file might have been empty! + topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) - if topk_ckpt_path is not None: - self.save_checkpoint(path=topk_ckpt_path) - # ========= eval end for this epoch ========== - policy.train() + if topk_ckpt_path is not None: + self.save_checkpoint(path=topk_ckpt_path) + # ========= eval end for this epoch ========== + policy.train() # end of epoch # log of last step is combined with validation and rollout - wandb_run.log(step_log, step=self.global_step) + self.wandb_run.log(step_log, step=self.global_step) json_logger.log(step_log) self.global_step += 1 self.epoch += 1 From 625c1cc91ea22eb66bec1372331533b37b54d7bf Mon Sep 17 00:00:00 2001 From: TMats Date: Fri, 26 May 2023 22:41:22 +0900 Subject: [PATCH 6/7] fix: log interval --- .../train_diffusion_unet_hybrid_workspace.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py index d778d182..68bcd636 100644 --- a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py @@ -161,7 +161,7 @@ def run(self, rank, world_size): # configure env # from hydra.core.hydra_config import HydraConfig # print(HydraConfig.get().runtime.output_dir) - if rank == 0: + if False: env_runner: BaseImageRunner env_runner = hydra.utils.instantiate( cfg.task.env_runner, @@ -170,11 +170,11 @@ def run(self, rank, world_size): - # configure checkpoint - topk_manager = TopKCheckpointManager( - save_dir=os.path.join(self.output_dir_base, 'checkpoints'), - **cfg.checkpoint.topk - ) + # configure checkpoint + topk_manager = TopKCheckpointManager( + save_dir=os.path.join(self.output_dir_base, 'checkpoints'), + **cfg.checkpoint.topk + ) # device transfer # device = torch.device(cfg.training.device) @@ -244,7 +244,7 @@ def run(self, rank, world_size): ema.step(self.model) # logging - if True: + if self.global_step % 100 == 0: raw_loss_cpu = raw_loss.item() tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) train_losses.append(raw_loss_cpu) From 911ca3352c45ee056c4065e6e5a7248ddc60aa4a Mon Sep 17 00:00:00 2001 From: TMats Date: Sun, 28 May 2023 15:45:50 +0900 Subject: [PATCH 7/7] fix: tqdm text --- .../workspace/train_diffusion_unet_hybrid_workspace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py index 68bcd636..7304d9bf 100644 --- a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py @@ -202,7 +202,7 @@ def run(self, rank, world_size): step_log = dict() # ========= train for this epoch ========== train_losses = list() - with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", + with tqdm.tqdm(train_dataloader, desc=f"GPU {rank}: Training epoch {self.epoch}", leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: import time end_time = time.time()