Skip to content

Commit f1fb67a

Browse files
Z Zhoufacebook-github-bot
Z Zhou
authored andcommitted
Add Regression AUC (RAUC) metrics
Summary: Implement regression AUC metrics. Regression AUC is an extension of classification AUC. See Section 4.1.1 in https://arxiv.org/ftp/arxiv/papers/1205/1205.2618.pdf for related discussions. On a high level, regression AUC is an extension of the traditional AUC for classification through the probabilistic interpretation: the area under the curve is equal to the probability that a classifier will rank a randomly chosen positive instance higher than a randomly chosen negative one. We utilize merge sort to optimize time complexity to O(nlog(n)). Reviewed By: zainhuda Differential Revision: D53377225 fbshipit-source-id: 7dfccbddf9f17a6881c6fd00f9614466c603514d
1 parent e719551 commit f1fb67a

File tree

6 files changed

+896
-3
lines changed

6 files changed

+896
-3
lines changed

torchrec/metrics/metric_module.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torchrec.metrics.multiclass_recall import MulticlassRecallMetric
4040
from torchrec.metrics.ndcg import NDCGMetric
4141
from torchrec.metrics.ne import NEMetric
42+
from torchrec.metrics.rauc import RAUCMetric
4243
from torchrec.metrics.rec_metric import RecMetric, RecMetricList
4344
from torchrec.metrics.recall_session import RecallSessionMetric
4445
from torchrec.metrics.scalar import ScalarMetric
@@ -58,6 +59,7 @@
5859
RecMetricEnum.CALIBRATION: CalibrationMetric,
5960
RecMetricEnum.AUC: AUCMetric,
6061
RecMetricEnum.AUPRC: AUPRCMetric,
62+
RecMetricEnum.RAUC: RAUCMetric,
6163
RecMetricEnum.MSE: MSEMetric,
6264
RecMetricEnum.MAE: MAEMetric,
6365
RecMetricEnum.MULTICLASS_RECALL: MulticlassRecallMetric,

torchrec/metrics/metrics_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class RecMetricEnum(RecMetricEnumBase):
2323
CTR = "ctr"
2424
AUC = "auc"
2525
AUPRC = "auprc"
26+
RAUC = "rauc"
2627
CALIBRATION = "calibration"
2728
MSE = "mse"
2829
MAE = "mae"

torchrec/metrics/metrics_namespace.py

+3
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ class MetricName(MetricNameBase):
4949
RMSE = "rmse"
5050
AUC = "auc"
5151
AUPRC = "auprc"
52+
RAUC = "rauc"
5253
GROUPED_AUC = "grouped_auc"
5354
GROUPED_AUPRC = "grouped_auprc"
55+
GROUPED_RAUC = "grouped_rauc"
5456
RECALL_SESSION_LEVEL = "recall_session_level"
5557
MULTICLASS_RECALL = "multiclass_recall"
5658
WEIGHTED_AVG = "weighted_avg"
@@ -76,6 +78,7 @@ class MetricNamespace(MetricNamespaceBase):
7678
MSE = "mse"
7779
AUC = "auc"
7880
AUPRC = "auprc"
81+
RAUC = "rauc"
7982
MAE = "mae"
8083
ACCURACY = "accuracy"
8184

0 commit comments

Comments
 (0)