diff --git a/setup.cfg b/setup.cfg index c76970a..fc57e6f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,10 +7,10 @@ author_email = license_file = LICENSE.txt long_description = file: README.md long_description_content_type = text/markdown; charset=UTF-8; variant=GFM -version = 1.3.0 +version = 1.4.0 # url = project_urls = - ; Documentation = https://.../docs + Documentation = https://github.com/reginabarzilaygroup/sybil/wiki Source = https://github.com/reginabarzilaygroup/sybil Tracker = https://github.com/reginabarzilaygroup/sybil/issues diff --git a/sybil/model.py b/sybil/model.py index 74899e7..10fb857 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -116,6 +116,21 @@ def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]: return all_files_and_dirs +def _torch_set_num_threads(threads) -> int: + """ + Set the number of CPU threads for torch to use. + Set to a negative number for no-op. + Set to 0 for the number of CPUs. + """ + if threads < 0: + return torch.get_num_threads() + if threads is None or threads == 0: + threads = os.cpu_count() + + torch.set_num_threads(threads) + return torch.get_num_threads() + + class Sybil: def __init__( self, @@ -294,7 +309,7 @@ def _predict( return Prediction(scores=scores, attentions=attentions) def predict( - self, series: Union[Serie, List[Serie]], return_attentions: bool = False + self, series: Union[Serie, List[Serie]], return_attentions: bool = False, threads=0, ) -> Prediction: """Run predictions over the given serie(s) and ensemble @@ -304,6 +319,8 @@ def predict( One or multiple series to run predictions for. return_attentions : bool If True, returns attention scores for each serie. See README for details. + threads : int + Number of CPU threads to use for PyTorch inference. Default is 0 (use all available cores). Returns ------- @@ -312,6 +329,10 @@ def predict( """ + # Set CPU threads available to torch + num_threads = _torch_set_num_threads(threads) + self._logger.debug(f"Using {num_threads} threads for PyTorch inference") + if self._device_flexible: self.device = self._pick_device() self.to(self.device) diff --git a/sybil/predict.py b/sybil/predict.py index 703fb3c..abc2679 100644 --- a/sybil/predict.py +++ b/sybil/predict.py @@ -69,6 +69,11 @@ def _get_parser(): parser.add_argument("-l", "--log", "--loglevel", "--log-level", default="INFO", dest="loglevel") + parser.add_argument('--threads', type=int, default=0, + help="Number of threads to use for PyTorch inference. " + "Default is 0 (use all available cores)." + "Set to a negative number to use Pytorch default.") + parser.add_argument("-v", "--version", action="version", version=__version__) return parser @@ -81,6 +86,7 @@ def predict( return_attentions=False, write_attention_images=False, file_type: Literal["auto", "dicom", "png"] = "auto", + threads: int = 0, ): logger = sybil.utils.logging_utils.get_logger() @@ -115,7 +121,7 @@ def predict( # Get risk scores serie = Serie(input_files, voxel_spacing=voxel_spacing, file_type=file_type) series = [serie] - prediction = model.predict(series, return_attentions=return_attentions) + prediction = model.predict(series, return_attentions=return_attentions, threads=threads) prediction_scores = prediction.scores[0] logger.debug(f"Prediction finished. Results:\n{prediction_scores}") @@ -155,6 +161,7 @@ def main(): args.return_attentions, args.write_attention_images, file_type=args.file_type, + threads=args.threads, ) print(json.dumps(pred_dict, indent=2))