Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.2.2 dev #38

Merged
merged 3 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pickle

import sybil.utils.logging_utils
from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -52,29 +53,20 @@ def _get_parser():
help="Name of the model to use for prediction. Default: sybil_ensemble",
)

parser.add_argument("-l", "--log", "--loglevel", default="INFO", dest="loglevel")
parser.add_argument("-l", "--log", "--loglevel", "--log-level",
default="INFO", dest="loglevel")

return parser


def logging_basic_config(args):
info_fmt = "[%(asctime)s] - %(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if args.loglevel.upper() == "DEBUG" else info_fmt

logging.basicConfig(
format=fmt, datefmt="%Y-%m-%d %H:%M:%S", level=args.loglevel.upper()
)


def inference(
image_dir,
output_dir,
model_name="sybil_ensemble",
return_attentions=False,
file_type="auto",
):
logger = logging.getLogger("inference")
logger = sybil.utils.logging_utils.get_logger()

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
Expand All @@ -94,11 +86,11 @@ def inference(

num_files = len(input_files)

logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")

# Load a trained model
model = Sybil(model_name)

logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")

# Get risk scores
serie = Serie(input_files, file_type=file_type)
series = [serie]
Expand Down Expand Up @@ -130,7 +122,7 @@ def inference(

def main():
args = _get_parser().parse_args()
logging_basic_config(args)
sybil.utils.logging_utils.configure_logger(args.loglevel)

os.makedirs(args.output_dir, exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions scripts/run_inference_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ demo_scan_dir=sybil_demo_data
if [ ! -d "$demo_scan_dir" ]; then
# Download example data
curl -L -o sybil_example.zip "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&dl=1"
tar -xf sybil_example.zip
unzip -q sybil_example.zip
fi

python3 scripts/inference.py \
--loglevel DEBUG \
--output-dir demo_prediction \
--return-attentions \
$demo_scan_dir
$demo_scan_dir
7 changes: 5 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.2.1
version = 1.2.2
# url =
project_urls =
; Documentation = https://.../docs
Expand All @@ -23,6 +23,10 @@ platforms = any
classifiers =
Programming Language :: Python

[easy_install]
find_links =
https://download.pytorch.org/whl/cu117/torch_stable.html

[options]
zip_safe = False
packages = find:
Expand All @@ -32,7 +36,6 @@ python_requires = >=3.8
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
# new major versions. This works if the required packages follow Semantic Versioning.
# For more information, check out https://semver.org/.
# Use --find-links https://download.pytorch.org/whl/cu117/torch_stable.html for torch libraries
install_requires =
importlib-metadata; python_version>="3.8"
numpy==1.24.1
Expand Down
1 change: 1 addition & 0 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions"]
66 changes: 58 additions & 8 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from sybil.serie import Serie
from sybil.models.sybil import SybilNet
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
from sybil.utils.metrics import get_survival_metrics


Expand Down Expand Up @@ -179,9 +181,10 @@ def __init__(
Path to calibrator pickle file corresponding with model
device: str
If provided, will run inference using this device.
By default uses GPU, if available.
By default, uses GPU with the most free memory, if available.

"""
self._logger = get_logger()
# Download if needed
if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE):
name_or_path, calibrator_path = download_sybil(name_or_path, cache)
Expand All @@ -197,15 +200,20 @@ def __init__(
if (calibrator_path is not None) and (not os.path.exists(calibrator_path)):
raise ValueError(f"Path not found for calibrator {calibrator_path}")

# Set device
# Set device.
# If set manually, use it and stay there.
# Otherwise, pick the most free GPU now and at predict time.
self._device_flexible = True
if device is not None:
self.device = device
self._device_flexible = False
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = get_default_device()

self.ensemble = torch.nn.ModuleList()
for path in name_or_path:
self.ensemble.append(self.load_model(path))
self.to(self.device)

if calibrator_path is not None:
self.calibrator = pickle.load(open(calibrator_path, "rb"))
Expand Down Expand Up @@ -235,12 +243,12 @@ def load_model(self, path):
# Remove model from param names
state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()}
model.load_state_dict(state_dict) # type: ignore
if self.device == "cuda":
model.to("cuda")
if self.device is not None:
model.to(self.device)

# Set eval
model.eval()
print(f"Loaded model from {path}")
self._logger.info(f"Loaded model from {path}")
return model

def _calibrate(self, scores: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -305,8 +313,8 @@ def _predict(
raise ValueError("Expected a list of Serie objects.")

volume = serie.get_volume()
if self.device == "cuda":
volume = volume.cuda()
if self.device is not None:
volume = volume.to(self.device)

with torch.no_grad():
out = model(volume)
Expand Down Expand Up @@ -344,6 +352,12 @@ def predict(
Output prediction. See details for :class:`~sybil.model.Prediction`".

"""

if self._device_flexible:
self.device = self._pick_device()
self.to(self.device)
self._logger.debug(f"Beginning prediction on device: {self.device}")

scores = []
attentions_ = [] if return_attentions else None
attention_keys = None
Expand Down Expand Up @@ -419,3 +433,39 @@ def evaluate(
c_index = float(out["c_index"])

return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions)

def to(self, device: str):
"""Move model to device.

Parameters
----------
device : str
Device to move model to.
"""
self.device = device
self.ensemble.to(device)

def _pick_device(self):
"""
Pick the device to run inference on.
This is based on the device with the most free memory, with a preference for remaining
on the current device.

Motivation is to enable multiprocessing without the processes needed to communicate.
"""
if not torch.cuda.is_available():
return get_default_device()

# Get size of the model in memory (approximate)
model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters())

# Check memory available on current device.
# If it seems like we're the only thing on this GPU, stay.
free_mem, total_mem = get_device_mem_info(self.device)
cur_allocated = total_mem - free_mem
min_to_move = int(1.01 * model_mem)
if cur_allocated < min_to_move:
return self.device
else:
# Otherwise, get the most free GPU
return get_most_free_gpu()
72 changes: 72 additions & 0 deletions sybil/utils/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import itertools
import os
from typing import Union

import torch


def get_default_device():
if torch.cuda.is_available():
return get_most_free_gpu()
elif torch.backends.mps.is_available():
# Not all operations implemented in MPS yet
use_mps = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") == "1"
if use_mps:
return torch.device('mps')
else:
return torch.device('cpu')
else:
return torch.device('cpu')


def get_available_devices(num_devices=None, max_devices=None):
device = get_default_device()
if device.type == "cuda":
num_gpus = torch.cuda.device_count()
if max_devices is not None:
num_gpus = min(num_gpus, max_devices)
gpu_list = [get_device(i) for i in range(num_gpus)]
if num_devices is not None:
cycle_gpu_list = itertools.cycle(gpu_list)
gpu_list = [next(cycle_gpu_list) for _ in range(num_devices)]
return gpu_list
else:
num_devices = num_devices if num_devices else torch.multiprocessing.cpu_count()
num_devices = min(num_devices, max_devices) if max_devices is not None else num_devices
return [device]*num_devices


def get_device(gpu_id: int):
if gpu_id is not None and torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
return torch.device(f'cuda:{gpu_id}')
else:
return None


def get_device_mem_info(device: Union[int, torch.device]):
if not torch.cuda.is_available():
return None

free_mem, total_mem = torch.cuda.mem_get_info(device=device)
return free_mem, total_mem


def get_most_free_gpu():
"""
Get the GPU with the most free memory
If system has no GPUs (or CUDA not available), return None
"""
if not torch.cuda.is_available():
return None

num_gpus = torch.cuda.device_count()
if num_gpus == 0:
return None

most_free_idx, most_free_val = -1, -1
for i in range(num_gpus):
free_mem, total_mem = get_device_mem_info(i)
if free_mem > most_free_val:
most_free_idx, most_free_val = i, free_mem

return torch.device(f'cuda:{most_free_idx}')
68 changes: 68 additions & 0 deletions sybil/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os


LOGGER_NAME = "sybil"
LOGLEVEL_KEY = "LOG_LEVEL"


def _get_formatter(loglevel="INFO"):
warn_fmt = "[%(asctime)s] %(levelname)s -%(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if loglevel.upper() in {"DEBUG"} else warn_fmt
return logging.Formatter(
fmt=fmt,
datefmt="%Y-%b-%d %H:%M:%S %Z",
)


def remove_all_handlers(logger):
while logger.hasHandlers():
logger.removeHandler(logger.handlers[0])


def configure_logger(loglevel=None, logger_name=LOGGER_NAME, logfile=None):
"""Do basic logger configuration and set our main logger"""

# Set as environment variable so other processes can retrieve it
if loglevel is None:
loglevel = os.environ.get(LOGLEVEL_KEY, "WARNING")
else:
os.environ[LOGLEVEL_KEY] = loglevel

logger = logging.getLogger(logger_name)
logger.setLevel(loglevel)
remove_all_handlers(logger)
logger.propagate = False

formatter = _get_formatter(loglevel)
def _prep_handler(handler):
for ex_handler in logger.handlers:
if type(ex_handler) == type(handler):
# Remove old handler, don't want to double-handle
logger.removeHandler(ex_handler)
handler.setLevel(loglevel)
handler.setFormatter(formatter)
logger.addHandler(handler)

sh = logging.StreamHandler()
_prep_handler(sh)

if logfile is not None:
fh = logging.FileHandler(logfile, mode="a")
_prep_handler(fh)

return logger


def get_logger(base_name=LOGGER_NAME):
"""
Return a logger.
Use a different logger in each subprocess, though they should all have the same log level.
"""
pid = os.getpid()
logger_name = f"{base_name}-process-{pid}"
logger = logging.getLogger(logger_name)
if not logger.hasHandlers():
configure_logger(logger_name=logger_name)
return logger
Loading