Skip to content

Commit

Permalink
Merge pull request #38 from reginabarzilaygroup/v1.2.2_dev
Browse files Browse the repository at this point in the history
V1.2.2 dev
  • Loading branch information
pgmikhael authored May 23, 2024
2 parents 6618b4f + 1e06b0e commit 1fd3d38
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 27 deletions.
22 changes: 7 additions & 15 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pickle

import sybil.utils.logging_utils
from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -52,29 +53,20 @@ def _get_parser():
help="Name of the model to use for prediction. Default: sybil_ensemble",
)

parser.add_argument("-l", "--log", "--loglevel", default="INFO", dest="loglevel")
parser.add_argument("-l", "--log", "--loglevel", "--log-level",
default="INFO", dest="loglevel")

return parser


def logging_basic_config(args):
info_fmt = "[%(asctime)s] - %(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if args.loglevel.upper() == "DEBUG" else info_fmt

logging.basicConfig(
format=fmt, datefmt="%Y-%m-%d %H:%M:%S", level=args.loglevel.upper()
)


def inference(
image_dir,
output_dir,
model_name="sybil_ensemble",
return_attentions=False,
file_type="auto",
):
logger = logging.getLogger("inference")
logger = sybil.utils.logging_utils.get_logger()

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
Expand All @@ -94,11 +86,11 @@ def inference(

num_files = len(input_files)

logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")

# Load a trained model
model = Sybil(model_name)

logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")

# Get risk scores
serie = Serie(input_files, file_type=file_type)
series = [serie]
Expand Down Expand Up @@ -130,7 +122,7 @@ def inference(

def main():
args = _get_parser().parse_args()
logging_basic_config(args)
sybil.utils.logging_utils.configure_logger(args.loglevel)

os.makedirs(args.output_dir, exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions scripts/run_inference_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ demo_scan_dir=sybil_demo_data
if [ ! -d "$demo_scan_dir" ]; then
# Download example data
curl -L -o sybil_example.zip "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&dl=1"
tar -xf sybil_example.zip
unzip -q sybil_example.zip
fi

python3 scripts/inference.py \
--loglevel DEBUG \
--output-dir demo_prediction \
--return-attentions \
$demo_scan_dir
$demo_scan_dir
7 changes: 5 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.2.1
version = 1.2.2
# url =
project_urls =
; Documentation = https://.../docs
Expand All @@ -23,6 +23,10 @@ platforms = any
classifiers =
Programming Language :: Python

[easy_install]
find_links =
https://download.pytorch.org/whl/cu117/torch_stable.html

[options]
zip_safe = False
packages = find:
Expand All @@ -32,7 +36,6 @@ python_requires = >=3.8
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
# new major versions. This works if the required packages follow Semantic Versioning.
# For more information, check out https://semver.org/.
# Use --find-links https://download.pytorch.org/whl/cu117/torch_stable.html for torch libraries
install_requires =
importlib-metadata; python_version>="3.8"
numpy==1.24.1
Expand Down
1 change: 1 addition & 0 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions"]
66 changes: 58 additions & 8 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from sybil.serie import Serie
from sybil.models.sybil import SybilNet
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
from sybil.utils.metrics import get_survival_metrics


Expand Down Expand Up @@ -179,9 +181,10 @@ def __init__(
Path to calibrator pickle file corresponding with model
device: str
If provided, will run inference using this device.
By default uses GPU, if available.
By default, uses GPU with the most free memory, if available.
"""
self._logger = get_logger()
# Download if needed
if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE):
name_or_path, calibrator_path = download_sybil(name_or_path, cache)
Expand All @@ -197,15 +200,20 @@ def __init__(
if (calibrator_path is not None) and (not os.path.exists(calibrator_path)):
raise ValueError(f"Path not found for calibrator {calibrator_path}")

# Set device
# Set device.
# If set manually, use it and stay there.
# Otherwise, pick the most free GPU now and at predict time.
self._device_flexible = True
if device is not None:
self.device = device
self._device_flexible = False
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = get_default_device()

self.ensemble = torch.nn.ModuleList()
for path in name_or_path:
self.ensemble.append(self.load_model(path))
self.to(self.device)

if calibrator_path is not None:
self.calibrator = pickle.load(open(calibrator_path, "rb"))
Expand Down Expand Up @@ -235,12 +243,12 @@ def load_model(self, path):
# Remove model from param names
state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()}
model.load_state_dict(state_dict) # type: ignore
if self.device == "cuda":
model.to("cuda")
if self.device is not None:
model.to(self.device)

# Set eval
model.eval()
print(f"Loaded model from {path}")
self._logger.info(f"Loaded model from {path}")
return model

def _calibrate(self, scores: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -305,8 +313,8 @@ def _predict(
raise ValueError("Expected a list of Serie objects.")

volume = serie.get_volume()
if self.device == "cuda":
volume = volume.cuda()
if self.device is not None:
volume = volume.to(self.device)

with torch.no_grad():
out = model(volume)
Expand Down Expand Up @@ -344,6 +352,12 @@ def predict(
Output prediction. See details for :class:`~sybil.model.Prediction`".
"""

if self._device_flexible:
self.device = self._pick_device()
self.to(self.device)
self._logger.debug(f"Beginning prediction on device: {self.device}")

scores = []
attentions_ = [] if return_attentions else None
attention_keys = None
Expand Down Expand Up @@ -419,3 +433,39 @@ def evaluate(
c_index = float(out["c_index"])

return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions)

def to(self, device: str):
"""Move model to device.
Parameters
----------
device : str
Device to move model to.
"""
self.device = device
self.ensemble.to(device)

def _pick_device(self):
"""
Pick the device to run inference on.
This is based on the device with the most free memory, with a preference for remaining
on the current device.
Motivation is to enable multiprocessing without the processes needed to communicate.
"""
if not torch.cuda.is_available():
return get_default_device()

# Get size of the model in memory (approximate)
model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters())

# Check memory available on current device.
# If it seems like we're the only thing on this GPU, stay.
free_mem, total_mem = get_device_mem_info(self.device)
cur_allocated = total_mem - free_mem
min_to_move = int(1.01 * model_mem)
if cur_allocated < min_to_move:
return self.device
else:
# Otherwise, get the most free GPU
return get_most_free_gpu()
72 changes: 72 additions & 0 deletions sybil/utils/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import itertools
import os
from typing import Union

import torch


def get_default_device():
if torch.cuda.is_available():
return get_most_free_gpu()
elif torch.backends.mps.is_available():
# Not all operations implemented in MPS yet
use_mps = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") == "1"
if use_mps:
return torch.device('mps')
else:
return torch.device('cpu')
else:
return torch.device('cpu')


def get_available_devices(num_devices=None, max_devices=None):
device = get_default_device()
if device.type == "cuda":
num_gpus = torch.cuda.device_count()
if max_devices is not None:
num_gpus = min(num_gpus, max_devices)
gpu_list = [get_device(i) for i in range(num_gpus)]
if num_devices is not None:
cycle_gpu_list = itertools.cycle(gpu_list)
gpu_list = [next(cycle_gpu_list) for _ in range(num_devices)]
return gpu_list
else:
num_devices = num_devices if num_devices else torch.multiprocessing.cpu_count()
num_devices = min(num_devices, max_devices) if max_devices is not None else num_devices
return [device]*num_devices


def get_device(gpu_id: int):
if gpu_id is not None and torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
return torch.device(f'cuda:{gpu_id}')
else:
return None


def get_device_mem_info(device: Union[int, torch.device]):
if not torch.cuda.is_available():
return None

free_mem, total_mem = torch.cuda.mem_get_info(device=device)
return free_mem, total_mem


def get_most_free_gpu():
"""
Get the GPU with the most free memory
If system has no GPUs (or CUDA not available), return None
"""
if not torch.cuda.is_available():
return None

num_gpus = torch.cuda.device_count()
if num_gpus == 0:
return None

most_free_idx, most_free_val = -1, -1
for i in range(num_gpus):
free_mem, total_mem = get_device_mem_info(i)
if free_mem > most_free_val:
most_free_idx, most_free_val = i, free_mem

return torch.device(f'cuda:{most_free_idx}')
68 changes: 68 additions & 0 deletions sybil/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os


LOGGER_NAME = "sybil"
LOGLEVEL_KEY = "LOG_LEVEL"


def _get_formatter(loglevel="INFO"):
warn_fmt = "[%(asctime)s] %(levelname)s -%(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if loglevel.upper() in {"DEBUG"} else warn_fmt
return logging.Formatter(
fmt=fmt,
datefmt="%Y-%b-%d %H:%M:%S %Z",
)


def remove_all_handlers(logger):
while logger.hasHandlers():
logger.removeHandler(logger.handlers[0])


def configure_logger(loglevel=None, logger_name=LOGGER_NAME, logfile=None):
"""Do basic logger configuration and set our main logger"""

# Set as environment variable so other processes can retrieve it
if loglevel is None:
loglevel = os.environ.get(LOGLEVEL_KEY, "WARNING")
else:
os.environ[LOGLEVEL_KEY] = loglevel

logger = logging.getLogger(logger_name)
logger.setLevel(loglevel)
remove_all_handlers(logger)
logger.propagate = False

formatter = _get_formatter(loglevel)
def _prep_handler(handler):
for ex_handler in logger.handlers:
if type(ex_handler) == type(handler):
# Remove old handler, don't want to double-handle
logger.removeHandler(ex_handler)
handler.setLevel(loglevel)
handler.setFormatter(formatter)
logger.addHandler(handler)

sh = logging.StreamHandler()
_prep_handler(sh)

if logfile is not None:
fh = logging.FileHandler(logfile, mode="a")
_prep_handler(fh)

return logger


def get_logger(base_name=LOGGER_NAME):
"""
Return a logger.
Use a different logger in each subprocess, though they should all have the same log level.
"""
pid = os.getpid()
logger_name = f"{base_name}-process-{pid}"
logger = logging.getLogger(logger_name)
if not logger.hasHandlers():
configure_logger(logger_name=logger_name)
return logger

0 comments on commit 1fd3d38

Please sign in to comment.