Skip to content

Commit

Permalink
Changes to support DINOv2 in HF
Browse files Browse the repository at this point in the history
  • Loading branch information
gheinrich committed Oct 10, 2024
1 parent 662e877 commit 58cfe72
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
16 changes: 13 additions & 3 deletions hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from radio.adaptor_base import RadioOutput
from radio.adaptor_registry import adaptor_registry
from radio.adaptor_mlp import get_mlp_info_from_state
from radio.hf_model import RADIOConfig, RADIOModel
from radio.hf_model import RADIOConfig, RADIOModel, rename_all_gamma_to_weight_with_proxy
from test_hf import deterministic_grid_init


Expand Down Expand Up @@ -164,7 +164,7 @@ def main():

feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer_config = None
if feat_norm_sd is not None:
if feat_norm_sd:
feature_normalizer_config = {
"embed_dim": feat_norm_sd['mean'].shape[0]
}
Expand Down Expand Up @@ -219,6 +219,10 @@ def main():
if inter_feat_norm_sd:
radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)

# Rename "gamma" parameters to "weight"
rename_all_gamma_to_weight_with_proxy(radio_model.radio_model)
radio_config.rename_gamma_to_weight = True

radio_model.eval().cuda()

# Sample inference with deterministic values.
Expand All @@ -240,7 +244,7 @@ def main():
hf_summary, hf_features = v.summary, v.features

print(
f"[{k}] Sample inference on tensor shape {x.shape} returned summary ",
f"[{k}] HF inference on tensor shape {x.shape} returned summary ",
f"with shape={hf_summary.shape} and std={hf_summary.std().item():.3}, ",
f"features with shape={hf_features.shape} and std={hf_features.std().item():.3}",
)
Expand Down Expand Up @@ -288,6 +292,12 @@ def main():
torchhub_output[k].features,
)

print(
f"[{k}] TorchHub inference on tensor shape {x.shape} returned summary ",
f"with shape={torchhub_summary.shape} and std={torchhub_summary.std().item():.3}, ",
f"features with shape={torchhub_features.shape} and std={torchhub_features.std().item():.3}",
)

# Make sure the shapes are the same.
assert (
hf_summary.shape == torchhub_summary.shape
Expand Down
7 changes: 7 additions & 0 deletions radio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ class RadioResource:
max_resolution=2048,
preferred_resolution=Resolution(512, 512),
),
# RADIO-DINOv2
"radio_dinov2-g": RadioResource(
None, # TODO: add URL for DINOv2 student.
patch_size=14,
max_resolution=2044,
preferred_resolution=Resolution(518, 518),
),
}

DEFAULT_VERSION = "radio_v2.5-h"
33 changes: 33 additions & 0 deletions radio/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,33 @@
from .extra_timm_models import *



def rename_all_gamma_to_weight_with_proxy(module):
"""
Renames all parameters named 'gamma' in a module (including submodules)
to 'weight' and sets up a property so that accesses to 'gamma' still work.
"""
# Recursively iterate through submodules
for submodule_name, submodule in module.named_modules():
# Get all parameters within the current submodule
for param_name, param in list(submodule.named_parameters(recurse=False)):
if 'gamma' in param_name:
# Generate the new name by replacing 'gamma' with 'weight'
new_name = param_name.replace('gamma', 'weight')

# Remove the old parameter and assign it with the new name
delattr(submodule, param_name)
setattr(submodule, new_name, nn.Parameter(param.data))

# Define a property to proxy access to the renamed parameter
def make_property(old_name, new_name):
return property(lambda self: getattr(self, new_name),
lambda self, value: setattr(self, new_name, value))

# Add the property to the submodule to proxy access to 'gamma'
setattr(submodule.__class__, param_name, make_property(param_name, new_name))


class RADIOConfig(PretrainedConfig):
"""Pretrained Hugging Face configuration for RADIO models."""

Expand All @@ -58,6 +85,7 @@ def __init__(
vitdet_window_size: Optional[int] = None,
feature_normalizer_config: Optional[dict] = None,
inter_feature_normalizer_config: Optional[dict] = None,
rename_gamma_to_weight: bool = False,
**kwargs,
):
self.args = args
Expand All @@ -79,9 +107,11 @@ def __init__(
self.vitdet_window_size = vitdet_window_size
self.feature_normalizer_config = feature_normalizer_config
self.inter_feature_normalizer_config = inter_feature_normalizer_config
self.rename_gamma_to_weight = rename_gamma_to_weight
super().__init__(**kwargs)



class RADIOModel(PreTrainedModel):
"""Pretrained Hugging Face model for RADIO.
Expand Down Expand Up @@ -149,6 +179,9 @@ def __init__(self, config: RADIOConfig):
inter_feature_normalizer=inter_feature_normalizer,
)

if config.rename_gamma_to_weight:
rename_all_gamma_to_weight_with_proxy(self.radio_model)

@property
def adaptors(self) -> nn.ModuleDict:
return self.radio_model.adaptors
Expand Down

0 comments on commit 58cfe72

Please sign in to comment.