Skip to content

Commit

Permalink
Merge pull request #63 from BloodAxe/develop
Browse files Browse the repository at this point in the history
Pytorch Toolbelt 0.4.4
  • Loading branch information
BloodAxe authored Aug 12, 2021
2 parents f3acfca + 6438a50 commit 4a24e63
Show file tree
Hide file tree
Showing 71 changed files with 1,538 additions and 1,445 deletions.
35 changes: 35 additions & 0 deletions .github/ISSUE_TEMPLATE/bug-report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
name: "Bug Report"
about: Submit a bug report to help us improve pytorch-toolbelt.

---

## 🐛 Bug

<!-- A clear and concise description of what the bug is. -->

## To Reproduce

Steps to reproduce the behavior:

1.
1.
1.

<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->

## Expected behavior

<!-- A clear and concise description of what you expected to happen. -->

## Environment

- Pytorch-toolbelt version (e.g., 0.4.4):
- Pytorch version (e.g., 1.8.1):
- Python version (e.g., 3.7):
- OS (e.g., Linux):
- Any other relevant information:

## Additional context

<!-- Add any other context about the problem here. -->
64 changes: 64 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: CI
on: [push, pull_request]

jobs:
test_and_lint:
name: Test and lint
runs-on: ${{ matrix.operating-system }}
strategy:
matrix:
operating-system: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
pytorch-toolbelt-version: [tests]
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install PyTorch on Linux and Windows
if: >
matrix.operating-system == 'ubuntu-latest' ||
matrix.operating-system == 'windows-latest'
run: >
pip install torch==1.8.1+cpu torchvision==0.9.1+cpu
-f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch on MacOS
if: matrix.operating-system == 'macos-latest'
run: pip install torch==1.8.1 torchvision==0.9.1
- name: Install dependencies
run: pip install .[${{ matrix.pytorch-toolbelt-version }}]
- name: Install linters
run: pip install flake8==3.8.4 flake8-docstrings==1.5.0
- name: Run PyTest
run: pytest
- name: Run Flake8
run: flake8
# - name: Install MyPy
# run: pip install mypy==0.800
# - name: Run mypy
# run: mypy .

check_code_formatting:
name: Check code formatting with Black
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install Black
run: pip install black==20.8b1
- name: Run Black
run: black --config=black.toml --check .
26 changes: 26 additions & 0 deletions .github/workflows/upload_to_pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Upload release to PyPI

on:
release:
types: [published]

jobs:
upload:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
27 changes: 0 additions & 27 deletions .travis.yml

This file was deleted.

4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Pytorch-toolbelt

