From 003cac4758b6c049c15f5af685cca6ce3356bea7 Mon Sep 17 00:00:00 2001 From: Andrii-Sheba Date: Tue, 12 Nov 2024 17:28:28 +0100 Subject: [PATCH] forklift 7kp detection at new mmpose version --- .gitignore | 3 + configs/_base_/datasets/coco.py | 150 ++-------- configs/_base_/default_runtime.py | 6 +- .../td-hm_res50_8xb64-210e_coco-256x192.py | 96 +++++-- docker/Dockerfile | 2 + docker/Dockerfile_aws | 36 +++ .../datasets/transforms/topdown_transforms.py | 19 ++ run_docker_gpu0.sh | 3 + tools/test.py | 15 + tools/train_grid.py | 261 ++++++++++++++++++ 10 files changed, 445 insertions(+), 146 deletions(-) create mode 100644 docker/Dockerfile_aws create mode 100644 run_docker_gpu0.sh create mode 100644 tools/train_grid.py diff --git a/.gitignore b/.gitignore index 2b337460f3..cc60fb71a0 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,6 @@ docs/**/modelzoo.md *.pth *.DS_Store + +data/* +work_dirs/* diff --git a/configs/_base_/datasets/coco.py b/configs/_base_/datasets/coco.py index 865a95bc02..a8514fc52c 100644 --- a/configs/_base_/datasets/coco.py +++ b/configs/_base_/datasets/coco.py @@ -12,170 +12,74 @@ ), keypoint_info={ 0: - dict(name='nose', id=0, color=[51, 153, 255], type='upper', swap=''), + dict( + name='rear_left', + id=0, + color=[51, 153, 255], + type='upper', + swap='rear_right'), 1: dict( - name='left_eye', + name='rear_right', id=1, color=[51, 153, 255], type='upper', - swap='right_eye'), + swap='rear_left'), 2: dict( - name='right_eye', + name='front_left', id=2, color=[51, 153, 255], type='upper', - swap='left_eye'), + swap='front_right'), 3: dict( - name='left_ear', + name='front_right', id=3, color=[51, 153, 255], type='upper', - swap='right_ear'), + swap='front_left'), 4: dict( - name='right_ear', + name='L_Fork', id=4, color=[51, 153, 255], type='upper', - swap='left_ear'), + swap='R_Fork'), 5: dict( - name='left_shoulder', + name='R_Fork', id=5, color=[0, 255, 0], type='upper', - swap='right_shoulder'), + swap='L_Fork'), 6: dict( - name='right_shoulder', + name='C_Fork', id=6, color=[255, 128, 0], type='upper', - swap='left_shoulder'), - 7: - dict( - name='left_elbow', - id=7, - color=[0, 255, 0], - type='upper', - swap='right_elbow'), - 8: - dict( - name='right_elbow', - id=8, - color=[255, 128, 0], - type='upper', - swap='left_elbow'), - 9: - dict( - name='left_wrist', - id=9, - color=[0, 255, 0], - type='upper', - swap='right_wrist'), - 10: - dict( - name='right_wrist', - id=10, - color=[255, 128, 0], - type='upper', - swap='left_wrist'), - 11: - dict( - name='left_hip', - id=11, - color=[0, 255, 0], - type='lower', - swap='right_hip'), - 12: - dict( - name='right_hip', - id=12, - color=[255, 128, 0], - type='lower', - swap='left_hip'), - 13: - dict( - name='left_knee', - id=13, - color=[0, 255, 0], - type='lower', - swap='right_knee'), - 14: - dict( - name='right_knee', - id=14, - color=[255, 128, 0], - type='lower', - swap='left_knee'), - 15: - dict( - name='left_ankle', - id=15, - color=[0, 255, 0], - type='lower', - swap='right_ankle'), - 16: - dict( - name='right_ankle', - id=16, - color=[255, 128, 0], - type='lower', - swap='left_ankle') + swap=''), }, skeleton_info={ 0: - dict(link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]), + dict(link=('rear_left', 'rear_right'), id=0, color=[0, 255, 0]), 1: - dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]), + dict(link=('front_left', 'front_right'), id=1, color=[0, 255, 0]), 2: - dict(link=('right_ankle', 'right_knee'), id=2, color=[255, 128, 0]), + dict(link=('rear_left', 'front_left'), id=2, color=[0, 255, 0]), 3: - dict(link=('right_knee', 'right_hip'), id=3, color=[255, 128, 0]), + dict(link=('rear_right', 'front_right'), id=3, color=[0, 255, 0]), 4: - dict(link=('left_hip', 'right_hip'), id=4, color=[51, 153, 255]), + dict(link=('L_Fork', 'R_Fork'), id=4, color=[255, 128, 0]), 5: - dict(link=('left_shoulder', 'left_hip'), id=5, color=[51, 153, 255]), + dict(link=('L_Fork', 'C_Fork'), id=5, color=[255, 128, 0]), 6: - dict(link=('right_shoulder', 'right_hip'), id=6, color=[51, 153, 255]), - 7: - dict( - link=('left_shoulder', 'right_shoulder'), - id=7, - color=[51, 153, 255]), - 8: - dict(link=('left_shoulder', 'left_elbow'), id=8, color=[0, 255, 0]), - 9: - dict( - link=('right_shoulder', 'right_elbow'), id=9, color=[255, 128, 0]), - 10: - dict(link=('left_elbow', 'left_wrist'), id=10, color=[0, 255, 0]), - 11: - dict(link=('right_elbow', 'right_wrist'), id=11, color=[255, 128, 0]), - 12: - dict(link=('left_eye', 'right_eye'), id=12, color=[51, 153, 255]), - 13: - dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]), - 14: - dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]), - 15: - dict(link=('left_eye', 'left_ear'), id=15, color=[51, 153, 255]), - 16: - dict(link=('right_eye', 'right_ear'), id=16, color=[51, 153, 255]), - 17: - dict(link=('left_ear', 'left_shoulder'), id=17, color=[51, 153, 255]), - 18: - dict( - link=('right_ear', 'right_shoulder'), id=18, color=[51, 153, 255]) + dict(link=('C_Fork', 'R_Fork'), id=6, color=[255, 128, 0]), }, joint_weights=[ - 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, - 1.5 + 1., 1., 1., 1., 1., 1., 1., ], sigmas=[ - 0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, - 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 + 0.05, 0.05, 0.05, 0.05, 0.06, 0.06, 0.07, ]) diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 6f27c0345a..c3db7f5266 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -5,9 +5,9 @@ timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=50), param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=10), + checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=1), sampler_seed=dict(type='DistSamplerSeedHook'), - visualization=dict(type='PoseVisualizationHook', enable=False), + visualization=dict(type='PoseVisualizationHook', enable=True), badcase=dict( type='BadCaseAnalysisHook', enable=False, @@ -32,7 +32,7 @@ # visualizer vis_backends = [ dict(type='LocalVisBackend'), - # dict(type='TensorboardVisBackend'), + dict(type='TensorboardVisBackend'), # dict(type='WandbVisBackend'), ] visualizer = dict( diff --git a/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res50_8xb64-210e_coco-256x192.py b/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res50_8xb64-210e_coco-256x192.py index 7dbe1b43f7..a814624493 100644 --- a/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res50_8xb64-210e_coco-256x192.py +++ b/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res50_8xb64-210e_coco-256x192.py @@ -1,7 +1,7 @@ _base_ = ['../../../_base_/default_runtime.py'] # runtime -train_cfg = dict(max_epochs=210, val_interval=10) +train_cfg = dict(max_epochs=300, val_interval=10) # optimizer optim_wrapper = dict(optimizer=dict( @@ -17,17 +17,33 @@ dict( type='MultiStepLR', begin=0, - end=210, - milestones=[170, 200], + end=300, + milestones=[200, 250], gamma=0.1, by_epoch=True) ] # automatically scaling LR based on the actual training batch size -auto_scale_lr = dict(base_batch_size=512) +auto_scale_lr = dict(base_batch_size=64) # hooks -default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) +default_hooks = dict( + checkpoint=dict(save_best='coco/AP', rule='greater'), +) + +custom_hooks = [ + # dict(type='PCKAccuracyTrainHook', interval=10, thr=0.05), + dict( + type='EarlyStoppingHook', + monitor='5pr_/PCK', + rule='greater', + min_delta=0.001, + patience=20, + stopping_threshold=None, + strict=False, + check_finite=True + ), +] # codec settings codec = dict( @@ -43,13 +59,13 @@ bgr_to_rgb=True), backbone=dict( type='ResNet', - depth=50, - init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + depth=18, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), ), head=dict( type='HeatmapHead', - in_channels=2048, - out_channels=17, + in_channels=512, + out_channels=7, loss=dict(type='KeypointMSELoss', use_target_weight=True), decoder=codec), test_cfg=dict( @@ -61,16 +77,39 @@ # base dataset settings dataset_type = 'CocoDataset' data_mode = 'topdown' -data_root = 'data/coco/' +# data_root = 'data/2144_split_exported_data_project_id_422/' +# data_root = 'data/2769_split_exported_data_project_id_422/' +data_root = "data/joined/" # pipelines train_pipeline = [ dict(type='LoadImage'), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), - dict(type='RandomHalfBody'), dict(type='RandomBBoxTransform'), dict(type='TopdownAffine', input_size=codec['input_size']), + dict( + type='Albumentation', + transforms=[ + dict(type='RandomBrightnessContrast', brightness_limit=[-0.4, 0.4], contrast_limit=[-0.4, 0.4], p=0.6), + + dict( + type='OneOf', + transforms=[ + dict(type='MotionBlur', blur_limit=5, p=0.3), + dict(type='MedianBlur', blur_limit=5, p=0.3), + dict(type='Blur', blur_limit=5, p=0.3), + ], p=0.4), + + dict( + type='OneOf', + transforms=[ + dict(type='GaussNoise', var_limit=(10.0, 50.0), p=0.4), + dict(type='MultiplicativeNoise', multiplier=(0.9, 1.1), p=0.4), + ], p=0.4), + + dict(type='HueSaturationValue', hue_shift_limit=20, sat_shift_limit=20, val_shift_limit=20, p=0.5), + ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] @@ -84,20 +123,20 @@ # data loaders train_dataloader = dict( batch_size=64, - num_workers=2, + num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, - ann_file='annotations/person_keypoints_train2017.json', + ann_file='annotations/forklift_keypoints_train2017.json', data_prefix=dict(img='train2017/'), pipeline=train_pipeline, )) val_dataloader = dict( batch_size=32, - num_workers=2, + num_workers=4, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), @@ -105,9 +144,8 @@ type=dataset_type, data_root=data_root, data_mode=data_mode, - ann_file='annotations/person_keypoints_val2017.json', - bbox_file='data/coco/person_detection_results/' - 'COCO_val2017_detections_AP_H_56_person.json', + ann_file='annotations/forklift_keypoints_train2017.json', + bbox_file='', data_prefix=dict(img='val2017/'), test_mode=True, pipeline=val_pipeline, @@ -115,7 +153,25 @@ test_dataloader = val_dataloader # evaluators -val_evaluator = dict( - type='CocoMetric', - ann_file=data_root + 'annotations/person_keypoints_val2017.json') +val_evaluator = [ + dict( + type='CocoMetric', + ann_file=data_root + 'annotations/forklift_keypoints_train2017.json' + ), + dict( + type='EPE', + ), + dict( + type='PCKAccuracy', + prefix="5pr_", + ), + dict( + type='PCKAccuracy', + thr=0.1, + prefix="10pr_", + ), + dict( + type='AUC', + ), +] test_evaluator = val_evaluator diff --git a/docker/Dockerfile b/docker/Dockerfile index 064b803979..1c18fcf2c7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -32,3 +32,5 @@ RUN git checkout main ENV FORCE_CUDA="1" RUN pip install -r requirements/build.txt RUN pip install --no-cache-dir -e . +RUN mim install "mmdet>=3.1.0" +RUN pip install future tensorboard albumentations diff --git a/docker/Dockerfile_aws b/docker/Dockerfile_aws new file mode 100644 index 0000000000..069f06d276 --- /dev/null +++ b/docker/Dockerfile_aws @@ -0,0 +1,36 @@ +ARG PYTORCH="1.9.0" +ARG CUDA="11.1" +ARG CUDNN="8" + +FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel + +ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX 8.6" +ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" +ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" + +# To fix GPG key error when running apt-get update +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub + +RUN apt-get update && apt-get install -y git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx\ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Install xtcocotools +RUN pip install cython +RUN pip install xtcocotools + +# Install MMEngine and MMCV +RUN pip install openmim +RUN mim install mmengine "mmcv>=2.0.0" + +# Install MMPose +RUN conda clean --all +RUN git clone https://github.com/open-mmlab/mmpose.git /mmpose +WORKDIR /mmpose +RUN git checkout main +ENV FORCE_CUDA="1" +RUN pip install -r requirements/build.txt +RUN pip install --no-cache-dir -e . +RUN mim install "mmdet>=3.1.0" +RUN pip install future tensorboard albumentations diff --git a/mmpose/datasets/transforms/topdown_transforms.py b/mmpose/datasets/transforms/topdown_transforms.py index c76d45e46a..ece0504814 100644 --- a/mmpose/datasets/transforms/topdown_transforms.py +++ b/mmpose/datasets/transforms/topdown_transforms.py @@ -5,6 +5,7 @@ import numpy as np from mmcv.transforms import BaseTransform from mmengine import is_seq_of +import os.path as osp from mmpose.registry import TRANSFORMS from mmpose.structures.bbox import get_udp_warp_matrix, get_warp_matrix @@ -51,6 +52,7 @@ def __init__(self, self.input_size = input_size self.use_udp = use_udp + self.idx = 0 @staticmethod def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float): @@ -134,6 +136,23 @@ def transform(self, results: Dict) -> Optional[dict]: results['input_center'] = center results['input_scale'] = scale + # output_dir = 'augmented_images' + # img_filename = f'aug_{self.idx}.jpg' + # self.idx += 1 + + # img = results['img'] + # keypoints = results['transformed_keypoints'] + + # for obj in keypoints: + # for kp in obj: + # print(kp) + # x, y = kp + # v = 1 + # if v > 0: # Visibility flag + # cv2.circle(img, (int(x), int(y)), 3, (0, 255, 0), -1) + + # cv2.imwrite(osp.join(output_dir, img_filename), img) + # print(f'========== {self.idx}') return results def __repr__(self) -> str: diff --git a/run_docker_gpu0.sh b/run_docker_gpu0.sh new file mode 100644 index 0000000000..38c839bf70 --- /dev/null +++ b/run_docker_gpu0.sh @@ -0,0 +1,3 @@ +docker run --rm --network host -w /data/new_mmpose/mmpose \ + -v /data:/data \ + --gpus '"device=0"' --shm-size=8g -it mmpose:1.3.2_cuda diff --git a/tools/test.py b/tools/test.py index 12fd6b4423..e7381c9a3a 100644 --- a/tools/test.py +++ b/tools/test.py @@ -71,6 +71,11 @@ def merge_args(cfg, args): cfg.launcher = args.launcher cfg.load_from = args.checkpoint + args.show, args.badcase, args.show_dir, args.dump = False, False, None, None + cfg.show, cfg.badcase, cfg.show_dir, cfg.dump = False, False, None, None + cfg.default_hooks.badcase.enable = False + cfg.default_hooks.visualization.enable = False + # -------------------- work directory -------------------- # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: @@ -83,6 +88,7 @@ def merge_args(cfg, args): # -------------------- visualization -------------------- if (args.show and not args.badcase) or (args.show_dir is not None): + print(f"---------------------- visualization {args.show} {args.show_dir}") assert 'visualization' in cfg.default_hooks, \ 'PoseVisualizationHook is not set in the ' \ '`default_hooks` field of config. Please set ' \ @@ -98,6 +104,7 @@ def merge_args(cfg, args): # -------------------- badcase analyze -------------------- if args.badcase: + print(f"---------------------- badcase {args.badcase}") assert 'badcase' in cfg.default_hooks, \ 'BadcaseAnalyzeHook is not set in the ' \ '`default_hooks` field of config. Please set ' \ @@ -123,6 +130,7 @@ def merge_args(cfg, args): # -------------------- Dump predictions -------------------- if args.dump is not None: + print(f"---------------------- dump {args.dump}") assert args.dump.endswith(('.pkl', '.pickle')), \ 'The dump file must be a pkl file.' dump_metric = dict(type='DumpResults', out_file_path=args.dump) @@ -135,6 +143,13 @@ def merge_args(cfg, args): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + args.show, args.badcase, args.show_dir, args.dump = False, False, None, None + cfg.show, cfg.badcase, cfg.show_dir, cfg.dump = False, False, None, None + cfg.default_hooks.badcase.enable = False + cfg.default_hooks.visualization.enable = False + cfg.vis_backends = None + del cfg.visualizer.vis_backends[1] + return cfg diff --git a/tools/train_grid.py b/tools/train_grid.py new file mode 100644 index 0000000000..d8a8206110 --- /dev/null +++ b/tools/train_grid.py @@ -0,0 +1,261 @@ +import argparse +import os +import os.path as osp +import csv +import shutil +import itertools +import time +import subprocess +import glob +import json + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner +from mmengine.utils import mkdir_or_exist + +csv_file = 'training_results.csv' + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a pose model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') + parser.add_argument( + '--amp', + action='store_true', + default=False, + help='enable automatic-mixed-precision training') + parser.add_argument( + '--no-validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') + parser.add_argument( + '--auto-scale-lr', + action='store_true', + help='whether to auto scale the learning rate according to the ' + 'actual batch size and the original batch size.') + parser.add_argument( + '--show-dir', + help='directory where the visualization images will be saved.') + parser.add_argument( + '--show', + action='store_true', + help='whether to display the prediction results in a window.') + parser.add_argument( + '--interval', + type=int, + default=1, + help='visualize per interval samples.') + parser.add_argument( + '--wait-time', + type=float, + default=1, + help='display time of every window. (second)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def check_existing_run(work_dir): + if not os.path.exists(csv_file): + return False + with open(csv_file, mode='r') as file: + reader = csv.DictReader(file) + for row in reader: + if row['run_name'] == work_dir and row['status'] == "completed": + return True + return False + + +def run_test(work_dir, config, checkpoint): + out_file = os.path.join(work_dir, 'out.json') + + command = f"python tools/test.py {config} {checkpoint} --work-dir {work_dir} --out {out_file}" + print("-----------", command) + result = subprocess.run(command, shell=True, capture_output=True, text=True) + + if result.returncode == 0: + try: + with open(out_file, 'r') as f: + metrics = json.load(f) + print(f"Metrics read from {out_file}: {metrics}") + + metrics['run_name'] = work_dir + metrics['status'] = "completed" + + file_exists = os.path.isfile(csv_file) + with open(csv_file, mode='a', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=metrics.keys()) + if not file_exists: + writer.writeheader() # Write header if the file doesn't exist + writer.writerow(metrics) + + print(f"Metrics saved to {csv_file}") + + + except FileNotFoundError: + print(f"Metrics file {out_file} not found.") + + except json.JSONDecodeError: + print(f"Failed to decode JSON from {out_file}.") + + else: + print(f"Test failed: {result.stderr}") + + +def find_checkpoint(work_dir): + """Find the correct checkpoint file based on naming convention.""" + # Find any pth files in the work_dir + checkpoints = glob.glob(osp.join(work_dir, '*.pth')) + if not checkpoints: + raise FileNotFoundError(f"No checkpoint found in {work_dir}") + + for checkpoint in checkpoints: + if 'best_coco_AP' in checkpoint: + return checkpoint + + for checkpoint in checkpoints: + if 'latest.pth' in checkpoint: + return checkpoint + + return checkpoints[0] + + +def merge_args(cfg, args): + """Merge CLI arguments to config.""" + if args.no_validate: + cfg.val_cfg = None + cfg.val_dataloader = None + cfg.val_evaluator = None + + cfg.launcher = args.launcher + + # work_dir is determined in this priority: CLI > config > default + if args.work_dir is not None: + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + + if args.amp is True: + from mmengine.optim import AmpOptimWrapper + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.setdefault('loss_scale', 'dynamic') + + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + if args.auto_scale_lr: + cfg.auto_scale_lr.enable = True + + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + return cfg + + +def main(): + hyperparam_grid = { + 'lr': [2e-4, 5e-4], + 'epochs': [300], # TODO: set valid + 'sigma': [1.5, 2], + 'optimizer': ['AdamW', 'Adam'], + 'batch_size': [128], + } + + step_dict = { + 5: [2, 3], + 10: [7, 9], + 100: [70, 90], + 200: [140, 180], + 300: [200, 250], + } + + hyperparam_combinations = list(itertools.product( + hyperparam_grid['lr'], + hyperparam_grid['epochs'], + hyperparam_grid['sigma'], + hyperparam_grid['optimizer'], + hyperparam_grid['batch_size'] + )) + + print(f"Total combinations: {len(hyperparam_combinations)}") + + for lr, epochs, sigma, optimizer, batch_size in hyperparam_combinations: + args = parse_args() + cfg = Config.fromfile(args.config) + cfg = merge_args(cfg, args) + + work_dir = f'work_dirs/grid_lr_{lr}_epochs_{epochs}_sigma_{sigma}_optim_{optimizer}_batch_{batch_size}/' + if check_existing_run(work_dir): + print(f"Skipping already completed run: {work_dir}") + continue + print(f"Start run: {work_dir}") + + # Update the config with the hyperparameters + cfg.optim_wrapper.optimizer = dict(type=optimizer, lr=lr) + cfg.train_cfg.max_epochs = epochs + cfg.codec.sigma = sigma # Modify sigma in loss + cfg.train_dataloader.batch_size = batch_size + cfg.param_scheduler[1].milestones = step_dict[epochs] + cfg.auto_scale_lr.base_batch_size = batch_size + + # Assign the work directory + cfg.work_dir = work_dir + + # Create work directory + mkdir_or_exist(osp.abspath(cfg.work_dir)) + + # Dump the updated config to the work directory + config_dst = osp.join(cfg.work_dir, 'modified_config.py') + cfg.dump(config_dst) + + # Set up Runner and train + runner = Runner.from_cfg(cfg) + + try: + print("Pre-train") + runner.train() + + print("Pre-test") + checkpoint_path = find_checkpoint(work_dir) + run_test(work_dir, config_dst, checkpoint_path) + print("Complete step") + + except Exception as e: + print(f"Error during training: {e}") + + +if __name__ == '__main__': + main()