diff --git a/pytorch_toolbelt/__init__.py b/pytorch_toolbelt/__init__.py index 586eb0a4c..820a7b0ba 100644 --- a/pytorch_toolbelt/__init__.py +++ b/pytorch_toolbelt/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import -__version__ = "0.5.1" +__version__ = "0.5.2" diff --git a/pytorch_toolbelt/inference/ensembling.py b/pytorch_toolbelt/inference/ensembling.py index 108e59f86..7422b2cd1 100644 --- a/pytorch_toolbelt/inference/ensembling.py +++ b/pytorch_toolbelt/inference/ensembling.py @@ -13,18 +13,22 @@ class ApplySoftmaxTo(nn.Module): dim: int def __init__( - self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", dim: int = 1, temperature: float = 1 + self, + model: nn.Module, + output_key: Union[str, int, Iterable[str]] = "logits", + dim: int = 1, + temperature: float = 1, ): """ Apply softmax activation on given output(s) of the model :param model: Model to wrap - :param output_key: string or list of strings, indicating to what outputs softmax activation should be applied. + :param output_key: string, index or list of strings, indicating to what outputs softmax activation should be applied. :param dim: Tensor dimension for softmax activation :param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper. """ super().__init__() # By converting to set, we prevent double-activation by passing output_key=["logits", "logits"] - output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key]) + output_key = (output_key,) if isinstance(output_key, (str, int)) else tuple(set(output_key)) self.output_keys = output_key self.model = model self.dim = dim @@ -41,16 +45,16 @@ class ApplySigmoidTo(nn.Module): output_keys: Tuple temperature: float - def __init__(self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", temperature=1): + def __init__(self, model: nn.Module, output_key: Union[str, int, Iterable[str]] = "logits", temperature=1): """ Apply sigmoid activation on given output(s) of the model :param model: Model to wrap - :param output_key: string or list of strings, indicating to what outputs sigmoid activation should be applied. + :param output_key: string index, or list of strings, indicating to what outputs sigmoid activation should be applied. :param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper. """ super().__init__() # By converting to set, we prevent double-activation by passing output_key=["logits", "logits"] - output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key]) + output_key = (output_key,) if isinstance(output_key, (str, int)) else tuple(set(output_key)) self.output_keys = output_key self.model = model self.temperature = temperature diff --git a/pytorch_toolbelt/utils/fs.py b/pytorch_toolbelt/utils/fs.py index f9f52d495..b40bdc331 100644 --- a/pytorch_toolbelt/utils/fs.py +++ b/pytorch_toolbelt/utils/fs.py @@ -14,8 +14,10 @@ "auto_file", "change_extension", "find_images_in_dir", + "find_images_in_dir_recursive", "find_in_dir", "find_in_dir_glob", + "find_in_dir_with_ext", "find_subdirectories_in_dir", "has_ext", "has_image_ext", @@ -64,7 +66,15 @@ def find_in_dir_with_ext(dirname: str, extensions: Union[str, List[str]]) -> Lis def find_images_in_dir(dirname: str) -> List[str]: - return [fname for fname in find_in_dir(dirname) if has_image_ext(fname)] + return [fname for fname in find_in_dir(dirname) if has_image_ext(fname) and os.path.isfile(fname)] + + +def find_images_in_dir_recursive(dirname: str) -> List[str]: + return [ + fname + for fname in glob.glob(os.path.join(dirname, "**"), recursive=True) + if has_image_ext(fname) and os.path.isfile(fname) + ] def find_in_dir_glob(dirname: str, recursive=False): @@ -76,13 +86,17 @@ def id_from_fname(fname: str) -> str: return os.path.splitext(os.path.basename(fname))[0] -def change_extension(fname: Union[str, Path], new_ext: str) -> str: - if type(fname) == str: +def change_extension(fname: Union[str, Path], new_ext: str) -> Union[str, Path]: + if isinstance(fname, str): return os.path.splitext(fname)[0] + new_ext - else: + elif isinstance(fname, Path): if new_ext[0] != ".": new_ext = "." + new_ext return fname.with_suffix(new_ext) + else: + raise RuntimeError( + f"Received input argument `fname` for unsupported type {type(fname)}. Argument must be string or Path." + ) def auto_file(filename: str, where: str = ".") -> str: diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index ee6dcf3f4..03380ee8d 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -34,6 +34,7 @@ "to_tensor", "transfer_weights", "move_to_device_non_blocking", + "describe_outputs", ] @@ -102,8 +103,11 @@ def count_parameters( parameters = {"total": total, "trainable": trainable} for key in keys: - if hasattr(model, key) and model.__getattr__(key) is not None: - parameters[key] = int(sum(p.numel() for p in model.__getattr__(key).parameters())) + try: + if hasattr(model, key) and model.__getattr__(key) is not None: + parameters[key] = int(sum(p.numel() for p in model.__getattr__(key).parameters())) + except AttributeError: + pass if human_friendly: for key in parameters.keys(): @@ -289,3 +293,27 @@ def move_to_device_non_blocking(x: Tensor, device: torch.device) -> Tensor: resize_as = resize_like + + +def describe_outputs(outputs: Union[Tensor, Dict[str, Tensor], Iterable[Tensor]]) -> Union[List[Dict], Dict[str, Any]]: + """ + Describe outputs and return shape, mean & std for each tensor in list or dict (Supports nested tensors) + + Args: + outputs: Input (Usually model outputs) + Returns: + Same structure but each item represents tensor shape, mean & std + """ + if torch.is_tensor(outputs): + desc = dict(size=tuple(outputs.size()), mean=outputs.mean().item(), std=outputs.std().item()) + elif isinstance(outputs, collections.Mapping): + desc = {} + for key, value in outputs.items(): + desc[key] = describe_outputs(value) + elif isinstance(outputs, collections.Iterable): + desc = [] + for index, output in enumerate(outputs): + desc.append(describe_outputs(output)) + else: + raise NotImplementedError + return desc