Skip to content

Commit 503c8f6

Browse files
author
pytorchbot
committed
2025-01-11 nightly release (a0146b5)
1 parent 04d1159 commit 503c8f6

7 files changed

+348
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from typing import Any, Dict, Optional, Type
11+
12+
import torch
13+
from torchrec.metrics.calibration import (
14+
CalibrationMetricComputation,
15+
get_calibration_states,
16+
)
17+
from torchrec.metrics.metrics_namespace import MetricNamespace
18+
from torchrec.metrics.rec_metric import (
19+
RecMetric,
20+
RecMetricComputation,
21+
RecMetricException,
22+
)
23+
24+
25+
CALIBRATION_NUM = "calibration_num"
26+
CALIBRATION_DENOM = "calibration_denom"
27+
28+
29+
class RecalibratedCalibrationMetricComputation(CalibrationMetricComputation):
30+
r"""
31+
This class implements the RecMetricComputation for Calibration that is required to correctly estimate eval NE if negative downsampling was used during training.
32+
33+
The constructor arguments are defined in RecMetricComputation.
34+
See the docstring of RecMetricComputation for more detail.
35+
"""
36+
37+
def __init__(
38+
self, *args: Any, recalibration_coefficient: float = 1.0, **kwargs: Any
39+
) -> None:
40+
self._recalibration_coefficient: float = recalibration_coefficient
41+
super().__init__(*args, **kwargs)
42+
self._add_state(
43+
CALIBRATION_NUM,
44+
torch.zeros(self._n_tasks, dtype=torch.double),
45+
add_window_state=True,
46+
dist_reduce_fx="sum",
47+
persistent=True,
48+
)
49+
self._add_state(
50+
CALIBRATION_DENOM,
51+
torch.zeros(self._n_tasks, dtype=torch.double),
52+
add_window_state=True,
53+
dist_reduce_fx="sum",
54+
persistent=True,
55+
)
56+
57+
def _recalibrate(
58+
self,
59+
predictions: torch.Tensor,
60+
calibration_coef: Optional[torch.Tensor],
61+
) -> torch.Tensor:
62+
if calibration_coef is not None:
63+
predictions = predictions / (
64+
predictions + (1.0 - predictions) / calibration_coef
65+
)
66+
return predictions
67+
68+
def update(
69+
self,
70+
*,
71+
predictions: Optional[torch.Tensor],
72+
labels: torch.Tensor,
73+
weights: Optional[torch.Tensor],
74+
**kwargs: Dict[str, Any],
75+
) -> None:
76+
if predictions is None or weights is None:
77+
raise RecMetricException(
78+
"Inputs 'predictions' and 'weights' should not be None for CalibrationMetricComputation update"
79+
)
80+
predictions = self._recalibrate(
81+
predictions, self._recalibration_coefficient * torch.ones_like(predictions)
82+
)
83+
num_samples = predictions.shape[-1]
84+
for state_name, state_value in get_calibration_states(
85+
labels, predictions, weights
86+
).items():
87+
state = getattr(self, state_name)
88+
state += state_value
89+
self._aggregate_window_state(state_name, state_value, num_samples)
90+
91+
92+
class RecalibratedCalibrationMetric(RecMetric):
93+
_namespace: MetricNamespace = MetricNamespace.RECALIBRATED_CALIBRATION
94+
_computation_class: Type[RecMetricComputation] = (
95+
RecalibratedCalibrationMetricComputation
96+
)

torchrec/metrics/metric_module.py

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from torchrec.metrics.auprc import AUPRCMetric
2424
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
2525
from torchrec.metrics.calibration import CalibrationMetric
26+
from torchrec.metrics.calibration_with_recalibration import (
27+
RecalibratedCalibrationMetric,
28+
)
2629
from torchrec.metrics.ctr import CTRMetric
2730
from torchrec.metrics.hindsight_target_pr import HindsightTargetPRMetric
2831
from torchrec.metrics.mae import MAEMetric
@@ -46,6 +49,7 @@
4649
from torchrec.metrics.ndcg import NDCGMetric
4750
from torchrec.metrics.ne import NEMetric
4851
from torchrec.metrics.ne_positive import NEPositiveMetric
52+
from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric
4953
from torchrec.metrics.output import OutputMetric
5054
from torchrec.metrics.precision import PrecisionMetric
5155
from torchrec.metrics.precision_session import PrecisionSessionMetric
@@ -71,6 +75,8 @@
7175
RecMetricEnum.NE: NEMetric,
7276
RecMetricEnum.NE_POSITIVE: NEPositiveMetric,
7377
RecMetricEnum.SEGMENTED_NE: SegmentedNEMetric,
78+
RecMetricEnum.RECALIBRATED_NE: RecalibratedNEMetric,
79+
RecMetricEnum.RECALIBRATED_CALIBRATION: RecalibratedCalibrationMetric,
7480
RecMetricEnum.CTR: CTRMetric,
7581
RecMetricEnum.CALIBRATION: CalibrationMetric,
7682
RecMetricEnum.AUC: AUCMetric,

