Skip to content

Commit

Permalink
Merge pull request #31 from reginabarzilaygroup/GH30_download
Browse files Browse the repository at this point in the history
Download models from GitHub releases instead of Google Drive
  • Loading branch information
jsilter authored Mar 11, 2024
2 parents 6453d08 + 710c193 commit 864c9d5
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions sybil/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import NamedTuple, Union, Dict, List, Optional
import os
from argparse import Namespace
import gdown
from io import BytesIO
import os
from typing import NamedTuple, Union, Dict, List, Optional, Tuple
from urllib.request import urlopen
from zipfile import ZipFile
# import gdown

import torch
import numpy as np
Expand All @@ -12,6 +15,7 @@
from sybil.utils.metrics import get_survival_metrics


# Leaving this here for a bit; these are IDs to download the models from Google Drive
NAME_TO_FILE = {
"sybil_base": {
"checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
Expand Down Expand Up @@ -62,6 +66,8 @@
},
}

CHECKPOINT_URL = "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.0.3/sybil_checkpoints.zip"


class Prediction(NamedTuple):
scores: List[List[float]]
Expand All @@ -75,7 +81,7 @@ class Evaluation(NamedTuple):
attentions: List[Dict[str, np.ndarray]] = None


def download_sybil(name, cache):
def download_sybil_gdrive(name, cache):
"""Download trained models and calibrator from Google Drive
Parameters
Expand Down Expand Up @@ -118,6 +124,40 @@ def download_sybil(name, cache):
return download_model_paths, download_calib_path


def download_sybil(name, cache) -> Tuple[List[str], str]:
"""Download trained models and calibrator"""
# Create cache folder if not exists
cache = os.path.expanduser(cache)
os.makedirs(cache, exist_ok=True)

# Download models
model_files = NAME_TO_FILE[name]
checkpoints = model_files["checkpoint"]
download_calib_path = os.path.join(cache, f"{name}.p")
have_all_files = os.path.exists(download_calib_path)

download_model_paths = []
for checkpoint in checkpoints:
cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt")
have_all_files &= os.path.exists(cur_checkpoint_path)
download_model_paths.append(cur_checkpoint_path)

if not have_all_files:
print(f"Downloading models to {cache}")
download_and_extract(CHECKPOINT_URL, cache)

return download_model_paths, download_calib_path


def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]:
resp = urlopen(remote_model_url)
os.makedirs(local_model_dir, exist_ok=True)
with ZipFile(BytesIO(resp.read())) as zip_file:
all_files_and_dirs = zip_file.namelist()
zip_file.extractall(local_model_dir)
return all_files_and_dirs


class Sybil:
def __init__(
self,
Expand Down

0 comments on commit 864c9d5

Please sign in to comment.