Skip to content

Commit

Permalink
fix train val splits
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrii-Sheba committed Nov 28, 2024
1 parent df95779 commit 4f6275a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@
# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'topdown'
# data_root = 'data/2144_split_exported_data_project_id_422/'
# data_root = 'data/2769_split_exported_data_project_id_422/'
data_root = "data/joined/"

# pipelines
Expand Down Expand Up @@ -144,7 +142,7 @@
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/forklift_keypoints_train2017.json',
ann_file='annotations/forklift_keypoints_val2017.json',
bbox_file='',
data_prefix=dict(img='val2017/'),
test_mode=True,
Expand All @@ -156,7 +154,7 @@
val_evaluator = [
dict(
type='CocoMetric',
ann_file=data_root + 'annotations/forklift_keypoints_train2017.json'
ann_file=data_root + 'annotations/forklift_keypoints_val2017.json'
),
dict(
type='EPE',
Expand Down
28 changes: 16 additions & 12 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def parse_args():
help='If specify checkpint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--data-root',
type=str,
help='Root directory for dataset. This will override data_root in the config file.'
)
parser.add_argument(
'--amp',
action='store_true',
Expand Down Expand Up @@ -69,14 +74,6 @@ def parse_args():
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)

# Add the max-epochs argument
parser.add_argument(
'--max-epochs',
type=int,
default=200,
help='Train for how many epochs.')

args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
Expand All @@ -102,6 +99,17 @@ def merge_args(cfg, args):
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])

# Update data_root
if args.data_root is not None:
cfg.train_dataloader.dataset.data_root = args.data_root
cfg.val_dataloader.dataset.data_root = args.data_root
cfg.test_dataloader.dataset.data_root = args.data_root

# Update evaluator paths if necessary
for evaluator in cfg.val_evaluator:
if 'ann_file' in evaluator:
evaluator['ann_file'] = osp.join(args.data_root, 'annotations/forklift_keypoints_val2017.json')

# enable automatic-mixed-precision training
if args.amp is True:
from mmengine.optim import AmpOptimWrapper, OptimWrapper
Expand Down Expand Up @@ -142,10 +150,6 @@ def merge_args(cfg, args):
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# Set max_epochs from CLI argument
if args.max_epochs is not None:
cfg.train_cfg['max_epochs'] = args.max_epochs

return cfg


Expand Down

0 comments on commit 4f6275a

Please sign in to comment.