Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gheinrich/hf feat norm #91

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ transformers
datasets
timm
open_clip_torch
albumentations
albumentations==1.3.1
opencv-python==4.8.0.74
opencv-python-headless==4.8.0.74
git+https://github.com/facebookresearch/segment-anything.git
Expand Down
3 changes: 2 additions & 1 deletion examples/visualize_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def main(rank: int = 0, world_size: int = 1):
output_fmt='NLC',
intermediates_only=True,
aggregation=args.intermediate_aggregation,
norm_alpha_scheme="none",
)
assert args.adaptor_name is None
all_feat = [o[1] for o in outputs]
all_feat = outputs
else:
output = model(p_images)
if args.adaptor_name:
Expand Down
63 changes: 60 additions & 3 deletions hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from radio.adaptor_base import RadioOutput
from radio.adaptor_registry import adaptor_registry
from radio.adaptor_mlp import get_mlp_info_from_state
from radio.hf_model import RADIOConfig, RADIOModel
from radio.hf_model import RADIOConfig, RADIOModel, rename_all_gamma_to_weight_with_proxy
from test_hf import deterministic_grid_init


Expand Down Expand Up @@ -161,11 +161,30 @@ def main():

adaptor_configs[adaptor_name] = adaptor_config


feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer_config = None
if feat_norm_sd:
feature_normalizer_config = {
"embed_dim": feat_norm_sd['mean'].shape[0]
}

inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
inter_feature_normalizer_config = None
if inter_feat_norm_sd:
inter_feature_normalizer_config = {
"num_intermediates": inter_feat_norm_sd['means'].shape[0],
"embed_dim": inter_feat_norm_sd['means'].shape[1],
"rot_per_layer": inter_feat_norm_sd['rotation'].ndim == 3,
}

radio_config = RADIOConfig(
vars(model_args),
version=args.version,
adaptor_names=adaptor_names,
adaptor_configs=adaptor_configs,
feature_normalizer_config=feature_normalizer_config,
inter_feature_normalizer_config=inter_feature_normalizer_config,
)
radio_model = RADIOModel(radio_config)

Expand Down Expand Up @@ -194,6 +213,16 @@ def main():
get_prefix_state_dict(state_dict, "input_conditioner.")
)

# Restore feature normalizer.
if feat_norm_sd:
radio_model.radio_model.feature_normalizer.load_state_dict(feat_norm_sd)
if inter_feat_norm_sd:
radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)

# Rename "gamma" parameters to "weight"
rename_all_gamma_to_weight_with_proxy(radio_model.radio_model)
radio_config.rename_gamma_to_weight = True

radio_model.eval().cuda()

# Sample inference with deterministic values.
Expand All @@ -215,11 +244,30 @@ def main():
hf_summary, hf_features = v.summary, v.features

print(
f"[{k}] Sample inference on tensor shape {x.shape} returned summary ",
f"[{k}] HF inference on tensor shape {x.shape} returned summary ",
f"with shape={hf_summary.shape} and std={hf_summary.std().item():.3}, ",
f"features with shape={hf_features.shape} and std={hf_features.std().item():.3}",
)

intermediates = radio_model.radio_model.forward_intermediates(
x,
indices=[-1],
return_prefix_tokens=True,
norm=False,
stop_early=False,
output_fmt='NLC',
intermediates_only=True,
aggregation="sparse",
)
print(
f"Intermediates inference returned ",
f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}",
)
print("diff norm", (intermediates[0].features- hf_output["backbone"].features).norm())
print("std", intermediates[0].features.std().item(), hf_output["backbone"].features.std().item())
print("mean", intermediates[0].features.mean().item(), hf_output["backbone"].features.mean().item())
#assert torch.allclose(intermediates[0].features, hf_output["backbone"].features, atol=1e-4)

# Infer using TorchHub model.
print("Infer using TorchHub model...")
torchhub_model = torch.hub.load(
Expand All @@ -244,6 +292,12 @@ def main():
torchhub_output[k].features,
)

print(
f"[{k}] TorchHub inference on tensor shape {x.shape} returned summary ",
f"with shape={torchhub_summary.shape} and std={torchhub_summary.std().item():.3}, ",
f"features with shape={torchhub_features.shape} and std={torchhub_features.std().item():.3}",
)

# Make sure the shapes are the same.
assert (
hf_summary.shape == torchhub_summary.shape
Expand All @@ -262,6 +316,10 @@ def main():

print(f"{k} outputs matched!")



print("All outputs matched!")

if args.push:
# Push to HuggingFace Hub.
huggingface_repo = args.hf_repo
Expand All @@ -273,7 +331,6 @@ def main():
)
print(f"Pushed to {commit}")


