diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index 92b849e461..a5a6f30b79 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -75,7 +75,7 @@ def feature_info(self, location): if location == 'expansion': # output of conv after act, same as block coutput return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels) else: # location == 'bottleneck', block output - return dict(module='', hook_type='', num_chs=self.conv.out_channels) + return dict(module='', num_chs=self.conv.out_channels) def forward(self, x): shortcut = x @@ -116,7 +116,7 @@ def feature_info(self, location): if location == 'expansion': # after SE, input to PW return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) else: # location == 'bottleneck', block output - return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) + return dict(module='', num_chs=self.conv_pw.out_channels) def forward(self, x): shortcut = x @@ -173,7 +173,7 @@ def feature_info(self, location): if location == 'expansion': # after SE, input to PWL return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) else: # location == 'bottleneck', block output - return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return dict(module='', num_chs=self.conv_pwl.out_channels) def forward(self, x): shortcut = x @@ -266,7 +266,7 @@ def feature_info(self, location): if location == 'expansion': # after SE, before PWL return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) else: # location == 'bottleneck', block output - return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return dict(module='', num_chs=self.conv_pwl.out_channels) def forward(self, x): shortcut = x diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index e6cd05ae2e..b5dbeaae3a 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -370,9 +370,7 @@ def __call__(self, in_chs, model_block_args): stages = [] if model_block_args[0][0]['stride'] > 1: # if the first block starts with a stride, we need to extract first level feat from stem - feature_info = dict( - module='act1', num_chs=in_chs, stage=0, reduction=current_stride, - hook_type='forward' if self.feature_location != 'bottleneck' else '') + feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride) self.features.append(feature_info) # outer list of block_args defines the stacks @@ -418,10 +416,16 @@ def __call__(self, in_chs, model_block_args): # stash feature module name and channel info for model feature extraction if extract_features: feature_info = dict( - stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location)) - module_name = f'blocks.{stack_idx}.{block_idx}' + stage=stack_idx + 1, + reduction=current_stride, + **block.feature_info(self.feature_location), + ) leaf_name = feature_info.get('module', '') - feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name + if leaf_name: + feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name]) + else: + assert last_block + feature_info['module'] = f'blocks.{stack_idx}' self.features.append(feature_info) total_block_idx += 1 # incr global block idx (across all stacks) diff --git a/timm/models/_features.py b/timm/models/_features.py index 7924453638..7ef51809bc 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -27,12 +27,13 @@ class FeatureInfo: def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): prev_reduction = 1 - for fi in feature_info: + for i, fi in enumerate(feature_info): # sanity check the mandatory fields, there may be additional fields depending on the model assert 'num_chs' in fi and fi['num_chs'] > 0 assert 'reduction' in fi and fi['reduction'] >= prev_reduction prev_reduction = fi['reduction'] assert 'module' in fi + fi.setdefault('index', i) self.out_indices = out_indices self.info = feature_info diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index c894d66c85..b84312a739 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -6,7 +6,7 @@ import torch from torch import nn -from ._features import _get_feature_info +from ._features import _get_feature_info, _get_return_layers try: from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor @@ -93,9 +93,7 @@ def __init__(self, model, out_indices, out_map=None): self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) - return_nodes = { - info['module']: out_map[i] if out_map is not None else info['module'] - for i, info in enumerate(self.feature_info) if i in out_indices} + return_nodes = _get_return_layers(self.feature_info, out_map) self.graph_module = create_feature_extractor(model, return_nodes) def forward(self, x): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 1854bc3058..905f8b7596 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -232,7 +232,7 @@ def __init__( ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) - self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()} efficientnet_init_weights(self) @@ -268,20 +268,28 @@ def forward(self, x) -> List[torch.Tensor]: def _create_effnet(variant, pretrained=False, **kwargs): - features_only = False + features_mode = '' model_cls = EfficientNet kwargs_filter = None if kwargs.pop('features_only', False): - features_only = True - kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') - model_cls = EfficientNetFeatures + if 'feature_cfg' in kwargs: + features_mode = 'cfg' + else: + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') + model_cls = EfficientNetFeatures + features_mode = 'cls' + model = build_model_with_cfg( - model_cls, variant, pretrained, - pretrained_strict=not features_only, + model_cls, + variant, + pretrained, + features_only=features_mode == 'cfg', + pretrained_strict=features_mode != 'cls', kwargs_filter=kwargs_filter, - **kwargs) - if features_only: - model.default_cfg = pretrained_cfg_for_features(model.default_cfg) + **kwargs, + ) + if features_mode == 'cls': + model.pretrained_cfg = model.default_cfg = pretrained_cfg_for_features(model.pretrained_cfg) return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 4a6fd80eaa..e56a550936 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -829,7 +829,7 @@ def __init__( **kwargs, ) self.feature_info = FeatureInfo(self.feature_info, out_indices) - self._out_idx = {i for i in out_indices} + self._out_idx = {f['index'] for f in self.feature_info.get_dicts()} def forward_features(self, x): assert False, 'Not supported' diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 70087b62fa..8de94f7e26 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -210,7 +210,7 @@ def __init__( ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) - self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()} efficientnet_init_weights(self) @@ -247,21 +247,27 @@ def forward(self, x) -> List[torch.Tensor]: def _create_mnv3(variant, pretrained=False, **kwargs): - features_only = False + features_mode = '' model_cls = MobileNetV3 kwargs_filter = None if kwargs.pop('features_only', False): - features_only = True - kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') - model_cls = MobileNetV3Features + if 'feature_cfg' in kwargs: + features_mode = 'cfg' + else: + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + features_mode = 'cls' + model = build_model_with_cfg( model_cls, variant, pretrained, - pretrained_strict=not features_only, + features_only=features_mode == 'cfg', + pretrained_strict=features_mode != 'cls', kwargs_filter=kwargs_filter, - **kwargs) - if features_only: + **kwargs, + ) + if features_mode == 'cls': model.default_cfg = pretrained_cfg_for_features(model.default_cfg) return model