From 8237f2c5d01b5c490915e26427913c317560d035 Mon Sep 17 00:00:00 2001 From: Francisco Rivera Valverde <44504424+franchuterivera@users.noreply.github.com> Date: Thu, 24 Jun 2021 12:28:15 +0200 Subject: [PATCH] Fix_248 (#263) --- autoPyTorch/utils/implementations.py | 12 ++--- test/test_pipeline/test_losses.py | 75 +++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/autoPyTorch/utils/implementations.py b/autoPyTorch/utils/implementations.py index 2130cfd6b..a0b020622 100644 --- a/autoPyTorch/utils/implementations.py +++ b/autoPyTorch/utils/implementations.py @@ -25,17 +25,17 @@ class LossWeightStrategyWeighted(): def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(y, torch.Tensor): y = y.detach().cpu().numpy() if y.is_cuda else y.numpy() - if isinstance(y[0], str): - y = y.astype('float64') counts = np.sum(y, axis=0) total_weight = y.shape[0] - if len(y.shape) > 1: + if len(y.shape) > 1 and y.shape[1] != 1: + # In this case, the second axis represents classes weight_per_class = total_weight / y.shape[1] weights = (np.ones(y.shape[1]) * weight_per_class) / np.maximum(counts, 1) else: + # Numpy unique return the sorted classes. This is desirable as + # weights recieved by PyTorch is a sorted list of classes classes, counts = np.unique(y, axis=0, return_counts=True) - classes, counts = classes[::-1], counts[::-1] weight_per_class = total_weight / classes.shape[0] weights = (np.ones(classes.shape[0]) * weight_per_class) / counts @@ -50,10 +50,8 @@ class LossWeightStrategyWeightedBinary(): def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(y, torch.Tensor): y = y.detach().cpu().numpy() if y.is_cuda else y.numpy() - if isinstance(y[0], str): - y = y.astype('float64') counts_one = np.sum(y, axis=0) - counts_zero = counts_one + (-y.shape[0]) + counts_zero = y.shape[0] - counts_one weights = counts_zero / np.maximum(counts_one, 1) return np.array(weights) diff --git a/test/test_pipeline/test_losses.py b/test/test_pipeline/test_losses.py index 9f23f3e9f..3eeba6a70 100644 --- a/test/test_pipeline/test_losses.py +++ b/test/test_pipeline/test_losses.py @@ -1,3 +1,5 @@ +import numpy as np + import pytest import torch @@ -5,7 +7,11 @@ from torch.nn.modules.loss import _Loss as Loss from autoPyTorch.pipeline.components.training.losses import get_loss, losses -from autoPyTorch.utils.implementations import get_loss_weight_strategy +from autoPyTorch.utils.implementations import ( + LossWeightStrategyWeighted, + LossWeightStrategyWeightedBinary, + get_loss_weight_strategy, +) @pytest.mark.parametrize('output_type', ['multiclass', @@ -66,3 +72,70 @@ def test_loss_dict(): assert isinstance(loss['module'](), Loss) assert 'supported_output_types' in loss.keys() assert isinstance(loss['supported_output_types'], list) + + +@pytest.mark.parametrize('target,expected_weights', [ + ( + # Expected 4 classes where first one is majority one + np.array([[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + # We reduce the contribution of the first class which has double elements + np.array([0.5, 1., 1., 1.]), + ), + ( + # Expected 2 classes -- multilable format + np.array([[1, 0], [1, 0], [1, 0], [0, 1]]), + # We reduce the contribution of the first class which 3 to 1 ratio + np.array([2 / 3, 2]), + ), + ( + # Expected 2 classes -- (-1, 1) format + np.array([[1], [1], [1], [0]]), + # We reduce the contribution of the second class, which has a 3 to 1 ratio + np.array([2, 2 / 3]), + ), + ( + # Expected 2 classes -- single column + # We have to reduce the contribution of the second class with 5 to 1 ratio + np.array([1, 1, 1, 1, 1, 0]), + # We reduce the contribution of the first class which has double elements + np.array([3, 6 / 10]), + ), +]) +def test_lossweightstrategyweighted(target, expected_weights): + weights = LossWeightStrategyWeighted()(target) + np.testing.assert_array_equal(weights, expected_weights) + assert nn.CrossEntropyLoss(weight=torch.Tensor(weights))( + torch.zeros(target.shape[0], len(weights)).float(), + torch.from_numpy(target.argmax(1)).long() if len(target.shape) > 1 + else torch.from_numpy(target).long() + ) > 0 + + +@pytest.mark.parametrize('target,expected_weights', [ + ( + # Expected 2 classes -- multilable format + np.array([[1, 0], [1, 0], [1, 0], [0, 1]]), + # We reduce the contribution of the first class which 3 to 1 ratio + np.array([1 / 3, 3]), + ), + ( + # Expected 2 classes -- (-1, 1) format + np.array([[1], [1], [1], [0]]), + # We reduce the contribution of the second class, which has a 3 to 1 ratio + np.array([1 / 3]), + ), + ( + # Expected 2 classes -- single column + # We have to reduce the contribution of the second class with 5 to 1 ratio + np.array([1, 1, 1, 1, 1, 0]), + # We reduce the contribution of the first class which has double elements + np.array([0.2]), + ), +]) +def test_lossweightstrategyweightedbinary(target, expected_weights): + weights = LossWeightStrategyWeightedBinary()(target) + np.testing.assert_array_equal(weights, expected_weights) + assert nn.BCEWithLogitsLoss(pos_weight=torch.Tensor(weights))( + torch.from_numpy(target).float(), + torch.from_numpy(target).float(), + ) > 0