Skip to content

Commit

Permalink
Set num threads available to torch.
Browse files Browse the repository at this point in the history
Available as command line argument, default is the number of CPUs. Doesn't matter when bare-metal but improves performance substantially in containers.
  • Loading branch information
jsilter committed Jun 12, 2024
1 parent 80aae8f commit 2bd7fcc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.3.0
version = 1.4.0
# url =
project_urls =
; Documentation = https://.../docs
Documentation = https://github.com/reginabarzilaygroup/sybil/wiki
Source = https://github.com/reginabarzilaygroup/sybil
Tracker = https://github.com/reginabarzilaygroup/sybil/issues

Expand Down
23 changes: 22 additions & 1 deletion sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]:
return all_files_and_dirs


def _torch_set_num_threads(threads) -> int:
"""
Set the number of CPU threads for torch to use.
Set to a negative number for no-op.
Set to 0 for the number of CPUs.
"""
if threads < 0:
return torch.get_num_threads()
if threads is None or threads == 0:
threads = os.cpu_count()

torch.set_num_threads(threads)
return torch.get_num_threads()


class Sybil:
def __init__(
self,
Expand Down Expand Up @@ -294,7 +309,7 @@ def _predict(
return Prediction(scores=scores, attentions=attentions)

def predict(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
self, series: Union[Serie, List[Serie]], return_attentions: bool = False, threads=0,
) -> Prediction:
"""Run predictions over the given serie(s) and ensemble
Expand All @@ -304,6 +319,8 @@ def predict(
One or multiple series to run predictions for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
threads : int
Number of CPU threads to use for PyTorch inference. Default is 0 (use all available cores).
Returns
-------
Expand All @@ -312,6 +329,10 @@ def predict(
"""

# Set CPU threads available to torch
num_threads = _torch_set_num_threads(threads)
self._logger.debug(f"Using {num_threads} threads for PyTorch inference")

if self._device_flexible:
self.device = self._pick_device()
self.to(self.device)
Expand Down
9 changes: 8 additions & 1 deletion sybil/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def _get_parser():
parser.add_argument("-l", "--log", "--loglevel", "--log-level",
default="INFO", dest="loglevel")

parser.add_argument('--threads', type=int, default=0,
help="Number of threads to use for PyTorch inference. "
"Default is 0 (use all available cores)."
"Set to a negative number to use Pytorch default.")

parser.add_argument("-v", "--version", action="version", version=__version__)

return parser
Expand All @@ -81,6 +86,7 @@ def predict(
return_attentions=False,
write_attention_images=False,
file_type: Literal["auto", "dicom", "png"] = "auto",
threads: int = 0,
):
logger = sybil.utils.logging_utils.get_logger()

Expand Down Expand Up @@ -115,7 +121,7 @@ def predict(
# Get risk scores
serie = Serie(input_files, voxel_spacing=voxel_spacing, file_type=file_type)
series = [serie]
prediction = model.predict(series, return_attentions=return_attentions)
prediction = model.predict(series, return_attentions=return_attentions, threads=threads)
prediction_scores = prediction.scores[0]

logger.debug(f"Prediction finished. Results:\n{prediction_scores}")
Expand Down Expand Up @@ -155,6 +161,7 @@ def main():
args.return_attentions,
args.write_attention_images,
file_type=args.file_type,
threads=args.threads,
)

print(json.dumps(pred_dict, indent=2))
Expand Down

0 comments on commit 2bd7fcc

Please sign in to comment.