-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #69 from BloodAxe/develop
Release of pytorch-toolbelt 0.5
- Loading branch information
Showing
42 changed files
with
2,129 additions
and
466 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ var/ | |
.idea/ | ||
.pytest_cache/ | ||
/tests/tta_eval.csv | ||
/tests/tmp.onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from __future__ import absolute_import | ||
|
||
__version__ = "0.4.4" | ||
__version__ = "0.5.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
from .classification import * | ||
from .segmentation import * | ||
from .wrappers import * | ||
from .mean_std import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
from typing import Optional, Tuple | ||
|
||
__all__ = ["DatasetMeanStdCalculator"] | ||
|
||
|
||
class DatasetMeanStdCalculator: | ||
__slots__ = ["global_mean", "global_var", "n_items", "num_channels", "global_max", "global_min"] | ||
|
||
""" | ||
Class to calculate running mean and std of the dataset. It helps when whole dataset does not fit entirely in RAM. | ||
""" | ||
|
||
def __init__(self, num_channels: int = 3): | ||
""" | ||
Create a new instance of DatasetMeanStdCalculator | ||
Args: | ||
num_channels: Number of channels in the image. Default value is 3 | ||
""" | ||
super(DatasetMeanStdCalculator, self).__init__() | ||
self.num_channels = num_channels | ||
self.global_mean = None | ||
self.global_var = None | ||
self.global_max = None | ||
self.global_min = None | ||
self.n_items = 0 | ||
self.reset() | ||
|
||
def reset(self): | ||
self.global_mean = np.zeros(self.num_channels, dtype=np.float64) | ||
self.global_var = np.zeros(self.num_channels, dtype=np.float64) | ||
self.global_max = np.ones_like(self.global_mean) * float("-inf") | ||
self.global_min = np.ones_like(self.global_mean) * float("+inf") | ||
self.n_items = 0 | ||
|
||
def accumulate(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> None: | ||
""" | ||
Compute mean and std of a single image and integrates it into global statistics | ||
Args: | ||
image: Input image (Must be HWC, with number of channels C equal to self.num_channels) | ||
mask: Optional mask to include only certain parts of image from statistics computation. | ||
Only non-zero elements will be included, | ||
""" | ||
if len(image.shape) == 2: | ||
image = np.expand_dims(image, axis=-1) | ||
|
||
if self.num_channels != image.shape[2]: | ||
raise RuntimeError(f"Number of channels in image must be {self.num_channels}, got {image.shape[2]}.") | ||
image = image.reshape((-1, self.num_channels)) | ||
|
||
if mask is not None: | ||
mask = mask.reshape((mask.shape[0] * mask.shape[1], 1)) | ||
image = image[mask] | ||
|
||
# In case the whole image is masked out, we exclude it entirely | ||
if len(image) == 0: | ||
return | ||
|
||
mean = np.mean(image, axis=0) | ||
std = np.std(image, axis=0) | ||
|
||
self.global_mean += np.squeeze(mean) | ||
self.global_var += np.squeeze(std) ** 2 | ||
self.global_max = np.maximum(self.global_max, np.max(image, axis=0)) | ||
self.global_min = np.minimum(self.global_min, np.min(image, axis=0)) | ||
self.n_items += 1 | ||
|
||
def compute(self) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Compute dataset-level mean & std | ||
Returns: | ||
Tuple of global [mean, std] per channel | ||
""" | ||
return self.global_mean / self.n_items, np.sqrt(self.global_var / self.n_items) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.