Skip to content

Commit

Permalink
[FIX]Loading teacher model only
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanBlanchet committed Jan 22, 2025
1 parent e3faeaf commit 4cdf021
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
"""
Allows to load pre-trained models from the `ibot` repository.
Example:
```python
import torch
import torch.hub
model = torch.hub.load("bytedance/ibot", "vits_16")
"""

import torch
import torch.hub
import torch.nn as nn
Expand All @@ -8,21 +19,23 @@
URL = "https://lf3-nlp-opensource.bytetos.com/obj/nlp-opensource/archive/2022/ibot/"

PTHS = dict(
vit_s16="vits_16/checkpoint.pth",
swint_7="swint_7/checkpoint.pth",
swint_14="swint_14/checkpoint.pth",
vitb_16="vitb_16/checkpoint.pth",
vitb_16_rand_mask="vitb_16_rand_mask/checkpoint.pth",
vitl_16="vitl_16/checkpoint.pth",
vitl_16_rand_mask="vitl_16_rand_mask/checkpoint.pth",
vit_s16="vits_16/checkpoint_teacher.pth",
swint_7="swint_7/checkpoint_teacher.pth",
swint_14="swint_14/checkpoint_teacher.pth",
vitb_16="vitb_16/checkpoint_teacher.pth",
vitb_16_rand_mask="vitb_16_rand_mask/checkpoint_teacher.pth",
vitl_16="vitl_16/checkpoint_teacher.pth",
vitl_16_rand_mask="vitl_16_rand_mask/checkpoint_teacher.pth",
)


def _load_ckpt(pth, model: nn.Module, pretrained=True, **kwargs):
if pretrained:
pth = torch.hub.load_state_dict_from_url(url=URL + pth)
state_dict = pth["teacher"]
model.load_state_dict(state_dict, strict=False)
pth = torch.hub.load_state_dict_from_url(
url=URL + pth, model_dir=pth.split("/")[0]
)
state_dict = pth["state_dict"]
model.load_state_dict(state_dict)
return model


Expand Down Expand Up @@ -62,3 +75,6 @@ def vitl_16_rand_mask(**kwargs):


dependencies = ["torch"]

if __name__ == "__main__":
model = vits_16()

0 comments on commit 4cdf021

Please sign in to comment.