Skip to content

Commit

Permalink
2025-02-27 nightly release (a3eee19)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Feb 27, 2025
1 parent 3522a80 commit 822a5a7
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 101 deletions.
345 changes: 252 additions & 93 deletions torchrec/distributed/test_utils/test_model.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _test_sharded_forward(
dedup_tables: Optional[List[EmbeddingTableConfig]] = None,
weighted_tables: Optional[List[EmbeddingTableConfig]] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
# pyre-ignore [9]
generate: ModelInputCallable = ModelInput.generate,
) -> None:
default_rank = 0
Expand Down
38 changes: 36 additions & 2 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def __call__(
Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
] = None,
variable_batch_size: bool = False,
use_offsets: bool = False,
indices_dtype: torch.dtype = torch.int64,
offsets_dtype: torch.dtype = torch.int64,
lengths_dtype: torch.dtype = torch.int64,
long_indices: bool = True,
) -> Tuple["ModelInput", List["ModelInput"]]: ...

Expand All @@ -140,6 +144,10 @@ def __call__(
weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
pooling_avg: int = 10,
global_constant_batch: bool = False,
use_offsets: bool = False,
indices_dtype: torch.dtype = torch.int64,
offsets_dtype: torch.dtype = torch.int64,
lengths_dtype: torch.dtype = torch.int64,
) -> Tuple["ModelInput", List["ModelInput"]]: ...


Expand All @@ -148,6 +156,7 @@ def gen_model_and_input(
tables: List[EmbeddingTableConfig],
embedding_groups: Dict[str, List[str]],
world_size: int,
# pyre-ignore [9]
generate: Union[
ModelInputCallable, VariableBatchModelInputCallable
] = ModelInput.generate,
Expand All @@ -160,10 +169,14 @@ def gen_model_and_input(
variable_batch_size: bool = False,
batch_size: int = 4,
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
long_indices: bool = True,
use_offsets: bool = False,
indices_dtype: torch.dtype = torch.int64,
offsets_dtype: torch.dtype = torch.int64,
lengths_dtype: torch.dtype = torch.int64,
global_constant_batch: bool = False,
num_inputs: int = 1,
input_type: str = "kjt", # "kjt" or "td"
long_indices: bool = True,
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -204,6 +217,10 @@ def gen_model_and_input(
tables=tables,
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
use_offsets=use_offsets,
indices_dtype=indices_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=lengths_dtype,
)
)
elif generate == ModelInput.generate:
Expand All @@ -217,8 +234,12 @@ def gen_model_and_input(
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
input_type=input_type,
use_offsets=use_offsets,
indices_dtype=indices_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=lengths_dtype,
long_indices=long_indices,
)
)
else:
Expand All @@ -232,6 +253,10 @@ def gen_model_and_input(
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
use_offsets=use_offsets,
indices_dtype=indices_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=lengths_dtype,
long_indices=long_indices,
)
)
Expand Down Expand Up @@ -335,6 +360,10 @@ def sharding_single_rank_test(
input_type: str = "kjt", # "kjt" or "td"
allow_zero_batch_size: bool = False,
custom_all_reduce: bool = False, # 2D parallel
use_offsets: bool = False,
indices_dtype: torch.dtype = torch.int64,
offsets_dtype: torch.dtype = torch.int64,
lengths_dtype: torch.dtype = torch.int64,
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
batch_size = (
Expand All @@ -344,6 +373,7 @@ def sharding_single_rank_test(
(global_model, inputs) = gen_model_and_input(
model_class=model_class,
tables=tables,
# pyre-ignore [6]
generate=(
cast(
VariableBatchModelInputCallable,
Expand All @@ -361,6 +391,10 @@ def sharding_single_rank_test(
feature_processor_modules=feature_processor_modules,
global_constant_batch=global_constant_batch,
input_type=input_type,
use_offsets=use_offsets,
indices_dtype=indices_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=lengths_dtype,
)
global_model = global_model.to(ctx.device)
global_input = inputs[0][0].to(ctx.device)
Expand Down
2 changes: 1 addition & 1 deletion torchrec/inference/inference_legacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
- `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model.
"""

from . import model_packager, modules # noqa # noqa
from . import model_packager # noqa
4 changes: 4 additions & 0 deletions torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,7 @@ def test_fused_params_overwrite(self) -> None:

# Make sure that overwrite of ebc_fused_params is not reflected in ec_fused_params
self.assertEqual(ec_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL], orig_value)

# change it back to the original value because it modifies the global variable
# otherwise it will affect other tests
ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = orig_value
6 changes: 6 additions & 0 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,13 @@ def test_serialize_deserialize_ebc(self) -> None:
self.assertEqual(deserialized.shape, orginal.shape)
self.assertTrue(torch.allclose(deserialized, orginal))

# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() == 0,
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
)
def test_dynamic_shape_ebc(self) -> None:
# TODO: https://fb.workplace.com/groups/1028545332188949/permalink/1138699244506890/
model = self.generate_model()
feature1 = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
Expand Down
15 changes: 15 additions & 0 deletions torchrec/models/experimental/test_transformerdlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def test_larger(self) -> None:
concat_dense = inter_arch(dense_features, sparse_features)
self.assertEqual(concat_dense.size(), (B, D * (F + 1)))

# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() == 0,
"skip this test in OSS (no GPU available) because seed might be different in OSS",
)
def test_correctness(self) -> None:
D = 4
B = 3
Expand Down Expand Up @@ -165,6 +170,11 @@ def test_correctness(self) -> None:
)
)

# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() == 0,
"skip this test in OSS (no GPU available) because seed might be different in OSS",
)
def test_numerical_stability(self) -> None:
D = 4
B = 3
Expand Down Expand Up @@ -194,6 +204,11 @@ def test_numerical_stability(self) -> None:


class DLRMTransformerTest(unittest.TestCase):
# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(
torch.cuda.device_count() == 0,
"skip this test in OSS (no GPU available) because seed might be different in OSS",
)
def test_basic(self) -> None:
torch.manual_seed(0)
B = 2
Expand Down
12 changes: 7 additions & 5 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ def apply_mc_method_to_jt_dict(
def _update(
base: Optional[Dict[str, JaggedTensor]], delta: Dict[str, JaggedTensor]
) -> Dict[str, JaggedTensor]:
if base is None:
base = delta
else:
base.update(delta)
return base
res: Dict[str, JaggedTensor] = {}
if base is not None:
for k, v in base.items():
res[k] = v
for k, v in delta.items():
res[k] = v
return res


@torch.fx.wrap
Expand Down

0 comments on commit 822a5a7

Please sign in to comment.