Skip to content

Commit

Permalink
Fixes issue #1: config key error
Browse files Browse the repository at this point in the history
  • Loading branch information
taugeren committed Oct 11, 2024
1 parent 795de0a commit b3ea477
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
34 changes: 20 additions & 14 deletions mmrotate/models/dense_heads/cpm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,28 @@ def __init__(self,
self.cls_weight = 20
self.thresh1 = 8
self.alpha = 1
if kwargs.get('train_cfg')['store_dir'] is not None:
self.store_dir = kwargs.get('train_cfg')['store_dir']
elif kwargs.get('test_cfg')['store_dir'] is not None:
self.store_dir = kwargs.get('test_cfg')['store_dir']

train_cfg = kwargs.get('train_cfg', {})
test_cfg = kwargs.get('test_cfg', {})

if 'store_dir' in train_cfg:
self.store_dir = train_cfg['store_dir']
elif 'store_dir' in test_cfg:
self.store_dir = test_cfg['store_dir']

assert self.store_dir is not None
os.makedirs(self.store_dir + "/visualize/", exist_ok=True)
if kwargs.get('train_cfg')['cls_weight'] is not None:
self.cls_weight = kwargs.get('train_cfg')['cls_weight']
if kwargs.get('train_cfg')['thresh1'] is not None:
self.thresh1 = kwargs.get('train_cfg')['thresh1']
if kwargs.get('train_cfg')['alpha'] is not None:
self.alpha = kwargs.get('train_cfg')['alpha']
if kwargs.get('train_cfg')['vis_train_duration'] is not None:
self.train_duration = kwargs.get('train_cfg')['vis_train_duration']
if kwargs.get('test_cfg')['visualize'] is not None:
self.visualize = kwargs.get('train_cfg')['visualize']

if 'cls_weight' in train_cfg:
self.cls_weight = train_cfg['cls_weight']
if 'thresh1' in train_cfg:
self.thresh1 = train_cfg['thresh1']
if 'alpha' in train_cfg:
self.alpha = train_cfg['alpha']
if 'vis_train_duration' in train_cfg:
self.train_duration = train_cfg['vis_train_duration']
if 'visualize' in train_cfg:
self.visualize = train_cfg['visualize']

def get_mask_image(self, max_probs, max_indices, thr, num_width):
PALETTE = [
Expand Down
28 changes: 16 additions & 12 deletions mmrotate/models/dense_heads/pseudo_label_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,24 @@ def __init__(self,
self.cls_weight = 20
self.thresh3 = 0.1
self.multiple_factor = 1/16
if kwargs.get('train_cfg')['store_dir'] is not None:
self.store_dir = kwargs.get('train_cfg')['store_dir']
elif kwargs.get('test_cfg')['store_dir'] is not None:
self.store_dir = kwargs.get('test_cfg')['store_dir']
if kwargs.get('train_cfg')['thresh3'] is not None:
self.thresh3 = kwargs.get('train_cfg')['thresh3']

train_cfg = kwargs.get('train_cfg', {})
test_cfg = kwargs.get('test_cfg', {})

if 'store_dir' in train_cfg:
self.store_dir = train_cfg['store_dir']
elif 'store_dir' in test_cfg:
self.store_dir = test_cfg['store_dir']
if 'thresh3' in train_cfg:
self.thresh3 = train_cfg['thresh3']
assert self.store_dir is not None
os.makedirs(self.store_dir + "/visualize/", exist_ok=True)
if kwargs.get('train_cfg')['cls_weight'] is not None:
self.cls_weight = kwargs.get('train_cfg')['cls_weight']
if kwargs.get('train_cfg')['pca_length'] is not None:
self.pca_length = kwargs.get('train_cfg')['pca_length']
if kwargs.get('train_cfg')['multiple_factor'] is not None:
self.multiple_factor = kwargs.get('train_cfg')['multiple_factor']
if 'cls_weight' in train_cfg:
self.cls_weight = train_cfg['cls_weight']
if 'pca_length' in train_cfg:
self.pca_length = train_cfg['pca_length']
if 'multiple_factor' in train_cfg:
self.multiple_factor = train_cfg['multiple_factor']
assert len(self.thresh3) == self.num_classes
self.store_ann_dir = kwargs.get('train_cfg')['store_ann_dir']
os.makedirs(self.store_ann_dir, exist_ok=True)
Expand Down

0 comments on commit b3ea477

Please sign in to comment.