Skip to content

Commit

Permalink
Merge pull request #76 from BloodAxe/develop
Browse files Browse the repository at this point in the history
0.5.2 Release
  • Loading branch information
BloodAxe authored Aug 26, 2022
2 parents ee72463 + f1114e8 commit 46f189f
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = "0.5.1"
__version__ = "0.5.2"
16 changes: 10 additions & 6 deletions pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@ class ApplySoftmaxTo(nn.Module):
dim: int

def __init__(
self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", dim: int = 1, temperature: float = 1
self,
model: nn.Module,
output_key: Union[str, int, Iterable[str]] = "logits",
dim: int = 1,
temperature: float = 1,
):
"""
Apply softmax activation on given output(s) of the model
:param model: Model to wrap
:param output_key: string or list of strings, indicating to what outputs softmax activation should be applied.
:param output_key: string, index or list of strings, indicating to what outputs softmax activation should be applied.
:param dim: Tensor dimension for softmax activation
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
"""
super().__init__()
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key])
output_key = (output_key,) if isinstance(output_key, (str, int)) else tuple(set(output_key))
self.output_keys = output_key
self.model = model
self.dim = dim
Expand All @@ -41,16 +45,16 @@ class ApplySigmoidTo(nn.Module):
output_keys: Tuple
temperature: float

def __init__(self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", temperature=1):
def __init__(self, model: nn.Module, output_key: Union[str, int, Iterable[str]] = "logits", temperature=1):
"""
Apply sigmoid activation on given output(s) of the model
:param model: Model to wrap
:param output_key: string or list of strings, indicating to what outputs sigmoid activation should be applied.
:param output_key: string index, or list of strings, indicating to what outputs sigmoid activation should be applied.
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
"""
super().__init__()
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key])
output_key = (output_key,) if isinstance(output_key, (str, int)) else tuple(set(output_key))
self.output_keys = output_key
self.model = model
self.temperature = temperature
Expand Down
22 changes: 18 additions & 4 deletions pytorch_toolbelt/utils/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
"auto_file",
"change_extension",
"find_images_in_dir",
"find_images_in_dir_recursive",
"find_in_dir",
"find_in_dir_glob",
"find_in_dir_with_ext",
"find_subdirectories_in_dir",
"has_ext",
"has_image_ext",
Expand Down Expand Up @@ -64,7 +66,15 @@ def find_in_dir_with_ext(dirname: str, extensions: Union[str, List[str]]) -> Lis


def find_images_in_dir(dirname: str) -> List[str]:
return [fname for fname in find_in_dir(dirname) if has_image_ext(fname)]
return [fname for fname in find_in_dir(dirname) if has_image_ext(fname) and os.path.isfile(fname)]


def find_images_in_dir_recursive(dirname: str) -> List[str]:
return [
fname
for fname in glob.glob(os.path.join(dirname, "**"), recursive=True)
if has_image_ext(fname) and os.path.isfile(fname)
]


def find_in_dir_glob(dirname: str, recursive=False):
Expand All @@ -76,13 +86,17 @@ def id_from_fname(fname: str) -> str:
return os.path.splitext(os.path.basename(fname))[0]


def change_extension(fname: Union[str, Path], new_ext: str) -> str:
if type(fname) == str:
def change_extension(fname: Union[str, Path], new_ext: str) -> Union[str, Path]:
if isinstance(fname, str):
return os.path.splitext(fname)[0] + new_ext
else:
elif isinstance(fname, Path):
if new_ext[0] != ".":
new_ext = "." + new_ext
return fname.with_suffix(new_ext)
else:
raise RuntimeError(
f"Received input argument `fname` for unsupported type {type(fname)}. Argument must be string or Path."
)


def auto_file(filename: str, where: str = ".") -> str:
Expand Down
32 changes: 30 additions & 2 deletions pytorch_toolbelt/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"to_tensor",
"transfer_weights",
"move_to_device_non_blocking",
"describe_outputs",
]


Expand Down Expand Up @@ -102,8 +103,11 @@ def count_parameters(
parameters = {"total": total, "trainable": trainable}

for key in keys:
if hasattr(model, key) and model.__getattr__(key) is not None:
parameters[key] = int(sum(p.numel() for p in model.__getattr__(key).parameters()))
try:
if hasattr(model, key) and model.__getattr__(key) is not None:
parameters[key] = int(sum(p.numel() for p in model.__getattr__(key).parameters()))
except AttributeError:
pass

if human_friendly:
for key in parameters.keys():
Expand Down Expand Up @@ -289,3 +293,27 @@ def move_to_device_non_blocking(x: Tensor, device: torch.device) -> Tensor:


resize_as = resize_like


def describe_outputs(outputs: Union[Tensor, Dict[str, Tensor], Iterable[Tensor]]) -> Union[List[Dict], Dict[str, Any]]:
"""
Describe outputs and return shape, mean & std for each tensor in list or dict (Supports nested tensors)
Args:
outputs: Input (Usually model outputs)
Returns:
Same structure but each item represents tensor shape, mean & std
"""
if torch.is_tensor(outputs):
desc = dict(size=tuple(outputs.size()), mean=outputs.mean().item(), std=outputs.std().item())
elif isinstance(outputs, collections.Mapping):
desc = {}
for key, value in outputs.items():
desc[key] = describe_outputs(value)
elif isinstance(outputs, collections.Iterable):
desc = []
for index, output in enumerate(outputs):
desc.append(describe_outputs(output))
else:
raise NotImplementedError
return desc

0 comments on commit 46f189f

Please sign in to comment.