Skip to content

Commit

Permalink
Clean more feature extract issues
Browse files Browse the repository at this point in the history
* EfficientNet/MobileNetV3/HRNetFeatures cls and FX mode support -ve index
* MobileNetV3 allows feature_cfg mode to bypass MobileNetV3Features
  • Loading branch information
rwightman committed Jun 14, 2023
1 parent a09c88e commit 47517db
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 35 deletions.
8 changes: 4 additions & 4 deletions timm/models/_efficientnet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def feature_info(self, location):
if location == 'expansion': # output of conv after act, same as block coutput
return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv.out_channels)
return dict(module='', num_chs=self.conv.out_channels)

def forward(self, x):
shortcut = x
Expand Down Expand Up @@ -116,7 +116,7 @@ def feature_info(self, location):
if location == 'expansion': # after SE, input to PW
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return dict(module='', num_chs=self.conv_pw.out_channels)

def forward(self, x):
shortcut = x
Expand Down Expand Up @@ -173,7 +173,7 @@ def feature_info(self, location):
if location == 'expansion': # after SE, input to PWL
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return dict(module='', num_chs=self.conv_pwl.out_channels)

def forward(self, x):
shortcut = x
Expand Down Expand Up @@ -266,7 +266,7 @@ def feature_info(self, location):
if location == 'expansion': # after SE, before PWL
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return dict(module='', num_chs=self.conv_pwl.out_channels)

def forward(self, x):
shortcut = x
Expand Down
10 changes: 6 additions & 4 deletions timm/models/_efficientnet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,7 @@ def __call__(self, in_chs, model_block_args):
stages = []
if model_block_args[0][0]['stride'] > 1:
# if the first block starts with a stride, we need to extract first level feat from stem
feature_info = dict(
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
hook_type='forward' if self.feature_location != 'bottleneck' else '')
feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
self.features.append(feature_info)

# outer list of block_args defines the stacks
Expand Down Expand Up @@ -418,11 +416,15 @@ def __call__(self, in_chs, model_block_args):
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_info = dict(
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
stage=stack_idx + 1,
reduction=current_stride,
**block.feature_info(self.feature_location),
)
leaf_name = feature_info.get('module', '')
if leaf_name:
feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
else:
assert last_block
feature_info['module'] = f'blocks.{stack_idx}'
self.features.append(feature_info)

Expand Down
3 changes: 2 additions & 1 deletion timm/models/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ class FeatureInfo:

def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
prev_reduction = 1
for fi in feature_info:
for i, fi in enumerate(feature_info):
# sanity check the mandatory fields, there may be additional fields depending on the model
assert 'num_chs' in fi and fi['num_chs'] > 0
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
prev_reduction = fi['reduction']
assert 'module' in fi
fi.setdefault('index', i)
self.out_indices = out_indices
self.info = feature_info

Expand Down
6 changes: 2 additions & 4 deletions timm/models/_features_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn

from ._features import _get_feature_info
from ._features import _get_feature_info, _get_return_layers

try:
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
Expand Down Expand Up @@ -93,9 +93,7 @@ def __init__(self, model, out_indices, out_map=None):
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {
info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
return_nodes = _get_return_layers(self.feature_info, out_map)
self.graph_module = create_feature_extractor(model, return_nodes)

def forward(self, x):
Expand Down
29 changes: 16 additions & 13 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}

efficientnet_init_weights(self)

Expand Down Expand Up @@ -268,25 +268,28 @@ def forward(self, x) -> List[torch.Tensor]:


def _create_effnet(variant, pretrained=False, **kwargs):
features_only = False
features_cls = False
features_mode = ''
model_cls = EfficientNet
kwargs_filter = None
if kwargs.pop('features_only', False):
if 'feature_cfg' not in kwargs:
if 'feature_cfg' in kwargs:
features_mode = 'cfg'
else:
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
model_cls = EfficientNetFeatures
features_cls = True
else:
features_only = True
features_mode = 'cls'

model = build_model_with_cfg(
model_cls, variant, pretrained,
features_only=features_only,
pretrained_strict=not features_cls,
model_cls,
variant,
pretrained,
features_only=features_mode == 'cfg',
pretrained_strict=features_mode != 'cls',
kwargs_filter=kwargs_filter,
**kwargs)
if features_only:
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
**kwargs,
)
if features_mode == 'cls':
model.pretrained_cfg = model.default_cfg = pretrained_cfg_for_features(model.pretrained_cfg)
return model


Expand Down
2 changes: 1 addition & 1 deletion timm/models/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def __init__(
**kwargs,
)
self.feature_info = FeatureInfo(self.feature_info, out_indices)
self._out_idx = {i for i in out_indices}
self._out_idx = {f['index'] for f in self.feature_info.get_dicts()}

def forward_features(self, x):
assert False, 'Not supported'
Expand Down
22 changes: 14 additions & 8 deletions timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}

efficientnet_init_weights(self)

Expand Down Expand Up @@ -247,21 +247,27 @@ def forward(self, x) -> List[torch.Tensor]:


def _create_mnv3(variant, pretrained=False, **kwargs):
features_only = False
features_mode = ''
model_cls = MobileNetV3
kwargs_filter = None
if kwargs.pop('features_only', False):
features_only = True
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
model_cls = MobileNetV3Features
if 'feature_cfg' in kwargs:
features_mode = 'cfg'
else:
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
model_cls = MobileNetV3Features
features_mode = 'cls'

model = build_model_with_cfg(
model_cls,
variant,
pretrained,
pretrained_strict=not features_only,
features_only=features_mode == 'cfg',
pretrained_strict=features_mode != 'cls',
kwargs_filter=kwargs_filter,
**kwargs)
if features_only:
**kwargs,
)
if features_mode == 'cls':
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
return model

Expand Down

0 comments on commit 47517db

Please sign in to comment.