diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..48a06f1c5 --- /dev/null +++ b/config.yaml @@ -0,0 +1,379 @@ +_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace +checkpoint: + save_last_ckpt: true + save_last_snapshot: false + topk: + format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt + k: 5 + mode: max + monitor_key: test_mean_score +dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: true +dataset_obs_steps: 2 +ema: + _target_: diffusion_policy.model.diffusion.ema_model.EMAModel + inv_gamma: 1.0 + max_value: 0.9999 + min_value: 0.0 + power: 0.75 + update_after_step: 0 +exp_name: default +horizon: 16 +keypoint_visible_rate: 1.0 +logging: + group: null + id: null + mode: online + name: 2023.01.03-19.43.07_train_diffusion_unet_hybrid_transport_image + project: diffusion_policy_debug + resume: true + tags: + - train_diffusion_unet_hybrid + - transport_image + - default +multi_run: + run_dir: data/outputs/2023.01.03/19.43.07_train_diffusion_unet_hybrid_transport_image + wandb_name_base: 2023.01.03-19.43.07_train_diffusion_unet_hybrid_transport_image +n_action_steps: 8 +n_latency_steps: 0 +n_obs_steps: 2 +name: train_diffusion_unet_hybrid +obs_as_global_cond: true +optimizer: + _target_: torch.optim.AdamW + betas: + - 0.95 + - 0.999 + eps: 1.0e-08 + lr: 0.0001 + weight_decay: 1.0e-06 +past_action_visible: false +policy: + _target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy + cond_predict_scale: true + crop_shape: + - 76 + - 76 + diffusion_step_embed_dim: 128 + down_dims: + - 512 + - 1024 + - 2048 + eval_fixed_crop: true + horizon: 16 + kernel_size: 5 + n_action_steps: 8 + n_groups: 8 + n_obs_steps: 2 + noise_scheduler: + _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + beta_end: 0.02 + beta_schedule: squaredcos_cap_v2 + beta_start: 0.0001 + clip_sample: true + num_train_timesteps: 100 + prediction_type: epsilon + variance_type: fixed_small + num_inference_steps: 100 + obs_as_global_cond: true + obs_encoder_group_norm: true + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb +shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb +task: + abs_action: true + dataset: + _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset + abs_action: true + dataset_path: data/robomimic/datasets/transport/mh/image_abs.hdf5 + horizon: 16 + n_obs_steps: 2 + pad_after: 7 + pad_before: 1 + rotation_rep: rotation_6d + seed: 42 + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb + use_cache: false + val_ratio: 0.02 + dataset_path: data/robomimic/datasets/transport/mh/image_abs.hdf5 + dataset_type: mh + env_runner: + _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner + abs_action: true + crf: 22 + dataset_path: data/robomimic/datasets/transport/mh/image_abs.hdf5 + fps: 10 + max_steps: 700 + n_action_steps: 8 + n_envs: 28 + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + render_obs_key: shouldercamera0_image + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb + test_start_seed: 100000 + tqdm_interval_sec: 1.0 + train_start_idx: 0 + name: transport_image + shape_meta: + action: + shape: + - 20 + obs: + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + robot1_eef_pos: + shape: + - 3 + robot1_eef_quat: + shape: + - 4 + robot1_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot1_gripper_qpos: + shape: + - 2 + shouldercamera0_image: + shape: + - 3 + - 84 + - 84 + type: rgb + shouldercamera1_image: + shape: + - 3 + - 84 + - 84 + type: rgb + task_name: transport +task_name: transport_image +training: + checkpoint_every: 50 + debug: false + device: cuda:0 + gradient_accumulate_every: 1 + lr_scheduler: cosine + lr_warmup_steps: 500 + max_train_steps: null + max_val_steps: null + num_epochs: 3000 + resume: true + rollout_every: 50 + sample_every: 5 + seed: 42 + tqdm_interval_sec: 1.0 + use_ema: true + val_every: 1 +val_dataloader: + batch_size: 64 + num_workers: 8 + persistent_workers: false + pin_memory: true + shuffle: false diff --git a/diffusion_policy/common/replay_buffer.py b/diffusion_policy/common/replay_buffer.py index 022a704ed..b5decba3d 100644 --- a/diffusion_policy/common/replay_buffer.py +++ b/diffusion_policy/common/replay_buffer.py @@ -586,3 +586,19 @@ def set_compressors(self, compressors: dict): compressor = self.resolve_compressor(value) if compressor != arr.compressor: rechunk_recompress_array(self.data, key, compressor=compressor) + + +class MemoryReplayBuffer: + def __init__(self): + self.data = {} + self.meta = {} + self.n_episodes = 0 + + @property + def episode_ends(self): + return self.meta['episode_ends'] + + def keys(self): + return self.data.keys() + def __getitem__(self, key): + return self.data[key] \ No newline at end of file diff --git a/diffusion_policy/common/sampler.py b/diffusion_policy/common/sampler.py index 4accaa093..ec8a13d84 100644 --- a/diffusion_policy/common/sampler.py +++ b/diffusion_policy/common/sampler.py @@ -109,7 +109,7 @@ def __init__(self, indices = np.zeros((0,4), dtype=np.int64) # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx) - self.indices = indices + self.indices = indices # [:200] self.keys = list(keys) # prevent OmegaConf list performance problem self.sequence_length = sequence_length self.replay_buffer = replay_buffer diff --git a/diffusion_policy/config/task/can_image.yaml b/diffusion_policy/config/task/can_image.yaml index 10158bbd0..bc10d3dc7 100644 --- a/diffusion_policy/config/task/can_image.yaml +++ b/diffusion_policy/config/task/can_image.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/can_image_abs.yaml b/diffusion_policy/config/task/can_image_abs.yaml index eee34dc63..1f824c342 100644 --- a/diffusion_policy/config/task/can_image_abs.yaml +++ b/diffusion_policy/config/task/can_image_abs.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/lift_image.yaml b/diffusion_policy/config/task/lift_image.yaml index 8ddde4567..2f5734ee2 100644 --- a/diffusion_policy/config/task/lift_image.yaml +++ b/diffusion_policy/config/task/lift_image.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/lift_image_abs.yaml b/diffusion_policy/config/task/lift_image_abs.yaml index b25002f38..9a29fac7c 100644 --- a/diffusion_policy/config/task/lift_image_abs.yaml +++ b/diffusion_policy/config/task/lift_image_abs.yaml @@ -58,6 +58,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/real_pusht_image.yaml b/diffusion_policy/config/task/real_pusht_image.yaml index a3f7c3f08..41eb2094c 100644 --- a/diffusion_policy/config/task/real_pusht_image.yaml +++ b/diffusion_policy/config/task/real_pusht_image.yaml @@ -39,7 +39,7 @@ dataset: pad_after: ${eval:'${n_action_steps}-1'} n_obs_steps: ${dataset_obs_steps} n_latency_steps: ${n_latency_steps} - use_cache: True + use_cache: false seed: 42 val_ratio: 0.00 max_train_episodes: null diff --git a/diffusion_policy/config/task/square_image.yaml b/diffusion_policy/config/task/square_image.yaml index 827151be6..03e5edcd2 100644 --- a/diffusion_policy/config/task/square_image.yaml +++ b/diffusion_policy/config/task/square_image.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/square_image_abs.yaml b/diffusion_policy/config/task/square_image_abs.yaml index d27a916af..53c9ebb59 100644 --- a/diffusion_policy/config/task/square_image_abs.yaml +++ b/diffusion_policy/config/task/square_image_abs.yaml @@ -59,6 +59,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/tool_hang_image.yaml b/diffusion_policy/config/task/tool_hang_image.yaml index 8a8b298f4..6e046758c 100644 --- a/diffusion_policy/config/task/tool_hang_image.yaml +++ b/diffusion_policy/config/task/tool_hang_image.yaml @@ -58,6 +58,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/tool_hang_image_abs.yaml b/diffusion_policy/config/task/tool_hang_image_abs.yaml index 068bdc93a..92fa8e630 100644 --- a/diffusion_policy/config/task/tool_hang_image_abs.yaml +++ b/diffusion_policy/config/task/tool_hang_image_abs.yaml @@ -58,6 +58,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/transport_image.yaml b/diffusion_policy/config/task/transport_image.yaml index c9c24cc05..b1076e820 100644 --- a/diffusion_policy/config/task/transport_image.yaml +++ b/diffusion_policy/config/task/transport_image.yaml @@ -70,6 +70,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/config/task/transport_image_abs.yaml b/diffusion_policy/config/task/transport_image_abs.yaml index 74e76c602..3fe6f4da4 100644 --- a/diffusion_policy/config/task/transport_image_abs.yaml +++ b/diffusion_policy/config/task/transport_image_abs.yaml @@ -70,6 +70,6 @@ dataset: abs_action: *abs_action rotation_rep: 'rotation_6d' use_legacy_normalizer: False - use_cache: True + use_cache: false seed: 42 val_ratio: 0.02 diff --git a/diffusion_policy/dataset/robomimic_replay_image_dataset.py b/diffusion_policy/dataset/robomimic_replay_image_dataset.py index 2728e9e9e..3c7c4de27 100644 --- a/diffusion_policy/dataset/robomimic_replay_image_dataset.py +++ b/diffusion_policy/dataset/robomimic_replay_image_dataset.py @@ -19,7 +19,7 @@ from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer from diffusion_policy.model.common.rotation_transformer import RotationTransformer from diffusion_policy.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k -from diffusion_policy.common.replay_buffer import ReplayBuffer +from diffusion_policy.common.replay_buffer import ReplayBuffer, MemoryReplayBuffer from diffusion_policy.common.sampler import SequenceSampler, get_val_mask from diffusion_policy.common.normalize_util import ( robomimic_abs_action_only_normalizer_from_stat, @@ -81,7 +81,7 @@ def __init__(self, src_store=zip_store, store=zarr.MemoryStore()) print('Loaded!') else: - replay_buffer = _convert_robomimic_to_replay( + replay_buffer = _convert_robomimic_to_memory_replay( store=zarr.MemoryStore(), shape_meta=shape_meta, dataset_path=dataset_path, @@ -362,6 +362,95 @@ def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx): replay_buffer = ReplayBuffer(root) return replay_buffer + +def _convert_robomimic_to_memory_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer, + n_workers=None, max_inflight_tasks=None): + if n_workers is None: + n_workers = multiprocessing.cpu_count() + if max_inflight_tasks is None: + max_inflight_tasks = n_workers * 5 + + # parse shape_meta + rgb_keys = list() + lowdim_keys = list() + # construct compressors and chunks + obs_shape_meta = shape_meta['obs'] + for key, attr in obs_shape_meta.items(): + shape = attr['shape'] + type = attr.get('type', 'low_dim') + if type == 'rgb': + rgb_keys.append(key) + elif type == 'low_dim': + lowdim_keys.append(key) + + root = MemoryReplayBuffer() + + with h5py.File(dataset_path) as file: + # count total steps + demos = file['data'] + episode_ends = list() + prev_end = 0 + for i in range(len(demos)): + demo = demos[f'demo_{i}'] + episode_length = demo['actions'].shape[0] + episode_end = prev_end + episode_length + prev_end = episode_end + episode_ends.append(episode_end) + n_steps = episode_ends[-1] + episode_starts = [0] + episode_ends[:-1] + root.meta['episode_ends'] = np.array(episode_ends, dtype=np.int64) + root.n_episodes = len(episode_ends) + + # save lowdim data + for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"): + data_key = 'obs/' + key + if key == 'action': + data_key = 'actions' + this_data = list() + for i in range(len(demos)): + demo = demos[f'demo_{i}'] + this_data.append(demo[data_key][:].astype(np.float32)) + this_data = np.concatenate(this_data, axis=0) + if key == 'action': + this_data = _convert_actions( + raw_actions=this_data, + abs_action=abs_action, + rotation_transformer=rotation_transformer + ) + assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape']) + else: + assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape']) + + root.data[key] = np.array(this_data, dtype=this_data.dtype).reshape(this_data.shape) + # root.data[key] = torch.tensor(np.array(this_data, dtype=this_data.dtype).reshape(this_data.shape), dtype=torch.uint8) + + def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx): + try: + zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx] + # make sure we can successfully decode + _ = zarr_arr[zarr_idx] + return True + except Exception as e: + return False + # save img data + for key in tqdm(rgb_keys, desc="Loading image data"): + data_key = 'obs/' + key + hdf5_arr = list() + for episode_idx in range(len(demos)): + demo = demos[f'demo_{episode_idx}'] + # print(demo['obs'][key][:].astype(np.uint8).shape) + # print(demo['obs'][key][:].astype(np.uint8)) + shape = tuple(shape_meta['obs'][key]['shape']) + c,h,w = shape + hdf5_arr.append(demo['obs'][key][:].astype(np.uint8).reshape(-1,h,w,c)) + + hdf5_arr = np.concatenate(hdf5_arr, axis=0) + root.data[key] = np.array(hdf5_arr, dtype=hdf5_arr.dtype).reshape(hdf5_arr.shape) + # root.data[key] = torch.tensor(np.array(hdf5_arr, dtype=hdf5_arr.dtype).reshape(hdf5_arr.shape), dtype=torch.uint8) + replay_buffer = root + return replay_buffer + + def normalizer_from_stat(stat): max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max()) scale = np.full_like(stat['max'], fill_value=1/max_abs) diff --git a/diffusion_policy/workspace/base_workspace.py b/diffusion_policy/workspace/base_workspace.py index 1352404a3..88e20c0b4 100644 --- a/diffusion_policy/workspace/base_workspace.py +++ b/diffusion_policy/workspace/base_workspace.py @@ -9,6 +9,10 @@ import torch import threading +# @hydra.main( +# version_base=None, +# config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), +# config_name=pathlib.Path(__file__).stem) class BaseWorkspace: include_keys = tuple() diff --git a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py index 9219427ca..7304d9bf5 100644 --- a/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py @@ -31,20 +31,82 @@ OmegaConf.register_new_resolver("eval", eval, replace=True) + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + + class TrainDiffusionUnetHybridWorkspace(BaseWorkspace): include_keys = ['global_step', 'epoch'] def __init__(self, cfg: OmegaConf, output_dir=None): super().__init__(cfg, output_dir=output_dir) + + self.output_dir_base = super().output_dir # set seed - seed = cfg.training.seed + self.seed = cfg.training.seed + torch.manual_seed(self.seed) + np.random.seed(self.seed) + random.seed(self.seed) + + # # configure model + # self.model: DiffusionUnetHybridImagePolicy = hydra.utils.instantiate(cfg.policy) + + # self.ema_model: DiffusionUnetHybridImagePolicy = None + # if cfg.training.use_ema: + # self.ema_model = copy.deepcopy(self.model) + + # # configure training state + # self.optimizer = hydra.utils.instantiate( + # cfg.optimizer, params=self.model.parameters()) + + # configure training state + self.global_step = 0 + self.epoch = 0 + # configure logging + self.wandb_run = wandb.init( + dir=str(self.output_dir_base), + config=OmegaConf.to_container(cfg, resolve=True), + **cfg.logging + ) + wandb.config.update( + { + "output_dir": self.output_dir_base, + } + ) + + @property + def output_dir(self): + return self.output_dir_base + + def run(self, rank, world_size): + print("rank:", rank, "world_size:", world_size) + setup(rank, world_size) + seed = self.seed + dist.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) + + device = f"cuda:{rank}" + cfg = copy.deepcopy(self.cfg) # configure model self.model: DiffusionUnetHybridImagePolicy = hydra.utils.instantiate(cfg.policy) + # configure the model for DDP + self.model = DDP(self.model) self.ema_model: DiffusionUnetHybridImagePolicy = None if cfg.training.use_ema: @@ -53,14 +115,6 @@ def __init__(self, cfg: OmegaConf, output_dir=None): # configure training state self.optimizer = hydra.utils.instantiate( cfg.optimizer, params=self.model.parameters()) - - # configure training state - self.global_step = 0 - self.epoch = 0 - - def run(self): - cfg = copy.deepcopy(self.cfg) - # resume training if cfg.training.resume: lastest_ckpt_path = self.get_checkpoint_path() @@ -72,6 +126,7 @@ def run(self): dataset: BaseImageDataset dataset = hydra.utils.instantiate(cfg.task.dataset) assert isinstance(dataset, BaseImageDataset) + # train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_dataloader = DataLoader(dataset, **cfg.dataloader) normalizer = dataset.get_normalizer() @@ -79,9 +134,9 @@ def run(self): val_dataset = dataset.get_validation_dataset() val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader) - self.model.set_normalizer(normalizer) + self.model.module.set_normalizer(normalizer) if cfg.training.use_ema: - self.ema_model.set_normalizer(normalizer) + self.ema_model.module.set_normalizer(normalizer) # configure lr scheduler lr_scheduler = get_scheduler( @@ -104,32 +159,25 @@ def run(self): model=self.ema_model) # configure env - env_runner: BaseImageRunner - env_runner = hydra.utils.instantiate( - cfg.task.env_runner, - output_dir=self.output_dir) - assert isinstance(env_runner, BaseImageRunner) + # from hydra.core.hydra_config import HydraConfig + # print(HydraConfig.get().runtime.output_dir) + if False: + env_runner: BaseImageRunner + env_runner = hydra.utils.instantiate( + cfg.task.env_runner, + output_dir=self.output_dir_base) + assert isinstance(env_runner, BaseImageRunner) - # configure logging - wandb_run = wandb.init( - dir=str(self.output_dir), - config=OmegaConf.to_container(cfg, resolve=True), - **cfg.logging - ) - wandb.config.update( - { - "output_dir": self.output_dir, - } - ) + # configure checkpoint topk_manager = TopKCheckpointManager( - save_dir=os.path.join(self.output_dir, 'checkpoints'), + save_dir=os.path.join(self.output_dir_base, 'checkpoints'), **cfg.checkpoint.topk ) # device transfer - device = torch.device(cfg.training.device) + # device = torch.device(cfg.training.device) self.model.to(device) if self.ema_model is not None: self.ema_model.to(device) @@ -148,56 +196,80 @@ def run(self): cfg.training.sample_every = 1 # training loop - log_path = os.path.join(self.output_dir, 'logs.json.txt') + log_path = os.path.join(self.output_dir_base, 'logs.json.txt') with JsonLogger(log_path) as json_logger: for local_epoch_idx in range(cfg.training.num_epochs): step_log = dict() # ========= train for this epoch ========== train_losses = list() - with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", + with tqdm.tqdm(train_dataloader, desc=f"GPU {rank}: Training epoch {self.epoch}", leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: + import time + end_time = time.time() for batch_idx, batch in enumerate(tepoch): + start_time = time.time() + print("Batch", start_time-end_time) + end_time = time.time() # device transfer batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) + start_time = time.time() + print("GPU", start_time-end_time) + end_time = time.time() if train_sampling_batch is None: train_sampling_batch = batch # compute loss - raw_loss = self.model.compute_loss(batch) + raw_loss = self.model.module.compute_loss(batch) + loss = raw_loss / cfg.training.gradient_accumulate_every + start_time = time.time() + print("Loss", start_time-end_time) + end_time = time.time() + loss.backward() + start_time = time.time() + print("Backword", start_time-end_time) + end_time = time.time() # step optimizer if self.global_step % cfg.training.gradient_accumulate_every == 0: self.optimizer.step() self.optimizer.zero_grad() lr_scheduler.step() - + start_time = time.time() + print("Optimizer", start_time-end_time) + end_time = time.time() # update ema if cfg.training.use_ema: ema.step(self.model) # logging - raw_loss_cpu = raw_loss.item() - tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) - train_losses.append(raw_loss_cpu) - step_log = { - 'train_loss': raw_loss_cpu, - 'global_step': self.global_step, - 'epoch': self.epoch, - 'lr': lr_scheduler.get_last_lr()[0] - } - - is_last_batch = (batch_idx == (len(train_dataloader)-1)) - if not is_last_batch: - # log of last step is combined with validation and rollout - wandb_run.log(step_log, step=self.global_step) - json_logger.log(step_log) - self.global_step += 1 - - if (cfg.training.max_train_steps is not None) \ - and batch_idx >= (cfg.training.max_train_steps-1): - break + if self.global_step % 100 == 0: + raw_loss_cpu = raw_loss.item() + tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) + train_losses.append(raw_loss_cpu) + step_log = { + 'train_loss': raw_loss_cpu, + 'global_step': self.global_step, + 'epoch': self.epoch, + 'lr': lr_scheduler.get_last_lr()[0] + } + + is_last_batch = (batch_idx == (len(train_dataloader)-1)) + + if not is_last_batch: + # log of last step is combined with validation and rollout + + self.wandb_run.log(step_log, step=self.global_step) + json_logger.log(step_log) + self.global_step += 1 + + if (cfg.training.max_train_steps is not None) \ + and batch_idx >= (cfg.training.max_train_steps-1): + break + start_time = time.time() + print("Logger", start_time-end_time) + end_time = time.time() # at the end of each epoch # replace train_loss with epoch average @@ -205,92 +277,109 @@ def run(self): step_log['train_loss'] = train_loss # ========= eval for this epoch ========== - policy = self.model - if cfg.training.use_ema: - policy = self.ema_model - policy.eval() - - # run rollout - if (self.epoch % cfg.training.rollout_every) == 0: - runner_log = env_runner.run(policy) - # log all - step_log.update(runner_log) - - # run validation - if (self.epoch % cfg.training.val_every) == 0: - with torch.no_grad(): - val_losses = list() - with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", - leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: - for batch_idx, batch in enumerate(tepoch): - batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) - loss = self.model.compute_loss(batch) - val_losses.append(loss) - if (cfg.training.max_val_steps is not None) \ - and batch_idx >= (cfg.training.max_val_steps-1): - break - if len(val_losses) > 0: - val_loss = torch.mean(torch.tensor(val_losses)).item() - # log epoch average validation loss - step_log['val_loss'] = val_loss - - # run diffusion sampling on a training batch - if (self.epoch % cfg.training.sample_every) == 0: - with torch.no_grad(): - # sample trajectory from training set, and evaluate difference - batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) - obs_dict = batch['obs'] - gt_action = batch['action'] - - result = policy.predict_action(obs_dict) - pred_action = result['action_pred'] - mse = torch.nn.functional.mse_loss(pred_action, gt_action) - step_log['train_action_mse_error'] = mse.item() - del batch - del obs_dict - del gt_action - del result - del pred_action - del mse - - # checkpoint - if (self.epoch % cfg.training.checkpoint_every) == 0: - # checkpointing - if cfg.checkpoint.save_last_ckpt: - self.save_checkpoint() - if cfg.checkpoint.save_last_snapshot: - self.save_snapshot() - - # sanitize metric names - metric_dict = dict() - for key, value in step_log.items(): - new_key = key.replace('/', '_') - metric_dict[new_key] = value + if rank == 0: + policy = self.model.module + if cfg.training.use_ema: + policy = self.ema_model.module + policy.eval() + + # run rollout + # if (self.epoch % cfg.training.rollout_every) == 0: + # runner_log = env_runner.run(policy) + # # log all + # step_log.update(runner_log) + + # run validation + if (self.epoch % cfg.training.val_every) == 0: + with torch.no_grad(): + val_losses = list() + with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", + leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: + for batch_idx, batch in enumerate(tepoch): + batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) + loss = self.model.module.compute_loss(batch) + val_losses.append(loss) + if (cfg.training.max_val_steps is not None) \ + and batch_idx >= (cfg.training.max_val_steps-1): + break + if len(val_losses) > 0: + val_loss = torch.mean(torch.tensor(val_losses)).item() + # log epoch average validation loss + step_log['val_loss'] = val_loss + + # run diffusion sampling on a training batch + if (self.epoch % cfg.training.sample_every) == 0: + with torch.no_grad(): + # sample trajectory from training set, and evaluate difference + batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) + obs_dict = batch['obs'] + gt_action = batch['action'] + + result = policy.predict_action(obs_dict) + pred_action = result['action_pred'] + mse = torch.nn.functional.mse_loss(pred_action, gt_action) + step_log['train_action_mse_error'] = mse.item() + del batch + del obs_dict + del gt_action + del result + del pred_action + del mse - # We can't copy the last checkpoint here - # since save_checkpoint uses threads. - # therefore at this point the file might have been empty! - topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) + # checkpoint + if (self.epoch % cfg.training.checkpoint_every) == 0: + # checkpointing + if cfg.checkpoint.save_last_ckpt: + self.save_checkpoint() + if cfg.checkpoint.save_last_snapshot: + self.save_snapshot() + + # sanitize metric names + metric_dict = dict() + for key, value in step_log.items(): + new_key = key.replace('/', '_') + metric_dict[new_key] = value + + # We can't copy the last checkpoint here + # since save_checkpoint uses threads. + # therefore at this point the file might have been empty! + topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) - if topk_ckpt_path is not None: - self.save_checkpoint(path=topk_ckpt_path) - # ========= eval end for this epoch ========== - policy.train() + if topk_ckpt_path is not None: + self.save_checkpoint(path=topk_ckpt_path) + # ========= eval end for this epoch ========== + policy.train() # end of epoch # log of last step is combined with validation and rollout - wandb_run.log(step_log, step=self.global_step) + self.wandb_run.log(step_log, step=self.global_step) json_logger.log(step_log) self.global_step += 1 self.epoch += 1 + cleanup() @hydra.main( version_base=None, config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), config_name=pathlib.Path(__file__).stem) -def main(cfg): - workspace = TrainDiffusionUnetHybridWorkspace(cfg) - workspace.run() +# def main(cfg): +# workspace = TrainDiffusionUnetHybridWorkspace(cfg) +# workspace.run() + +# if __name__ == "__main__": +# main() + + +def main(cfg: OmegaConf): + # resolve immediately so all the ${now:} resolvers + # will use the same time. + OmegaConf.resolve(cfg) + + cls = hydra.utils.get_class(cfg._target_) + workspace: BaseWorkspace = cls(cfg) + world_size = torch.cuda.device_count() + mp.spawn(workspace.run, args=(world_size,), nprocs=world_size, join=True) + if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/train.py b/train.py index 51c564e6f..6ca730905 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,8 @@ from omegaconf import OmegaConf import pathlib from diffusion_policy.workspace.base_workspace import BaseWorkspace +import torch +import torch.multiprocessing as mp # allows arbitrary python code execution in configs using the ${eval:''} resolver OmegaConf.register_new_resolver("eval", eval, replace=True) @@ -22,6 +24,15 @@ config_path=str(pathlib.Path(__file__).parent.joinpath( 'diffusion_policy','config')) ) +# def main(cfg: OmegaConf): +# # resolve immediately so all the ${now:} resolvers +# # will use the same time. +# OmegaConf.resolve(cfg) + +# cls = hydra.utils.get_class(cfg._target_) +# workspace: BaseWorkspace = cls(cfg) +# workspace.run() + def main(cfg: OmegaConf): # resolve immediately so all the ${now:} resolvers # will use the same time. @@ -29,7 +40,9 @@ def main(cfg: OmegaConf): cls = hydra.utils.get_class(cfg._target_) workspace: BaseWorkspace = cls(cfg) - workspace.run() + world_size = torch.cuda.device_count() + mp.spawn(workspace.run, args=(world_size,), nprocs=world_size, join=True) + if __name__ == "__main__": main()