diff --git a/scripts/inference.py b/scripts/inference.py index 1ea2258..422fe38 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -5,6 +5,7 @@ import os import pickle +import sybil.utils.logging_utils from sybil import Serie, Sybil, visualize_attentions script_directory = os.path.dirname(os.path.abspath(__file__)) @@ -52,21 +53,12 @@ def _get_parser(): help="Name of the model to use for prediction. Default: sybil_ensemble", ) - parser.add_argument("-l", "--log", "--loglevel", default="INFO", dest="loglevel") + parser.add_argument("-l", "--log", "--loglevel", "--log-level", + default="INFO", dest="loglevel") return parser -def logging_basic_config(args): - info_fmt = "[%(asctime)s] - %(message)s" - debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s" - fmt = debug_fmt if args.loglevel.upper() == "DEBUG" else info_fmt - - logging.basicConfig( - format=fmt, datefmt="%Y-%m-%d %H:%M:%S", level=args.loglevel.upper() - ) - - def inference( image_dir, output_dir, @@ -74,7 +66,7 @@ def inference( return_attentions=False, file_type="auto", ): - logger = logging.getLogger("inference") + logger = sybil.utils.logging_utils.get_logger() input_files = os.listdir(image_dir) input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")] @@ -94,11 +86,11 @@ def inference( num_files = len(input_files) + logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}") + # Load a trained model model = Sybil(model_name) - logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}") - # Get risk scores serie = Serie(input_files, file_type=file_type) series = [serie] @@ -130,7 +122,7 @@ def inference( def main(): args = _get_parser().parse_args() - logging_basic_config(args) + sybil.utils.logging_utils.configure_logger(args.loglevel) os.makedirs(args.output_dir, exist_ok=True) diff --git a/scripts/run_inference_demo.sh b/scripts/run_inference_demo.sh index 3051239..6f1088d 100755 --- a/scripts/run_inference_demo.sh +++ b/scripts/run_inference_demo.sh @@ -10,11 +10,11 @@ demo_scan_dir=sybil_demo_data if [ ! -d "$demo_scan_dir" ]; then # Download example data curl -L -o sybil_example.zip "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&dl=1" - tar -xf sybil_example.zip + unzip -q sybil_example.zip fi python3 scripts/inference.py \ --loglevel DEBUG \ --output-dir demo_prediction \ --return-attentions \ -$demo_scan_dir \ No newline at end of file +$demo_scan_dir diff --git a/setup.cfg b/setup.cfg index 21b5319..fd8bb88 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ author_email = license_file = LICENSE.txt long_description = file: README.md long_description_content_type = text/markdown; charset=UTF-8; variant=GFM -version = 1.2.1 +version = 1.2.2 # url = project_urls = ; Documentation = https://.../docs @@ -23,6 +23,10 @@ platforms = any classifiers = Programming Language :: Python +[easy_install] +find_links = + https://download.pytorch.org/whl/cu117/torch_stable.html + [options] zip_safe = False packages = find: @@ -32,7 +36,6 @@ python_requires = >=3.8 # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in # new major versions. This works if the required packages follow Semantic Versioning. # For more information, check out https://semver.org/. -# Use --find-links https://download.pytorch.org/whl/cu117/torch_stable.html for torch libraries install_requires = importlib-metadata; python_version>="3.8" numpy==1.24.1 diff --git a/sybil/__init__.py b/sybil/__init__.py index bc7b412..a37f041 100644 --- a/sybil/__init__.py +++ b/sybil/__init__.py @@ -20,5 +20,6 @@ from sybil.model import Sybil from sybil.serie import Serie from sybil.utils.visualization import visualize_attentions +import sybil.utils.logging_utils __all__ = ["Sybil", "Serie", "visualize_attentions"] diff --git a/sybil/model.py b/sybil/model.py index 93af540..cda2e9c 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -12,6 +12,8 @@ from sybil.serie import Serie from sybil.models.sybil import SybilNet +from sybil.utils.logging_utils import get_logger +from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info from sybil.utils.metrics import get_survival_metrics @@ -179,9 +181,10 @@ def __init__( Path to calibrator pickle file corresponding with model device: str If provided, will run inference using this device. - By default uses GPU, if available. + By default, uses GPU with the most free memory, if available. """ + self._logger = get_logger() # Download if needed if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE): name_or_path, calibrator_path = download_sybil(name_or_path, cache) @@ -197,15 +200,20 @@ def __init__( if (calibrator_path is not None) and (not os.path.exists(calibrator_path)): raise ValueError(f"Path not found for calibrator {calibrator_path}") - # Set device + # Set device. + # If set manually, use it and stay there. + # Otherwise, pick the most free GPU now and at predict time. + self._device_flexible = True if device is not None: self.device = device + self._device_flexible = False else: - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = get_default_device() self.ensemble = torch.nn.ModuleList() for path in name_or_path: self.ensemble.append(self.load_model(path)) + self.to(self.device) if calibrator_path is not None: self.calibrator = pickle.load(open(calibrator_path, "rb")) @@ -235,12 +243,12 @@ def load_model(self, path): # Remove model from param names state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()} model.load_state_dict(state_dict) # type: ignore - if self.device == "cuda": - model.to("cuda") + if self.device is not None: + model.to(self.device) # Set eval model.eval() - print(f"Loaded model from {path}") + self._logger.info(f"Loaded model from {path}") return model def _calibrate(self, scores: np.ndarray) -> np.ndarray: @@ -305,8 +313,8 @@ def _predict( raise ValueError("Expected a list of Serie objects.") volume = serie.get_volume() - if self.device == "cuda": - volume = volume.cuda() + if self.device is not None: + volume = volume.to(self.device) with torch.no_grad(): out = model(volume) @@ -344,6 +352,12 @@ def predict( Output prediction. See details for :class:`~sybil.model.Prediction`". """ + + if self._device_flexible: + self.device = self._pick_device() + self.to(self.device) + self._logger.debug(f"Beginning prediction on device: {self.device}") + scores = [] attentions_ = [] if return_attentions else None attention_keys = None @@ -419,3 +433,39 @@ def evaluate( c_index = float(out["c_index"]) return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions) + + def to(self, device: str): + """Move model to device. + + Parameters + ---------- + device : str + Device to move model to. + """ + self.device = device + self.ensemble.to(device) + + def _pick_device(self): + """ + Pick the device to run inference on. + This is based on the device with the most free memory, with a preference for remaining + on the current device. + + Motivation is to enable multiprocessing without the processes needed to communicate. + """ + if not torch.cuda.is_available(): + return get_default_device() + + # Get size of the model in memory (approximate) + model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters()) + + # Check memory available on current device. + # If it seems like we're the only thing on this GPU, stay. + free_mem, total_mem = get_device_mem_info(self.device) + cur_allocated = total_mem - free_mem + min_to_move = int(1.01 * model_mem) + if cur_allocated < min_to_move: + return self.device + else: + # Otherwise, get the most free GPU + return get_most_free_gpu() diff --git a/sybil/utils/device_utils.py b/sybil/utils/device_utils.py new file mode 100644 index 0000000..8184b4f --- /dev/null +++ b/sybil/utils/device_utils.py @@ -0,0 +1,72 @@ +import itertools +import os +from typing import Union + +import torch + + +def get_default_device(): + if torch.cuda.is_available(): + return get_most_free_gpu() + elif torch.backends.mps.is_available(): + # Not all operations implemented in MPS yet + use_mps = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") == "1" + if use_mps: + return torch.device('mps') + else: + return torch.device('cpu') + else: + return torch.device('cpu') + + +def get_available_devices(num_devices=None, max_devices=None): + device = get_default_device() + if device.type == "cuda": + num_gpus = torch.cuda.device_count() + if max_devices is not None: + num_gpus = min(num_gpus, max_devices) + gpu_list = [get_device(i) for i in range(num_gpus)] + if num_devices is not None: + cycle_gpu_list = itertools.cycle(gpu_list) + gpu_list = [next(cycle_gpu_list) for _ in range(num_devices)] + return gpu_list + else: + num_devices = num_devices if num_devices else torch.multiprocessing.cpu_count() + num_devices = min(num_devices, max_devices) if max_devices is not None else num_devices + return [device]*num_devices + + +def get_device(gpu_id: int): + if gpu_id is not None and torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): + return torch.device(f'cuda:{gpu_id}') + else: + return None + + +def get_device_mem_info(device: Union[int, torch.device]): + if not torch.cuda.is_available(): + return None + + free_mem, total_mem = torch.cuda.mem_get_info(device=device) + return free_mem, total_mem + + +def get_most_free_gpu(): + """ + Get the GPU with the most free memory + If system has no GPUs (or CUDA not available), return None + """ + if not torch.cuda.is_available(): + return None + + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + return None + + most_free_idx, most_free_val = -1, -1 + for i in range(num_gpus): + free_mem, total_mem = get_device_mem_info(i) + if free_mem > most_free_val: + most_free_idx, most_free_val = i, free_mem + + return torch.device(f'cuda:{most_free_idx}') diff --git a/sybil/utils/logging_utils.py b/sybil/utils/logging_utils.py new file mode 100644 index 0000000..40cf353 --- /dev/null +++ b/sybil/utils/logging_utils.py @@ -0,0 +1,68 @@ +import logging +import os + + +LOGGER_NAME = "sybil" +LOGLEVEL_KEY = "LOG_LEVEL" + + +def _get_formatter(loglevel="INFO"): + warn_fmt = "[%(asctime)s] %(levelname)s -%(message)s" + debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s" + fmt = debug_fmt if loglevel.upper() in {"DEBUG"} else warn_fmt + return logging.Formatter( + fmt=fmt, + datefmt="%Y-%b-%d %H:%M:%S %Z", + ) + + +def remove_all_handlers(logger): + while logger.hasHandlers(): + logger.removeHandler(logger.handlers[0]) + + +def configure_logger(loglevel=None, logger_name=LOGGER_NAME, logfile=None): + """Do basic logger configuration and set our main logger""" + + # Set as environment variable so other processes can retrieve it + if loglevel is None: + loglevel = os.environ.get(LOGLEVEL_KEY, "WARNING") + else: + os.environ[LOGLEVEL_KEY] = loglevel + + logger = logging.getLogger(logger_name) + logger.setLevel(loglevel) + remove_all_handlers(logger) + logger.propagate = False + + formatter = _get_formatter(loglevel) + def _prep_handler(handler): + for ex_handler in logger.handlers: + if type(ex_handler) == type(handler): + # Remove old handler, don't want to double-handle + logger.removeHandler(ex_handler) + handler.setLevel(loglevel) + handler.setFormatter(formatter) + logger.addHandler(handler) + + sh = logging.StreamHandler() + _prep_handler(sh) + + if logfile is not None: + fh = logging.FileHandler(logfile, mode="a") + _prep_handler(fh) + + return logger + + +def get_logger(base_name=LOGGER_NAME): + """ + Return a logger. + Use a different logger in each subprocess, though they should all have the same log level. + """ + pid = os.getpid() + logger_name = f"{base_name}-process-{pid}" + logger = logging.getLogger(logger_name) + if not logger.hasHandlers(): + configure_logger(logger_name=logger_name) + return logger