Skip to content

Commit

Permalink
Fix error raised by latest Torchmetrics (0.11.0) (#576)
Browse files Browse the repository at this point in the history
* update default torchmetrics

* unpin torchmetrics version

* update torchmetrics requirement

* update torchmetrics version from Oliver suggestion
  • Loading branch information
sararb authored Dec 13, 2022
1 parent 215d570 commit 650f932
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion requirements/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.0
torchmetrics==0.3.2
torchmetrics>=0.10.0
12 changes: 6 additions & 6 deletions tests/torch/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def test_evaluate_results(torch_yoochoose_next_item_prediction_model):
(
tr.BinaryClassificationTask("click", summary_type="mean"),
[
"eval_/click/binary_classification_task/accuracy",
"eval_/click/binary_classification_task/precision",
"eval_/click/binary_classification_task/recall",
"eval_/click/binary_classification_task/binary_accuracy",
"eval_/click/binary_classification_task/binary_precision",
"eval_/click/binary_classification_task/binary_recall",
],
),
(
Expand Down Expand Up @@ -446,9 +446,9 @@ def test_trainer_with_multiple_tasks():
"eval_/next-item/avg_precision_at_20",
"eval_/next-item/recall_at_10",
"eval_/next-item/recall_at_20",
"eval_/click/binary_classification_task/accuracy",
"eval_/click/binary_classification_task/precision",
"eval_/click/binary_classification_task/recall",
"eval_/click/binary_classification_task/binary_accuracy",
"eval_/click/binary_classification_task/binary_precision",
"eval_/click/binary_classification_task/binary_recall",
"eval_/play_percentage/regression_task/mean_squared_error",
]

Expand Down
6 changes: 3 additions & 3 deletions transformers4rec/torch/model/prediction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def build(self, input_size) -> SequentialBlock:
class BinaryClassificationTask(PredictionTask):
DEFAULT_LOSS = torch.nn.BCELoss()
DEFAULT_METRICS = (
tm.Precision(num_classes=2),
tm.Recall(num_classes=2),
tm.Accuracy(),
tm.Precision(num_classes=2, task="binary"),
tm.Recall(num_classes=2, task="binary"),
tm.Accuracy(task="binary"),
# TODO: Fix this: tm.AUC()
)

Expand Down

0 comments on commit 650f932

Please sign in to comment.