From a09c88ed0f9f085b2d0bebba9a3bb8814a728ced Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 14 Jun 2023 12:57:39 -0700 Subject: [PATCH] Support other features only modes for EfficientNet --- timm/models/_efficientnet_builder.py | 6 ++++-- timm/models/efficientnet.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index e6cd05ae2e..14eb70f174 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -419,9 +419,11 @@ def __call__(self, in_chs, model_block_args): if extract_features: feature_info = dict( stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location)) - module_name = f'blocks.{stack_idx}.{block_idx}' leaf_name = feature_info.get('module', '') - feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name + if leaf_name: + feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name]) + else: + feature_info['module'] = f'blocks.{stack_idx}' self.features.append(feature_info) total_block_idx += 1 # incr global block idx (across all stacks) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 1854bc3058..af77d7ec0c 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -269,15 +269,20 @@ def forward(self, x) -> List[torch.Tensor]: def _create_effnet(variant, pretrained=False, **kwargs): features_only = False + features_cls = False model_cls = EfficientNet kwargs_filter = None if kwargs.pop('features_only', False): - features_only = True - kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') - model_cls = EfficientNetFeatures + if 'feature_cfg' not in kwargs: + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') + model_cls = EfficientNetFeatures + features_cls = True + else: + features_only = True model = build_model_with_cfg( model_cls, variant, pretrained, - pretrained_strict=not features_only, + features_only=features_only, + pretrained_strict=not features_cls, kwargs_filter=kwargs_filter, **kwargs) if features_only: