Skip to content

Commit c9cfe23

Browse files
author
pytorchbot
committed
2025-01-17 nightly release (33168a1)
1 parent 29755d8 commit c9cfe23

File tree

6 files changed

+66
-17
lines changed

6 files changed

+66
-17
lines changed

torchrec/distributed/embedding_sharding.py

+4
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
132132
bucketize_pos: bool = False,
133133
block_bucketize_pos: Optional[List[torch.Tensor]] = None,
134134
total_num_blocks: Optional[torch.Tensor] = None,
135+
keep_original_indices: bool = False,
135136
) -> Tuple[
136137
torch.Tensor,
137138
torch.Tensor,
@@ -159,6 +160,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
159160
max_B=_fx_wrap_max_B(kjt),
160161
block_bucketize_pos=block_bucketize_pos,
161162
return_bucket_mapping=True,
163+
keep_orig_idx=keep_original_indices,
162164
)
163165

164166
return (
@@ -305,6 +307,7 @@ def bucketize_kjt_inference(
305307
bucketize_pos: bool = False,
306308
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
307309
is_sequence: bool = False,
310+
keep_original_indices: bool = False,
308311
) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
309312
"""
310313
Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets,
@@ -352,6 +355,7 @@ def bucketize_kjt_inference(
352355
total_num_blocks=total_num_buckets_new_type,
353356
bucketize_pos=bucketize_pos,
354357
block_bucketize_pos=block_bucketize_row_pos,
358+
keep_original_indices=keep_original_indices,
355359
)
356360
else:
357361
(

torchrec/distributed/mc_modules.py

+1
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,7 @@ def _create_input_dists(
11711171
has_feature_processor=sharding._has_feature_processor,
11721172
need_pos=False,
11731173
embedding_shard_metadata=emb_sharding,
1174+
keep_original_indices=True,
11741175
)
11751176
self._input_dists.append(input_dist)
11761177

torchrec/distributed/sharding/rw_sharding.py

+4
Original file line numberDiff line numberDiff line change
@@ -649,10 +649,12 @@ def __init__(
649649
has_feature_processor: bool = False,
650650
need_pos: bool = False,
651651
embedding_shard_metadata: Optional[List[List[int]]] = None,
652+
keep_original_indices: bool = False,
652653
) -> None:
653654
super().__init__()
654655
logger.info(
655656
f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}"
657+
f", keep_original_indices={keep_original_indices}"
656658
)
657659
self._world_size: int = world_size
658660
self._num_features = num_features
@@ -683,6 +685,7 @@ def __init__(
683685
self._embedding_shard_metadata: Optional[List[List[int]]] = (
684686
embedding_shard_metadata
685687
)
688+
self._keep_original_indices = keep_original_indices
686689

687690
def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
688691
block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device(
@@ -717,6 +720,7 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
717720
block_bucketize_row_pos
718721
),
719722
is_sequence=self._is_sequence,
723+
keep_original_indices=self._keep_original_indices,
720724
)
721725
# KJTOneToAll
722726
dist_kjt = self._dist.forward(bucketized_features)

torchrec/metrics/metrics_config.py

-5
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,6 @@ def validate_batch_size_stages(
236236
if len(batch_size_stages) == 0:
237237
raise ValueError("Batch size stages should not be empty")
238238

239-
for i in range(len(batch_size_stages) - 1):
240-
if batch_size_stages[i].batch_size >= batch_size_stages[i + 1].batch_size:
241-
raise ValueError(
242-
f"Batch size should be in ascending order. Got {batch_size_stages}"
243-
)
244239
if batch_size_stages[-1].max_iters is not None:
245240
raise ValueError(
246241
f"Batch size stages last stage should have max_iters = None, but get {batch_size_stages[-1].max_iters}"

torchrec/metrics/segmented_ne.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,25 @@ class SegmentedNEMetricComputation(RecMetricComputation):
165165
166166
Args:
167167
include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE.
168+
num_groups (int): number of groups to segment NE by.
169+
grouping_keys (str): name of the tensor containing the label by which results will be segmented. This tensor should be of type torch.int64.
170+
cast_keys_to_int (bool): whether to cast grouping_keys to torch.int64. Only works if grouping_keys is of type torch.float32.
168171
"""
169172

170173
def __init__(
171174
self,
172175
*args: Any,
173176
include_logloss: bool = False, # TODO - include
174177
num_groups: int = 1,
178+
grouping_keys: str = "grouping_keys",
179+
cast_keys_to_int: bool = False,
175180
**kwargs: Any,
176181
) -> None:
177182
self._include_logloss: bool = include_logloss
178183
super().__init__(*args, **kwargs)
179184
self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state
185+
self._grouping_keys = grouping_keys
186+
self._cast_keys_to_int = cast_keys_to_int
180187
self._add_state(
181188
"cross_entropy_sum",
182189
torch.zeros((self._n_tasks, num_groups), dtype=torch.double),
@@ -217,21 +224,30 @@ def update(
217224
) -> None:
218225
if predictions is None or weights is None:
219226
raise RecMetricException(
220-
"Inputs 'predictions' and 'weights' and 'grouping_keys' should not be None for NEMetricComputation update"
227+
f"Inputs 'predictions' and 'weights' and '{self._grouping_keys}' should not be None for NEMetricComputation update"
221228
)
222229
elif (
223230
"required_inputs" not in kwargs
224-
or kwargs["required_inputs"].get("grouping_keys") is None
231+
or kwargs["required_inputs"].get(self._grouping_keys) is None
225232
):
226233
raise RecMetricException(
227-
f"Required inputs for SegmentedNEMetricComputation update should contain 'grouping_keys', got kwargs: {kwargs}"
228-
)
229-
elif kwargs["required_inputs"]["grouping_keys"].dtype != torch.int64:
230-
raise RecMetricException(
231-
f"Grouping keys must have type torch.int64, got {kwargs['required_inputs']['grouping_keys'].dtype}."
234+
f"Required inputs for SegmentedNEMetricComputation update should contain {self._grouping_keys}, got kwargs: {kwargs}"
232235
)
236+
elif kwargs["required_inputs"][self._grouping_keys].dtype != torch.int64:
237+
if (
238+
self._cast_keys_to_int
239+
and kwargs["required_inputs"][self._grouping_keys].dtype
240+
== torch.float32
241+
):
242+
kwargs["required_inputs"][self._grouping_keys] = kwargs[
243+
"required_inputs"
244+
][self._grouping_keys].to(torch.int64)
245+
else:
246+
raise RecMetricException(
247+
f"Grouping keys expected to have type torch.int64 or torch.float32 with cast_keys_to_int set to true, got {kwargs['required_inputs'][self._grouping_keys].dtype}."
248+
)
233249

234-
grouping_keys = kwargs["required_inputs"]["grouping_keys"]
250+
grouping_keys = kwargs["required_inputs"][self._grouping_keys]
235251
states = get_segemented_ne_states(
236252
labels,
237253
predictions,
@@ -325,4 +341,8 @@ def __init__(
325341
process_group=process_group,
326342
**kwargs,
327343
)
328-
self._required_inputs.add("grouping_keys")
344+
if "grouping_keys" not in kwargs:
345+
self._required_inputs.add("grouping_keys")
346+
else:
347+
# pyre-ignore[6]
348+
self._required_inputs.add(kwargs["grouping_keys"])

torchrec/metrics/tests/test_segmented_ne.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Dict, Iterable, Union
11+
from typing import Any, Dict, Iterable, Union
1212

1313
import torch
1414
from torch import no_grad
@@ -31,6 +31,8 @@ def _test_segemented_ne_helper(
3131
weights: torch.Tensor,
3232
expected_ne: torch.Tensor,
3333
grouping_keys: torch.Tensor,
34+
grouping_key_tensor_name: str = "grouping_keys",
35+
cast_keys_to_int: bool = False,
3436
) -> None:
3537
num_task = labels.shape[0]
3638
batch_size = labels.shape[0]
@@ -41,7 +43,7 @@ def _test_segemented_ne_helper(
4143
"weights": {},
4244
}
4345
if grouping_keys is not None:
44-
inputs["required_inputs"] = {"grouping_keys": grouping_keys}
46+
inputs["required_inputs"] = {grouping_key_tensor_name: grouping_keys}
4547
for i in range(num_task):
4648
task_info = RecTaskInfo(
4749
name=f"Task:{i}",
@@ -64,6 +66,10 @@ def _test_segemented_ne_helper(
6466
tasks=task_list,
6567
# pyre-ignore
6668
num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1),
69+
# pyre-ignore
70+
grouping_keys=grouping_key_tensor_name,
71+
# pyre-ignore
72+
cast_keys_to_int=cast_keys_to_int,
6773
)
6874
ne.update(**inputs)
6975
actual_ne = ne.compute()
@@ -95,7 +101,7 @@ def test_grouped_ne(self) -> None:
95101
raise
96102

97103

98-
def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
104+
def generate_model_outputs_cases() -> Iterable[Dict[str, Any]]:
99105
return [
100106
# base condition
101107
{
@@ -149,4 +155,23 @@ def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
149155
), # for this case, both tasks have same groupings
150156
"expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]),
151157
},
158+
# Custom grouping key tensor name
159+
{
160+
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
161+
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
162+
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
163+
"grouping_keys": torch.tensor([0, 1, 0, 1, 1]),
164+
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
165+
"grouping_key_tensor_name": "custom_key",
166+
},
167+
# Cast grouping keys to int32
168+
{
169+
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
170+
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
171+
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
172+
"grouping_keys": torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0]),
173+
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
174+
"grouping_key_tensor_name": "custom_key",
175+
"cast_keys_to_int": True,
176+
},
152177
]

0 commit comments

Comments
 (0)