Skip to content

Commit

Permalink
Report the batch size used for the qps/throughput calculation (#2788)
Browse files Browse the repository at this point in the history
Summary:

If using batch size stages, also report the current batch size that was used for the throughput calculation

Reviewed By: kausv

Differential Revision: D70892994
  • Loading branch information
Ilyas Atishev authored and facebook-github-bot committed Mar 11, 2025
1 parent e1ee42c commit 5deab66
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class MetricName(MetricNameBase):
THROUGHPUT = "throughput"
TOTAL_EXAMPLES = "total_examples"
ATTEMPT_EXAMPLES = "attempt_examples"
BATCH_SIZE = "batch_size"
CTR = "ctr"
CALIBRATION = "calibration"
MSE = "mse"
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/tests/test_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def test_batch_size_schedule(self, time_mock: Mock) -> None:
{
"throughput-throughput|total_examples": total_examples,
"throughput-throughput|attempt_examples": total_examples,
"throughput-throughput|batch_size": 256,
},
)

Expand All @@ -209,5 +210,6 @@ def test_batch_size_schedule(self, time_mock: Mock) -> None:
{
"throughput-throughput|total_examples": total_examples,
"throughput-throughput|attempt_examples": total_examples,
"throughput-throughput|batch_size": 512,
},
)
17 changes: 17 additions & 0 deletions torchrec/metrics/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class ThroughputMetric(nn.Module):
_attempt_throughput_key: str
_total_examples_key: str
_attempt_examples_key: str
_batch_size_key: str
_steps: int

def __init__(
Expand Down Expand Up @@ -167,6 +168,11 @@ def __init__(
str(self._namespace),
MetricName.ATTEMPT_EXAMPLES,
)
self._batch_size_key = compose_metric_key(
self._namespace,
str(self._namespace),
MetricName.BATCH_SIZE,
)
self._steps = 0

def _get_batch_size(self) -> int:
Expand Down Expand Up @@ -258,4 +264,15 @@ def compute(self) -> Dict[str, torch.Tensor]:
self._attempt_throughput_key: attempt_throughput.clone().detach(),
}
)
# If using batch_size_stages, also report the current batch size
# that was used for the throughput calculation
if self._batch_size_stages is not None:
ret.update(
{
self._batch_size_key: torch.tensor(
self._get_batch_size(), dtype=torch.int32
),
}
)

return ret

0 comments on commit 5deab66

Please sign in to comment.