diff --git a/comet/models/base.py b/comet/models/base.py index 8fadd0b..a730704 100644 --- a/comet/models/base.py +++ b/comet/models/base.py @@ -21,6 +21,7 @@ import abc import logging import os +import sys import warnings from typing import Dict, List, Optional, Tuple, Union @@ -592,9 +593,16 @@ def predict( sort_ids = np.argsort([len(sample["ref"]) for sample in samples]) sampler = OrderedSampler(sort_ids) + # On Windows, only num_workers=0 is supported. + is_windows = os.name == "nt" if num_workers is None: # Guideline for workers that typically works well. - num_workers = 2 * gpus + num_workers = 0 if is_windows else 2 * gpus + elif is_windows and num_workers != 0: + logger.warning( + "Due to limits of multiprocessing on Windows, it is likely that setting num_workers > 0 will result" + " in scores of 0. It is therefore recommended to set num_workers=0 or leave it to None (default)." + ) self.eval() dataloader = DataLoader(