Skip to content

Commit

Permalink
save to weights instead of hub
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Aug 18, 2024
1 parent 7148431 commit b2dc761
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions cutie/utils/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
('https://github.com/hkchengrex/Cutie/releases/download/v1.0/cutie-base-mega.pth', 'a6071de6136982e396851903ab4c083a'),
]

def download_models_if_needed():
weight_dir = os.path.join(torch.hub.get_dir(), 'weights')
def download_models_if_needed(local: bool = False) -> str:
weight_dir = os.path.join(os.path.dirname(__file__), '..', '..', 'weights')
os.makedirs(weight_dir, exist_ok=True)
for link, md5 in _links:
# download file if not exists with a progressbar
filename = link.split('/')[-1]
if not os.path.exists(os.path.join(weight_dir, filename)) or hashlib.md5(open(os.path.join(weight_dir, filename), 'rb').read()).hexdigest() != md5:
print(f'Downloading {filename}...')
print(f'Downloading {filename} to {weight_dir}...')
r = requests.get(link, stream=True)
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
Expand All @@ -29,7 +29,8 @@ def download_models_if_needed():
t.close()
if total_size != 0 and t.n != total_size:
raise RuntimeError('Error while downloading %s' % filename)
return weight_dir


if __name__ == '__main__':
download_models_if_needed()
download_models_if_needed(local=True)
4 changes: 2 additions & 2 deletions cutie/utils/get_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def get_default_model() -> CUTIE:
initialize(version_base='1.3.2', config_path="../config", job_name="eval_config")
cfg = compose(config_name="eval_config")

download_models_if_needed()
weight_dir = download_models_if_needed()
with open_dict(cfg):
cfg['weights'] = os.path.join(torch.hub.get_dir(), 'weights', 'cutie-base-mega.pth')
cfg['weights'] = os.path.join(weight_dir, 'cutie-base-mega.pth')
get_dataset_cfg(cfg)

# Load the network weights
Expand Down

0 comments on commit b2dc761

Please sign in to comment.