From 714843120c7574b4ae3c771c8e4bb0e600d23f89 Mon Sep 17 00:00:00 2001 From: Rex Cheng Date: Sun, 18 Aug 2024 22:51:35 +0900 Subject: [PATCH] fix external model loading --- {scripts => cutie/utils}/download_models.py | 8 +++++--- cutie/utils/get_default_model.py | 5 +++-- gui/main_controller.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) rename {scripts => cutie/utils}/download_models.py (76%) diff --git a/scripts/download_models.py b/cutie/utils/download_models.py similarity index 76% rename from scripts/download_models.py rename to cutie/utils/download_models.py index a11ab12..fadb578 100644 --- a/scripts/download_models.py +++ b/cutie/utils/download_models.py @@ -2,6 +2,7 @@ import requests import hashlib from tqdm import tqdm +import torch _links = [ @@ -10,17 +11,18 @@ ] def download_models_if_needed(): - os.makedirs('weights', exist_ok=True) + weight_dir = os.path.join(torch.hub.get_dir(), 'weights') + os.makedirs(weight_dir, exist_ok=True) for link, md5 in _links: # download file if not exists with a progressbar filename = link.split('/')[-1] - if not os.path.exists(os.path.join('weights', filename)) or hashlib.md5(open(os.path.join('weights', filename), 'rb').read()).hexdigest() != md5: + if not os.path.exists(os.path.join(weight_dir, filename)) or hashlib.md5(open(os.path.join(weight_dir, filename), 'rb').read()).hexdigest() != md5: print(f'Downloading {filename}...') r = requests.get(link, stream=True) total_size = int(r.headers.get('content-length', 0)) block_size = 1024 t = tqdm(total=total_size, unit='iB', unit_scale=True) - with open(os.path.join('weights', filename), 'wb') as f: + with open(os.path.join(weight_dir, filename), 'wb') as f: for data in r.iter_content(block_size): t.update(len(data)) f.write(data) diff --git a/cutie/utils/get_default_model.py b/cutie/utils/get_default_model.py index bf61e3f..a3cba5d 100644 --- a/cutie/utils/get_default_model.py +++ b/cutie/utils/get_default_model.py @@ -1,13 +1,14 @@ """ A helper function to get a default model for quick testing """ +import os from omegaconf import open_dict from hydra import compose, initialize import torch from cutie.model.cutie import CUTIE from cutie.inference.utils.args_utils import get_dataset_cfg -from scripts.download_models import download_models_if_needed +from cutie.utils.download_models import download_models_if_needed def get_default_model() -> CUTIE: @@ -16,7 +17,7 @@ def get_default_model() -> CUTIE: download_models_if_needed() with open_dict(cfg): - cfg['weights'] = './weights/cutie-base-mega.pth' + cfg['weights'] = os.path.join(torch.hub.get_dir(), 'weights', 'cutie-base-mega.pth') get_dataset_cfg(cfg) # Load the network weights diff --git a/gui/main_controller.py b/gui/main_controller.py index 3be3e66..be06426 100644 --- a/gui/main_controller.py +++ b/gui/main_controller.py @@ -27,7 +27,7 @@ from gui.click_controller import ClickController from gui.reader import PropagationReader, get_data_loader from gui.exporter import convert_frames_to_video, convert_mask_to_binary -from scripts.download_models import download_models_if_needed +from cutie.utils.download_models import download_models_if_needed log = logging.getLogger()