torchrec/metrics/metrics_config.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class RecMetricEnumBase(StrValueMixin, Enum):
2121
class RecMetricEnum(RecMetricEnumBase):
2222
NE = "ne"
2323
NE_POSITIVE = "ne_positive"
24+
RECALIBRATED_NE = "recalibrated_ne"
25+
RECALIBRATED_CALIBRATION = "recalibrated_calibration"
2426
SEGMENTED_NE = "segmented_ne"
2527
LOG_LOSS = "log_loss"
2628
CTR = "ctr"

torchrec/metrics/metrics_namespace.py

+2
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class MetricNamespace(MetricNamespaceBase):
9797
NE = "ne"
9898
NE_POSITIVE = "ne_positive"
9999
SEGMENTED_NE = "segmented_ne"
100+
RECALIBRATED_NE = "recalibrated_ne"
101+
RECALIBRATED_CALIBRATION = "recalibrated_calibration"
100102
THROUGHPUT = "throughput"
101103
CTR = "ctr"
102104
CALIBRATION = "calibration"
+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from typing import Any, Dict, Optional, Type
11+
12+
import torch
13+
14+
from torchrec.metrics.metrics_namespace import MetricNamespace
15+
from torchrec.metrics.ne import get_ne_states, NEMetricComputation
16+
from torchrec.metrics.rec_metric import (
17+
RecMetric,
18+
RecMetricComputation,
19+
RecMetricException,
20+
)
21+
22+
23+
class RecalibratedNEMetricComputation(NEMetricComputation):
24+
r"""
25+
This class implements the recalibration for NE that is required to correctly estimate eval NE if negative downsampling was used during training.
26+
27+
The constructor arguments are defined in RecMetricComputation.
28+
See the docstring of RecMetricComputation for more detail.
29+
30+
Args:
31+
include_logloss (bool): return vanilla logloss as one of metrics results, on top of NE.
32+
"""
33+
34+
def __init__(
35+
self,
36+
*args: Any,
37+
include_logloss: bool = False,
38+
allow_missing_label_with_zero_weight: bool = False,
39+
recalibration_coefficient: float = 1.0,
40+
**kwargs: Any,
41+
) -> None:
42+
self._recalibration_coefficient: float = recalibration_coefficient
43+
self._include_logloss: bool = include_logloss
44+
self._allow_missing_label_with_zero_weight: bool = (
45+
allow_missing_label_with_zero_weight
46+
)
47+
super().__init__(*args, **kwargs)
48+
self._add_state(
49+
"cross_entropy_sum",
50+
torch.zeros(self._n_tasks, dtype=torch.double),
51+
add_window_state=True,
52+
dist_reduce_fx="sum",
53+
persistent=True,
54+
)
55+
self._add_state(
56+
"weighted_num_samples",
57+
torch.zeros(self._n_tasks, dtype=torch.double),
58+
add_window_state=True,
59+
dist_reduce_fx="sum",
60+
persistent=True,
61+
)
62+
self._add_state(
63+
"pos_labels",
64+
torch.zeros(self._n_tasks, dtype=torch.double),
65+
add_window_state=True,
66+
dist_reduce_fx="sum",
67+
persistent=True,
68+
)
69+
self._add_state(
70+
"neg_labels",
71+
torch.zeros(self._n_tasks, dtype=torch.double),
72+
add_window_state=True,
73+
dist_reduce_fx="sum",
74+
persistent=True,
75+
)
76+
self.eta = 1e-12
77+
78+
def _recalibrate(
79+
self,
80+
predictions: torch.Tensor,
81+
calibration_coef: Optional[torch.Tensor],
82+
) -> torch.Tensor:
83+
if calibration_coef is not None:
84+
predictions = predictions / (
85+
predictions + (1.0 - predictions) / calibration_coef
86+
)
87+
return predictions
88+
89+
def update(
90+
self,
91+
*,
92+
predictions: Optional[torch.Tensor],
93+
labels: torch.Tensor,
94+
weights: Optional[torch.Tensor],
95+
**kwargs: Dict[str, Any],
96+
) -> None:
97+
if predictions is None or weights is None:
98+
raise RecMetricException(
99+
"Inputs 'predictions' and 'weights' should not be None for RecalibratedNEMetricComputation update"
100+
)
101+
102+
predictions = self._recalibrate(
103+
predictions, self._recalibration_coefficient * torch.ones_like(predictions)
104+
)
105+
states = get_ne_states(labels, predictions, weights, self.eta)
106+
num_samples = predictions.shape[-1]
107+
108+
for state_name, state_value in states.items():
109+
state = getattr(self, state_name)
110+
state += state_value
111+
self._aggregate_window_state(state_name, state_value, num_samples)
112+
113+
114+
class RecalibratedNEMetric(RecMetric):
115+
_namespace: MetricNamespace = MetricNamespace.RECALIBRATED_NE
116+
_computation_class: Type[RecMetricComputation] = RecalibratedNEMetricComputation
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from typing import Dict
12+
13+
import torch
14+
from torchrec.metrics.calibration_with_recalibration import (
15+
RecalibratedCalibrationMetric,
16+
)
17+
from torchrec.metrics.metrics_config import DefaultTaskInfo
18+
19+
20+
WORLD_SIZE = 4
21+
BATCH_SIZE = 10
22+
23+
24+
def generate_model_output() -> Dict[str, torch._tensor.Tensor]:
25+
return {
26+
"predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]),
27+
"labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0]]),
28+
"weights": torch.tensor([[1.0, 1.0, 1.0, 0.0, 1.0]]),
29+
"expected_recalibrated_calibration": torch.tensor([0.0837]),
30+
}
31+
32+
33+
class RecalibratedCalibrationMetricMetricTest(unittest.TestCase):
34+
def setUp(self) -> None:
35+
self.calibration_with_recalibration = RecalibratedCalibrationMetric(
36+
world_size=WORLD_SIZE,
37+
my_rank=0,
38+
batch_size=BATCH_SIZE,
39+
tasks=[DefaultTaskInfo],
40+
# pyre-ignore[6]
41+
recalibration_coefficient=0.1,
42+
)
43+
44+
def test_calibration_with_recalibration(self) -> None:
45+
model_output = generate_model_output()
46+
self.calibration_with_recalibration.update(
47+
predictions={DefaultTaskInfo.name: model_output["predictions"][0]},
48+
labels={DefaultTaskInfo.name: model_output["labels"][0]},
49+
weights={DefaultTaskInfo.name: model_output["weights"][0]},
50+
)
51+
metric = self.calibration_with_recalibration.compute()
52+
actual_metric = metric[
53+
f"recalibrated_calibration-{DefaultTaskInfo.name}|lifetime_calibration"
54+
]
55+
expected_metric = model_output["expected_recalibrated_calibration"]
56+
57+
torch.testing.assert_close(
58+
actual_metric,
59+
expected_metric,
60+
atol=1e-4,
61+
rtol=1e-4,
62+
check_dtype=False,
63+
equal_nan=True,
64+
msg=f"Actual: {actual_metric}, Expected: {expected_metric}",
65+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from typing import Dict
12+
13+
import torch
14+
from torchrec.metrics.metrics_config import DefaultTaskInfo
15+
from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric
16+
17+
18+
WORLD_SIZE = 4
19+
BATCH_SIZE = 10
20+
21+
22+
def generate_model_output() -> Dict[str, torch._tensor.Tensor]:
23+
return {
24+
"predictions": torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]),
25+
"labels": torch.tensor([[1.0, 0.0, 1.0, 1.0, 0.0]]),
26+
"weights": torch.tensor([[1.0, 1.0, 1.0, 0.0, 1.0]]),
27+
"expected_recalibrated_ne": torch.tensor([2.8214]),
28+
}
29+
30+
31+
class RecalibratedNEMetricMetricTest(unittest.TestCase):
32+
def setUp(self) -> None:
33+
self.ne_with_recalibration = RecalibratedNEMetric(
34+
world_size=WORLD_SIZE,
35+
my_rank=0,
36+
batch_size=BATCH_SIZE,
37+
tasks=[DefaultTaskInfo],
38+
# pyre-ignore[6]
39+
recalibration_coefficient=0.1,
40+
)
41+
42+
def test_ne_with_recalibration(self) -> None:
43+
model_output = generate_model_output()
44+
self.ne_with_recalibration.update(
45+
predictions={DefaultTaskInfo.name: model_output["predictions"][0]},
46+
labels={DefaultTaskInfo.name: model_output["labels"][0]},
47+
weights={DefaultTaskInfo.name: model_output["weights"][0]},
48+
)
49+
metric = self.ne_with_recalibration.compute()
50+
actual_metric = metric[f"recalibrated_ne-{DefaultTaskInfo.name}|lifetime_ne"]
51+
expected_metric = model_output["expected_recalibrated_ne"]
52+
53+
torch.testing.assert_close(
54+
actual_metric,
55+
expected_metric,
56+
atol=1e-4,
57+
rtol=1e-4,
58+
check_dtype=False,
59+
equal_nan=True,
60+
msg=f"Actual: {actual_metric}, Expected: {expected_metric}",
61+
)

0 commit comments

Comments
 (0)