Skip to content

Commit

Permalink
fix external model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Aug 18, 2024
1 parent d24ab45 commit 7148431
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
8 changes: 5 additions & 3 deletions scripts/download_models.py → cutie/utils/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
import hashlib
from tqdm import tqdm
import torch


_links = [
Expand All @@ -10,17 +11,18 @@
]

def download_models_if_needed():
os.makedirs('weights', exist_ok=True)
weight_dir = os.path.join(torch.hub.get_dir(), 'weights')
os.makedirs(weight_dir, exist_ok=True)
for link, md5 in _links:
# download file if not exists with a progressbar
filename = link.split('/')[-1]
if not os.path.exists(os.path.join('weights', filename)) or hashlib.md5(open(os.path.join('weights', filename), 'rb').read()).hexdigest() != md5:
if not os.path.exists(os.path.join(weight_dir, filename)) or hashlib.md5(open(os.path.join(weight_dir, filename), 'rb').read()).hexdigest() != md5:
print(f'Downloading {filename}...')
r = requests.get(link, stream=True)
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
t = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(os.path.join('weights', filename), 'wb') as f:
with open(os.path.join(weight_dir, filename), 'wb') as f:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
Expand Down
5 changes: 3 additions & 2 deletions cutie/utils/get_default_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
A helper function to get a default model for quick testing
"""
import os
from omegaconf import open_dict
from hydra import compose, initialize

import torch
from cutie.model.cutie import CUTIE
from cutie.inference.utils.args_utils import get_dataset_cfg
from scripts.download_models import download_models_if_needed
from cutie.utils.download_models import download_models_if_needed


def get_default_model() -> CUTIE:
Expand All @@ -16,7 +17,7 @@ def get_default_model() -> CUTIE:

download_models_if_needed()
with open_dict(cfg):
cfg['weights'] = './weights/cutie-base-mega.pth'
cfg['weights'] = os.path.join(torch.hub.get_dir(), 'weights', 'cutie-base-mega.pth')
get_dataset_cfg(cfg)

# Load the network weights
Expand Down
2 changes: 1 addition & 1 deletion gui/main_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from gui.click_controller import ClickController
from gui.reader import PropagationReader, get_data_loader
from gui.exporter import convert_frames_to_video, convert_mask_to_binary
from scripts.download_models import download_models_if_needed
from cutie.utils.download_models import download_models_if_needed

log = logging.getLogger()

Expand Down

0 comments on commit 7148431

Please sign in to comment.