if __name__ == "__main__":
"""Call the main entrypoiny."""
main()
4 changes: 4 additions & 0 deletions mmseg/radio.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
# Standard ViT case.
patch_height, patch_width = self.base_model.model.patch_embed.patch_size
features = features.reshape(B, math.ceil(H/patch_height), math.ceil(W/patch_width), C).permute(0, 3, 1, 2).contiguous()
else:
B, _, C = features.shape
patch_height = patch_width = 16
features = features.reshape(B, math.ceil(H/patch_height), math.ceil(W/patch_width), C).permute(0, 3, 1, 2).contiguous()

# IMPORTANT: prevent gradients from flowing back towards the backbone.
features = features.detach()
Expand Down
7 changes: 7 additions & 0 deletions radio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ class RadioResource:
max_resolution=2048,
preferred_resolution=Resolution(512, 512),
),
# RADIO-DINOv2
"radio_dinov2-g": RadioResource(
None, # TODO: add URL for DINOv2 student.
patch_size=14,
max_resolution=2044,
preferred_resolution=Resolution(518, 518),
),
}

DEFAULT_VERSION = "radio_v2.5-h"
2 changes: 1 addition & 1 deletion radio/enable_cpe_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from timm.models import VisionTransformer, checkpoint_seq

from radio.feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer

from .extra_models import DinoWrapper
from .vit_patch_generator import ViTPatchGenerator
Expand Down
6 changes: 6 additions & 0 deletions radio/feature_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class FeatureNormalizer(nn.Module):
def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32):
super().__init__()

# self.mean = nn.Parameter(torch.zeros(embed_dim, dtype=dtype), requires_grad=False)
# self.tx = nn.Parameter(torch.eye(embed_dim, dtype=dtype), requires_grad=False)
self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype))
self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype))

Expand All @@ -49,15 +51,19 @@ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Opti
class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32):
super().__init__()
# self.alphas = nn.Parameter(torch.ones(num_intermediates, dtype=dtype), requires_grad=False)
self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype))

rot = torch.eye(embed_dim, dtype=dtype)
if rot_per_layer:
rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1)

# self.rotation = nn.Parameter(rot.contiguous(), requires_grad=False)
# self.means = nn.Parameter(torch.zeros(num_intermediates, embed_dim, dtype=dtype), requires_grad=False)
self.register_buffer('rotation', rot.contiguous())
self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype))


def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
if rot_index is None:
rot_index = index
Expand Down
53 changes: 53 additions & 0 deletions radio/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .enable_cpe_support import enable_cpe
from .enable_spectral_reparam import configure_spectral_reparam_from_args
from .eradio_model import eradio
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .radio_model import create_model_from_args
from .radio_model import RADIOModel as RADIOModelBase, Resolution
from .input_conditioner import get_default_conditioner, InputConditioner
Expand All @@ -42,6 +43,33 @@
from .extra_timm_models import *



def rename_all_gamma_to_weight_with_proxy(module):
"""
Renames all parameters named 'gamma' in a module (including submodules)
to 'weight' and sets up a property so that accesses to 'gamma' still work.
"""
# Recursively iterate through submodules
for submodule_name, submodule in module.named_modules():
# Get all parameters within the current submodule
for param_name, param in list(submodule.named_parameters(recurse=False)):
if 'gamma' in param_name:
# Generate the new name by replacing 'gamma' with 'weight'
new_name = param_name.replace('gamma', 'weight')

# Remove the old parameter and assign it with the new name
delattr(submodule, param_name)
setattr(submodule, new_name, nn.Parameter(param.data))

# Define a property to proxy access to the renamed parameter
def make_property(old_name, new_name):
return property(lambda self: getattr(self, new_name),
lambda self, value: setattr(self, new_name, value))

# Add the property to the submodule to proxy access to 'gamma'
setattr(submodule.__class__, param_name, make_property(param_name, new_name))


class RADIOConfig(PretrainedConfig):
"""Pretrained Hugging Face configuration for RADIO models."""