[![Build Status](https://travis-ci.org/BloodAxe/pytorch-toolbelt.svg?branch=develop)](https://travis-ci.org/BloodAxe/pytorch-toolbelt)
[![Documentation Status](https://readthedocs.org/projects/pytorch-toolbelt/badge/?version=latest)](https://pytorch-toolbelt.readthedocs.io/en/latest/?badge=latest)
[![DeepSource](https://static.deepsource.io/deepsource-badge-light-mini.svg)](https://deepsource.io/gh/BloodAxe/pytorch-toolbelt/?ref=repository-badge)

A `pytorch-toolbelt` is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming:

Expand Down Expand Up @@ -223,4 +221,4 @@ merged_mask = tiler.crop_to_orignal_size(merged_mask)
howpublished = {\url{https://github.com/BloodAxe/pytorch-toolbelt}},
commit = {cc5e9973cdb0dcbf1c6b6e1401bf44b9c69e13f3}
}
```
```
2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = "0.4.3"
__version__ = "0.4.4"
9 changes: 5 additions & 4 deletions pytorch_toolbelt/datasets/classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, List
from typing import Optional, List, Union, Callable

import albumentations as A
import numpy as np
import torch
from torch.utils.data import Dataset

Expand All @@ -22,10 +23,10 @@ class ClassificationDataset(Dataset):
def __init__(
self,
image_filenames: List[str],
labels: Optional[List[str]],
labels: Optional[Union[List[int], np.ndarray]],
transform: A.Compose,
read_image_fn=read_image_rgb,
make_target_fn=label_to_tensor,
read_image_fn: Callable = read_image_rgb,
make_target_fn: Callable = label_to_tensor,
):
if labels is not None and len(image_filenames) != len(labels):
raise ValueError("Number of images does not corresponds to number of targets")
Expand Down
20 changes: 10 additions & 10 deletions pytorch_toolbelt/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,32 @@ def name_for_stride(name, stride: int):
return f"{name}_{stride}"


INPUT_INDEX_KEY = "index"
INPUT_IMAGE_KEY = "image"
INPUT_IMAGE_ID_KEY = "image_id"
INPUT_INDEX_KEY = "INPUT_INDEX_KEY"
INPUT_IMAGE_KEY = "INPUT_IMAGE_KEY"
INPUT_IMAGE_ID_KEY = "INPUT_IMAGE_ID_KEY"

TARGET_MASK_WEIGHT_KEY = "true_weights"
TARGET_CLASS_KEY = "true_class"
TARGET_LABELS_KEY = "true_labels"
TARGET_MASK_WEIGHT_KEY = "TARGET_MASK_WEIGHT_KEY"
TARGET_CLASS_KEY = "TARGET_CLASS_KEY"
TARGET_LABELS_KEY = "TARGET_LABELS_KEY"

TARGET_MASK_KEY = "true_mask"
TARGET_MASK_KEY = "TARGET_MASK_KEY"
TARGET_MASK_2_KEY = name_for_stride(TARGET_MASK_KEY, 2)
TARGET_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
TARGET_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
TARGET_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
TARGET_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
TARGET_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)

OUTPUT_MASK_KEY = "pred_mask"
OUTPUT_MASK_KEY = "OUTPUT_MASK_KEY"
OUTPUT_MASK_2_KEY = name_for_stride(OUTPUT_MASK_KEY, 2)
OUTPUT_MASK_4_KEY = name_for_stride(OUTPUT_MASK_KEY, 4)
OUTPUT_MASK_8_KEY = name_for_stride(OUTPUT_MASK_KEY, 8)
OUTPUT_MASK_16_KEY = name_for_stride(OUTPUT_MASK_KEY, 16)
OUTPUT_MASK_32_KEY = name_for_stride(OUTPUT_MASK_KEY, 32)
OUTPUT_MASK_64_KEY = name_for_stride(OUTPUT_MASK_KEY, 64)

OUTPUT_LOGITS_KEY = "pred_logits"
OUTPUT_EMBEDDINGS_KEY = "pred_embeddings"
OUTPUT_LOGITS_KEY = "OUTPUT_LOGITS_KEY"
OUTPUT_EMBEDDINGS_KEY = "OUTPUT_EMBEDDINGS_KEY"


def read_image_rgb(fname: str):
Expand Down
1 change: 0 additions & 1 deletion pytorch_toolbelt/datasets/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def read_binary_mask(mask_fname: str) -> np.ndarray:
Returns:
Numpy array with {0,1} values
"""

mask = cv2.imread(mask_fname, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise FileNotFoundError(f"Cannot find {mask_fname}")
Expand Down
14 changes: 14 additions & 0 deletions pytorch_toolbelt/datasets/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

__all__ = ["RandomSubsetDataset", "RandomSubsetWithMaskDataset"]

from torch.utils.data.dataloader import default_collate


class RandomSubsetDataset(Dataset):
"""
Expand All @@ -23,6 +25,12 @@ def __getitem__(self, _) -> Any:
index = random.randrange(len(self.dataset))
return self.dataset[index]

def get_collate_fn(self):
get_collate_fn = getattr(self.dataset, "get_collate_fn", None)
if callable(get_collate_fn):
return get_collate_fn()
return default_collate()


class RandomSubsetWithMaskDataset(Dataset):
"""
Expand Down Expand Up @@ -53,3 +61,9 @@ def __len__(self) -> int:
def __getitem__(self, _) -> Any:
index = random.choice(self.indexes)
return self.dataset[index]

def get_collate_fn(self):
get_collate_fn = getattr(self.dataset, "get_collate_fn", None)
if callable(get_collate_fn):
return get_collate_fn()
return default_collate()
5 changes: 5 additions & 0 deletions pytorch_toolbelt/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .ensembling import *
from .functional import *
from .tiles import *
from .tiles_3d import *
from .tta import *
61 changes: 49 additions & 12 deletions pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from torch import nn, Tensor
from typing import List, Union, Iterable, Optional
from typing import List, Union, Iterable, Optional, Dict

__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput"]
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex"]

from pytorch_toolbelt.inference.tta import _deaugment_averaging

Expand Down Expand Up @@ -54,6 +54,8 @@ def forward(self, *input, **kwargs): # skipcq: PYL-W0221


class Ensembler(nn.Module):
__slots__ = ["outputs", "reduction", "return_some_outputs"]

"""
Compute sum (or average) of outputs of several models.
"""
Expand All @@ -67,38 +69,73 @@ def __init__(self, models: List[nn.Module], reduction: str = "mean", outputs: Op
If None, all outputs from the first model will be used.
"""
super().__init__()
self.outputs = outputs
self.return_some_outputs = outputs is not None
self.outputs = tuple(outputs) if outputs else tuple()
self.models = nn.ModuleList(models)
self.reduction = reduction

def forward(self, *input, **kwargs): # skipcq: PYL-W0221
outputs = [model(*input, **kwargs) for model in self.models]

if self.outputs:
if self.return_some_outputs:
keys = self.outputs
else:
elif isinstance(outputs[0], dict):
keys = outputs[0].keys()
elif torch.is_tensor(outputs[0]):
keys = None
else:
raise RuntimeError()

averaged_output = {}
for key in keys:
predictions = [output[key] for output in outputs]
predictions = torch.stack(predictions)
if keys is None:
predictions = torch.stack(outputs)
predictions = _deaugment_averaging(predictions, self.reduction)
averaged_output[key] = predictions
averaged_output = predictions
else:
averaged_output = {}
for key in keys:
predictions = [output[key] for output in outputs]
predictions = torch.stack(predictions)
predictions = _deaugment_averaging(predictions, self.reduction)
averaged_output[key] = predictions

return averaged_output


class PickModelOutput(nn.Module):
"""
Assuming you have a model that outputs a dictionary, this module returns only a given element by it's key
Wraps a model that returns dict or list and returns only a specific element.
Usage example:
>>> model = MyAwesomeSegmentationModel() # Returns dict {"OUTPUT_MASK": Tensor, ...}
>>> net = nn.Sequential(PickModelOutput(model, "OUTPUT_MASK")), nn.Sigmoid())
"""

def __init__(self, model: nn.Module, key: str):
__slots__ = ["target_key"]

def __init__(self, model: nn.Module, key: Union[str, int]):
super().__init__()
self.model = model
self.target_key = key

def forward(self, *input, **kwargs) -> Tensor:
output = self.model(*input, **kwargs)
return output[self.target_key]


class SelectByIndex(nn.Module):
"""
Select a single Tensor from the dict or list of output tensors.
Usage example:
>>> model = MyAwesomeSegmentationModel() # Returns dict {"OUTPUT_MASK": Tensor, ...}
>>> net = nn.Sequential(model, SelectByIndex("OUTPUT_MASK"), nn.Sigmoid())
"""

__slots__ = ["target_key"]

def __init__(self, key: Union[str, int]):
super().__init__()
self.target_key = key

def forward(self, outputs: Dict[str, Tensor]) -> Tensor:
return outputs[self.target_key]
Loading

0 comments on commit 4a24e63

Please sign in to comment.