Skip to content

Commit

Permalink
Add deterministic version of the FBeta class (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
freud14 authored May 21, 2023
1 parent b368220 commit 003b37d
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 109 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# v1.x.x
# v1.17

-
- [`FBeta`](https://poutyne.org/metrics.html#poutyne.FBeta) is using the non-deterministic torch function [`bincount`](https://pytorch.org/docs/stable/generated/torch.bincount.html). Either by passing the argument `make_deterministic` to the [`FBeta`](https://poutyne.org/metrics.html#poutyne.FBeta) class or by using one of the PyTorch functions `torch.set_deterministic_debug_mode` or `torch.use_deterministic_algorithms`, you can now make this function deterministic. Note that this might make your code slower.

# v1.16

Expand Down
67 changes: 67 additions & 0 deletions poutyne/framework/metrics/predefined/bincount.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
The source code of this file was copied from the torchmetrics project, and has been modified. All modifications
made from the original source code are under the LGPLv3 license.
Copyright (c) 2022 Poutyne and all respective contributors.
Each contributor holds copyright over their respective contributions. The project versioning (Git)
records all such contribution source information on the Poutyne and AllenNLP repository.
This file is part of Poutyne.
Poutyne is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
version.
Poutyne is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with Poutyne. If not, see
<https://www.gnu.org/licenses/>.
Copyright The PyTorch Lightning team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional

import torch
from torch import Tensor
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE


def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""PyTorch currently does not support``torch.bincount`` for:
- deterministic mode on GPU.
- MPS devices
This implementation fallback to a for-loop counting occurrences in that case.
Args:
x: tensor to count
minlength: minimum length to count
Returns:
Number of occurrences for each unique element in x
"""
if minlength is None:
minlength = len(torch.unique(x))
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
for i in range(minlength):
output[i] = (x == i).sum()
return output
return torch.bincount(x, minlength=minlength)
157 changes: 83 additions & 74 deletions poutyne/framework/metrics/predefined/fscores.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from poutyne.framework.metrics.base import Metric
from poutyne.framework.metrics.metrics_registering import register_metric_class
from poutyne.framework.metrics.predefined.bincount import _bincount
from poutyne.utils import set_deterministic_debug_mode


class FBeta(Metric):
Expand Down Expand Up @@ -115,6 +117,8 @@ class FBeta(Metric):
names (Optional[Union[str, List[str]]]): The names associated to the metrics. It is a string when
a single metric is requested. It is a list of 3 strings if all metrics are requested.
(Default value = None)
make_deterministic (Optional[bool]): Avoid non-deterministic operations in computations. This might make the
code slower.
"""

def __init__(
Expand All @@ -127,6 +131,7 @@ def __init__(
ignore_index: int = -100,
threshold: float = 0.0,
names: Optional[Union[str, List[str]]] = None,
make_deterministic: Optional[bool] = None,
) -> None:
super().__init__()
self.metric_options = ('fscore', 'precision', 'recall')
Expand Down Expand Up @@ -154,6 +159,9 @@ def __init__(
self.ignore_index = ignore_index
self.threshold = threshold
self.__name__ = self._get_names(names)
self.deterministic_debug_mode = (
"error" if make_deterministic is True else "default" if make_deterministic is False else None
)

# statistics
# the total number of true positive instances under each class
Expand Down Expand Up @@ -235,80 +243,81 @@ def update(self, y_pred: torch.Tensor, y_true: Union[torch.Tensor, Tuple[torch.T

def _update(self, y_pred: torch.Tensor, y_true: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> None:
# pylint: disable=too-many-branches
if isinstance(y_true, tuple):
y_true, mask = y_true
mask = mask.bool()
else:
mask = torch.ones_like(y_true).bool()

if self.ignore_index is not None:
mask *= y_true != self.ignore_index

if y_pred.shape[0] == 1:
y_pred, y_true, mask = (
y_pred.squeeze().unsqueeze(0),
y_true.squeeze().unsqueeze(0),
mask.squeeze().unsqueeze(0),
)
else:
y_pred, y_true, mask = y_pred.squeeze(), y_true.squeeze(), mask.squeeze()

num_classes = 2
if y_pred.shape != y_true.shape:
num_classes = y_pred.size(1)

if (y_true >= num_classes).any():
raise ValueError(
f"A gold label passed to FBetaMeasure contains an id >= {num_classes}, the number of classes."
)

if self._average == 'binary' and num_classes > 2:
raise ValueError("When `average` is binary, the number of prediction scores must be 2.")

# It means we call this metric at the first time
# when `self._true_positive_sum` is None.
if self._true_positive_sum is None:
self._true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
self._true_sum = torch.zeros(num_classes, device=y_pred.device)
self._pred_sum = torch.zeros(num_classes, device=y_pred.device)
self._total_sum = torch.zeros(num_classes, device=y_pred.device)

y_true = y_true.float()

if y_pred.shape != y_true.shape:
argmax_y_pred = y_pred.argmax(1).float()
else:
argmax_y_pred = (y_pred > self.threshold).float()
true_positives = (y_true == argmax_y_pred) * mask
true_positives_bins = y_true[true_positives]

# Watch it:
# The total numbers of true positives under all _predicted_ classes are zeros.
if true_positives_bins.shape[0] == 0:
true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
else:
true_positive_sum = torch.bincount(true_positives_bins.long(), minlength=num_classes).float()

pred_bins = argmax_y_pred[mask].long()
# Watch it:
# When the `mask` is all 0, we will get an _empty_ tensor.
if pred_bins.shape[0] != 0:
pred_sum = torch.bincount(pred_bins, minlength=num_classes).float()
else:
pred_sum = torch.zeros(num_classes, device=y_pred.device)

y_true_bins = y_true[mask].long()
if y_true.shape[0] != 0:
true_sum = torch.bincount(y_true_bins, minlength=num_classes).float()
else:
true_sum = torch.zeros(num_classes, device=y_pred.device)

self._true_positive_sum += true_positive_sum
self._pred_sum += pred_sum
self._true_sum += true_sum
self._total_sum += mask.sum().to(torch.float)

return true_positive_sum, pred_sum, true_sum
with set_deterministic_debug_mode(self.deterministic_debug_mode):
if isinstance(y_true, tuple):
y_true, mask = y_true
mask = mask.bool()
else:
mask = torch.ones_like(y_true).bool()

if self.ignore_index is not None:
mask *= y_true != self.ignore_index

if y_pred.shape[0] == 1:
y_pred, y_true, mask = (
y_pred.squeeze().unsqueeze(0),
y_true.squeeze().unsqueeze(0),
mask.squeeze().unsqueeze(0),
)
else:
y_pred, y_true, mask = y_pred.squeeze(), y_true.squeeze(), mask.squeeze()

num_classes = 2
if y_pred.shape != y_true.shape:
num_classes = y_pred.size(1)

if (y_true >= num_classes).any():
raise ValueError(
f"A gold label passed to FBetaMeasure contains an id >= {num_classes}, the number of classes."
)

if self._average == 'binary' and num_classes > 2:
raise ValueError("When `average` is binary, the number of prediction scores must be 2.")

# It means we call this metric at the first time
# when `self._true_positive_sum` is None.
if self._true_positive_sum is None:
self._true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
self._true_sum = torch.zeros(num_classes, device=y_pred.device)
self._pred_sum = torch.zeros(num_classes, device=y_pred.device)
self._total_sum = torch.zeros(num_classes, device=y_pred.device)

y_true = y_true.float()

if y_pred.shape != y_true.shape:
argmax_y_pred = y_pred.argmax(1).float()
else:
argmax_y_pred = (y_pred > self.threshold).float()
true_positives = (y_true == argmax_y_pred) * mask
true_positives_bins = y_true[true_positives]

# Watch it:
# The total numbers of true positives under all _predicted_ classes are zeros.
if true_positives_bins.shape[0] == 0:
true_positive_sum = torch.zeros(num_classes, device=y_pred.device)
else:
true_positive_sum = _bincount(true_positives_bins.long(), minlength=num_classes).float()

pred_bins = argmax_y_pred[mask].long()
# Watch it:
# When the `mask` is all 0, we will get an _empty_ tensor.
if pred_bins.shape[0] != 0:
pred_sum = _bincount(pred_bins, minlength=num_classes).float()
else:
pred_sum = torch.zeros(num_classes, device=y_pred.device)

y_true_bins = y_true[mask].long()
if y_true.shape[0] != 0:
true_sum = _bincount(y_true_bins, minlength=num_classes).float()
else:
true_sum = torch.zeros(num_classes, device=y_pred.device)

self._true_positive_sum += true_positive_sum
self._pred_sum += pred_sum
self._true_sum += true_sum
self._total_sum += mask.sum().to(torch.float)

return true_positive_sum, pred_sum, true_sum

def compute(self) -> Union[float, Tuple[float]]:
"""
Expand Down
15 changes: 14 additions & 1 deletion poutyne/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
<https://www.gnu.org/licenses/>.
"""

import contextlib
import numbers

# -*- coding: utf-8 -*-
import os
import random
import warnings
from typing import IO, Any, BinaryIO, Union
from typing import IO, Any, BinaryIO, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -332,3 +333,15 @@ def is_torch_or_numpy(v):
"tensor or a Numpy array.\n"
)
return 1


@contextlib.contextmanager
def set_deterministic_debug_mode(mode: Optional[Union[str, int]]):
if mode is None:
yield
return

old_mode = torch.get_deterministic_debug_mode()
torch.set_deterministic_debug_mode(mode)
yield
torch.set_deterministic_debug_mode(old_mode)
Loading

0 comments on commit 003b37d

Please sign in to comment.