Skip to content

Commit

Permalink
Fix_248 (automl#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
franchuterivera authored Jun 24, 2021
1 parent 2a9f04e commit 8237f2c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
12 changes: 5 additions & 7 deletions autoPyTorch/utils/implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ class LossWeightStrategyWeighted():
def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(y, torch.Tensor):
y = y.detach().cpu().numpy() if y.is_cuda else y.numpy()
if isinstance(y[0], str):
y = y.astype('float64')
counts = np.sum(y, axis=0)
total_weight = y.shape[0]

if len(y.shape) > 1:
if len(y.shape) > 1 and y.shape[1] != 1:
# In this case, the second axis represents classes
weight_per_class = total_weight / y.shape[1]
weights = (np.ones(y.shape[1]) * weight_per_class) / np.maximum(counts, 1)
else:
# Numpy unique return the sorted classes. This is desirable as
# weights recieved by PyTorch is a sorted list of classes
classes, counts = np.unique(y, axis=0, return_counts=True)
classes, counts = classes[::-1], counts[::-1]
weight_per_class = total_weight / classes.shape[0]
weights = (np.ones(classes.shape[0]) * weight_per_class) / counts

Expand All @@ -50,10 +50,8 @@ class LossWeightStrategyWeightedBinary():
def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(y, torch.Tensor):
y = y.detach().cpu().numpy() if y.is_cuda else y.numpy()
if isinstance(y[0], str):
y = y.astype('float64')
counts_one = np.sum(y, axis=0)
counts_zero = counts_one + (-y.shape[0])
counts_zero = y.shape[0] - counts_one
weights = counts_zero / np.maximum(counts_one, 1)

return np.array(weights)
Expand Down
75 changes: 74 additions & 1 deletion test/test_pipeline/test_losses.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import numpy as np

import pytest

import torch
from torch import nn
from torch.nn.modules.loss import _Loss as Loss

from autoPyTorch.pipeline.components.training.losses import get_loss, losses
from autoPyTorch.utils.implementations import get_loss_weight_strategy
from autoPyTorch.utils.implementations import (
LossWeightStrategyWeighted,
LossWeightStrategyWeightedBinary,
get_loss_weight_strategy,
)


@pytest.mark.parametrize('output_type', ['multiclass',
Expand Down Expand Up @@ -66,3 +72,70 @@ def test_loss_dict():
assert isinstance(loss['module'](), Loss)
assert 'supported_output_types' in loss.keys()
assert isinstance(loss['supported_output_types'], list)


@pytest.mark.parametrize('target,expected_weights', [
(
# Expected 4 classes where first one is majority one
np.array([[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
# We reduce the contribution of the first class which has double elements
np.array([0.5, 1., 1., 1.]),
),
(
# Expected 2 classes -- multilable format
np.array([[1, 0], [1, 0], [1, 0], [0, 1]]),
# We reduce the contribution of the first class which 3 to 1 ratio
np.array([2 / 3, 2]),
),
(
# Expected 2 classes -- (-1, 1) format
np.array([[1], [1], [1], [0]]),
# We reduce the contribution of the second class, which has a 3 to 1 ratio
np.array([2, 2 / 3]),
),
(
# Expected 2 classes -- single column
# We have to reduce the contribution of the second class with 5 to 1 ratio
np.array([1, 1, 1, 1, 1, 0]),
# We reduce the contribution of the first class which has double elements
np.array([3, 6 / 10]),
),
])
def test_lossweightstrategyweighted(target, expected_weights):
weights = LossWeightStrategyWeighted()(target)
np.testing.assert_array_equal(weights, expected_weights)
assert nn.CrossEntropyLoss(weight=torch.Tensor(weights))(
torch.zeros(target.shape[0], len(weights)).float(),
torch.from_numpy(target.argmax(1)).long() if len(target.shape) > 1
else torch.from_numpy(target).long()
) > 0


@pytest.mark.parametrize('target,expected_weights', [
(
# Expected 2 classes -- multilable format
np.array([[1, 0], [1, 0], [1, 0], [0, 1]]),
# We reduce the contribution of the first class which 3 to 1 ratio
np.array([1 / 3, 3]),
),
(
# Expected 2 classes -- (-1, 1) format
np.array([[1], [1], [1], [0]]),
# We reduce the contribution of the second class, which has a 3 to 1 ratio
np.array([1 / 3]),
),
(
# Expected 2 classes -- single column
# We have to reduce the contribution of the second class with 5 to 1 ratio
np.array([1, 1, 1, 1, 1, 0]),
# We reduce the contribution of the first class which has double elements
np.array([0.2]),
),
])
def test_lossweightstrategyweightedbinary(target, expected_weights):
weights = LossWeightStrategyWeightedBinary()(target)
np.testing.assert_array_equal(weights, expected_weights)
assert nn.BCEWithLogitsLoss(pos_weight=torch.Tensor(weights))(
torch.from_numpy(target).float(),
torch.from_numpy(target).float(),
) > 0

0 comments on commit 8237f2c

Please sign in to comment.