Expand All @@ -55,6 +83,9 @@ def __init__(
adaptor_names: Union[str, List[str]] = None,
adaptor_configs: Dict[str, Dict[str, int]] = None,
vitdet_window_size: Optional[int] = None,
feature_normalizer_config: Optional[dict] = None,
inter_feature_normalizer_config: Optional[dict] = None,
rename_gamma_to_weight: bool = False,
**kwargs,
):
self.args = args
Expand All @@ -74,9 +105,13 @@ def __init__(
self.adaptor_names = adaptor_names
self.adaptor_configs = adaptor_configs
self.vitdet_window_size = vitdet_window_size
self.feature_normalizer_config = feature_normalizer_config
self.inter_feature_normalizer_config = inter_feature_normalizer_config
self.rename_gamma_to_weight = rename_gamma_to_weight
super().__init__(**kwargs)



class RADIOModel(PreTrainedModel):
"""Pretrained Hugging Face model for RADIO.

Expand Down Expand Up @@ -118,6 +153,19 @@ def __init__(self, config: RADIOConfig):
adaptor.head_idx = mlp_config["head_idx"]
adaptors[adaptor_name] = adaptor

feature_normalizer = None
if config.feature_normalizer_config is not None:
# Actual normalization values will be restored when loading checkpoint weights.
feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"])

inter_feature_normalizer = None
if config.inter_feature_normalizer_config is not None:
inter_feature_normalizer = IntermediateFeatureNormalizer(
config.inter_feature_normalizer_config["num_intermediates"],
config.inter_feature_normalizer_config["embed_dim"],
rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"],
dtype=dtype)

self.radio_model = RADIOModelBase(
model,
input_conditioner,
Expand All @@ -127,8 +175,13 @@ def __init__(self, config: RADIOConfig):
window_size=config.vitdet_window_size,
preferred_resolution=config.preferred_resolution,
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
)

if config.rename_gamma_to_weight:
rename_all_gamma_to_weight_with_proxy(self.radio_model)

@property
def adaptors(self) -> nn.ModuleDict:
return self.radio_model.adaptors
Expand Down
27 changes: 24 additions & 3 deletions test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def main():
python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio_v2.1_bf16.pth.tar --torchhub-repo NVlabs/RADIO:dev/hf
python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio-v2.5-l_half.pth.tar --torchhub-repo NVlabs/RADIO:dev/hf
python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio-v2.5-l_half.pth.tar --adaptor-names siglip,sam
python3 -m test_hf --hf-repo gheinrich/RADIO-NORM --torchhub-version /lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/hero/n32_8-19-24_vit-h-16_hero-v4_s3/checkpoints/last_norm_release_half.pth.tar --torchhub-repo NVlabs/RADIO:mranzinger/ship_paper
"""
parser = argparse.ArgumentParser()
parser.add_argument("--hf-repo", help="Path to the HuggingFace repo", required=True)
Expand All @@ -53,6 +54,9 @@ def main():
parser.add_argument(
"--torchhub-repo", help="Path to the Torchhub repo", default="NVlabs/RADIO"
)
parser.add_argument(
"--hf-revision", help="HuggingFace revision to checkout", default="main"
)
parser.add_argument(
"--adaptor-names",
default=None,
Expand All @@ -63,13 +67,13 @@ def main():

args = parser.parse_args()

hf_config = AutoConfig.from_pretrained(args.hf_repo, trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(args.hf_repo, revision=args.hf_revision, trust_remote_code=True)
if args.adaptor_names is not None:
# Configure adaptors if specified on the command line.
# This needs to happen before we instantiate the model.
hf_config.adaptor_names = args.adaptor_names
hf_model = AutoModel.from_pretrained(
args.hf_repo, trust_remote_code=True, config=hf_config
args.hf_repo, revision=args.hf_revision, trust_remote_code=True, config=hf_config
)
hf_model.eval().cuda()

Expand Down Expand Up @@ -126,10 +130,27 @@ def main():
assert torch.allclose(hf_summary, torchhub_summary, atol=1e-6)
assert torch.allclose(hf_features, torchhub_features, atol=1e-6)

intermediates = hf_model.radio_model.forward_intermediates(
hf_model.input_conditioner(x),
indices=[-1],
return_prefix_tokens=True,
norm=False,
stop_early=False,
output_fmt='NLC',
intermediates_only=True,
aggregation="sparse",
)
print(
f"Intermediates inference returned summary ",
f"with shape={intermediates[0].summary.shape} and std={intermediates[0].summary.std().item():.3}, ",
f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}",
)
#assert torch.allclose(intermediates[0].features, torchhub_output["backbone"].features, atol=1e-6)

print("All outputs matched!")

# Infer a sample image.
image_processor = CLIPImageProcessor.from_pretrained(args.hf_repo)
image_processor = CLIPImageProcessor.from_pretrained(args.hf_repo, revision=args.hf_revision)

image = Image.open("./examples/image1.png").convert("RGB")
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
Expand Down