diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index 7626385b..81b7a17e 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -1,5 +1,5 @@ # fmt: off -from typing import List, Type, Union +from typing import Any, Dict, List, Type, Union import torch.nn as nn import torch.nn.functional as F @@ -20,14 +20,14 @@ weight_ids = { "10": { - "18": None, + "18": "pe_resnet18_c10", "32": None, "50": "pe_resnet50_c10", "101": None, "152": None, }, "100": { - "18": None, + "18": "pe_resnet18_c100", "32": None, "50": "pe_resnet50_c100", "101": None, @@ -40,13 +40,6 @@ "101": None, "152": None, }, - "1000_wider": { - "18": None, - "32": None, - "50": "pex4_resnet50", - "101": None, - "152": None, - }, } @@ -211,6 +204,9 @@ def __init__( super().__init__() self.in_channels = in_channels + self.alpha = alpha + self.gamma = gamma + self.groups = groups self.num_estimators = num_estimators self.in_planes = 64 block_planes = self.in_planes @@ -350,6 +346,15 @@ def forward(self, x: Tensor) -> Tensor: out = self.linear(out) return out + def check_config(self, config: Dict[str, Any]) -> bool: + """Check if the pretrained configuration matches the current model.""" + return ( + (config["alpha"] == self.alpha) + * (config["gamma"] == self.gamma) + * (config["groups"] == self.groups) + * (config["num_estimators"] == self.num_estimators) + ) + def packed_resnet18( in_channels: int, @@ -386,10 +391,15 @@ def packed_resnet18( style=style, ) if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)][18] + weights = weight_ids[str(num_classes)]["18"] if weights is None: raise ValueError("No pretrained weights for this configuration") - net.load_state_dict(load_hf(weights)) + state_dict, config = load_hf(weights) + if not net.check_config(config): + raise ValueError( + "Pretrained weights do not match current configuration." + ) + net.load_state_dict(state_dict) return net @@ -428,10 +438,15 @@ def packed_resnet34( style=style, ) if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)][34] + weights = weight_ids[str(num_classes)]["34"] if weights is None: raise ValueError("No pretrained weights for this configuration") - net.load_state_dict(load_hf(weights)) + state_dict, config = load_hf(weights) + if not net.check_config(config): + raise ValueError( + "Pretrained weights do not match current configuration." + ) + net.load_state_dict(state_dict) return net @@ -470,10 +485,15 @@ def packed_resnet50( style=style, ) if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)][50] + weights = weight_ids[str(num_classes)]["50"] if weights is None: raise ValueError("No pretrained weights for this configuration") - net.load_state_dict(load_hf(weights)) + state_dict, config = load_hf(weights) + if not net.check_config(config): + raise ValueError( + "Pretrained weights do not match current configuration." + ) + net.load_state_dict(state_dict) return net @@ -512,10 +532,15 @@ def packed_resnet101( style=style, ) if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)][101] + weights = weight_ids[str(num_classes)]["101"] if weights is None: raise ValueError("No pretrained weights for this configuration") - net.load_state_dict(load_hf(weights)) + state_dict, config = load_hf(weights) + if not net.check_config(config): + raise ValueError( + "Pretrained weights do not match current configuration." + ) + net.load_state_dict(state_dict) return net @@ -556,8 +581,13 @@ def packed_resnet152( style=style, ) if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)][152] + weights = weight_ids[str(num_classes)]["152"] if weights is None: raise ValueError("No pretrained weights for this configuration") - net.load_state_dict(load_hf(weights)) + state_dict, config = load_hf(weights) + if not net.check_config(config): + raise ValueError( + "Pretrained weights do not match current configuration." + ) + net.load_state_dict(state_dict) return net diff --git a/torch_uncertainty/utils/hub.py b/torch_uncertainty/utils/hub.py index a17243a5..1954b5fe 100644 --- a/torch_uncertainty/utils/hub.py +++ b/torch_uncertainty/utils/hub.py @@ -1,8 +1,32 @@ +# fmt: off +from pathlib import Path +from typing import Dict, Tuple + +import torch +import yaml from huggingface_hub import hf_hub_download -def load_hf(weight_id: str): - weights = hf_hub_download( - repo_id=f"torch-uncertainty/{weight_id}", filename=f"{weight_id}.ckpt" - ) - return weights +# fmt: on +def load_hf(weight_id: str) -> Tuple[torch.Tensor, Dict]: + """Load a model from the huggingface hub. + + Args: + weight_id (str): The id of the model to load. + + Returns: + Tuple[torch.Tensor, Dict]: The model weights and config. + """ + repo_id = f"torch-uncertainty/{weight_id}" + + # Load the weights + weight_path = hf_hub_download(repo_id=repo_id, filename=f"{weight_id}.ckpt") + weight = torch.load(weight_path) + if "state_dict" in weight: + weight = weight["state_dict"] + + # Load the config + config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml") + config = yaml.safe_load(Path(config_path).read_text()) + + return weight, config