Skip to content

Commit

Permalink
2025-03-04 nightly release (919bbcb)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Mar 4, 2025
1 parent 9e1651e commit 09ab2c1
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 20 deletions.
2 changes: 2 additions & 0 deletions torchrec/distributed/composable/tests/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ def _run(cls, rank: int, world_size: int, path: str) -> None:
weighted_tables=weighted_tables,
dense_device=ctx.device,
)
# pyre-ignore
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=ctx.device,
plan=column_wise(ranks=list(range(world_size))),
)
# pyre-ignore
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=ctx.device,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/composable/tests/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ def _run( # noqa
m.sparse.parameters(),
{"lr": 0.01},
)
# pyre-ignore
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=ctx.device,
plan=row_wise(),
)
# pyre-ignore
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=ctx.device,
Expand Down
217 changes: 210 additions & 7 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@
)
from torchrec.distributed.fused_embedding import FusedEmbeddingCollectionSharder
from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder
from torchrec.distributed.types import QuantizedCommCodecs
from torchrec.distributed.mc_embedding_modules import (
BaseManagedCollisionEmbeddingCollectionSharder,
)
from torchrec.distributed.mc_embeddingbag import (
ShardedManagedCollisionEmbeddingBagCollection,
)
from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder
from torchrec.distributed.types import (
ParameterSharding,
QuantizedCommCodecs,
ShardingEnv,
)
from torchrec.distributed.utils import CopyableMixin
from torchrec.modules.activation import SwishLayerNorm
from torchrec.modules.embedding_configs import (
Expand All @@ -39,6 +50,12 @@
from torchrec.modules.feature_processor import PositionWeightedProcessor
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
from torchrec.modules.mc_modules import (
DistanceLFU_EvictionPolicy,
ManagedCollisionCollection,
MCHManagedCollisionModule,
)
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
from torchrec.streamable import Pipelineable
Expand Down Expand Up @@ -1351,6 +1368,7 @@ def __init__(
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
over_arch_clazz: Type[nn.Module] = TestOverArch,
postproc_module: Optional[nn.Module] = None,
zch: bool = False,
) -> None:
super().__init__(
tables=cast(List[BaseEmbeddingConfig], tables),
Expand All @@ -1362,12 +1380,20 @@ def __init__(
if weighted_tables is None:
weighted_tables = []
self.dense = TestDenseArch(num_float_features, dense_device)
self.sparse = TestSparseArch(
tables,
weighted_tables,
sparse_device,
max_feature_lengths,
)
if zch:
self.sparse: nn.Module = TestSparseArchZCH(
tables,
weighted_tables,
torch.device("meta"),
return_remapped=True,
)
else:
self.sparse = TestSparseArch(
tables,
weighted_tables,
sparse_device,
max_feature_lengths,
)

embedding_names = (
list(embedding_groups.values())[0] if embedding_groups else None
Expand Down Expand Up @@ -1687,6 +1713,64 @@ def compute_kernels(
return [self._kernel_type]


class TestMCSharder(ManagedCollisionCollectionSharder):
def __init__(
self,
sharding_type: str,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
self._sharding_type = sharding_type
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)

def sharding_types(self, compute_device_type: str) -> List[str]:
return [self._sharding_type]


class TestEBCSharderMCH(
BaseManagedCollisionEmbeddingCollectionSharder[
ManagedCollisionEmbeddingBagCollection
]
):
def __init__(
self,
sharding_type: str,
kernel_type: str,
fused_params: Optional[Dict[str, Any]] = None,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__(
TestEBCSharder(
sharding_type, kernel_type, fused_params, qcomm_codecs_registry
),
TestMCSharder(sharding_type, qcomm_codecs_registry),
qcomm_codecs_registry=qcomm_codecs_registry,
)

@property
def module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]:
return ManagedCollisionEmbeddingBagCollection

def shard(
self,
module: ManagedCollisionEmbeddingBagCollection,
params: Dict[str, ParameterSharding],
env: ShardingEnv,
device: Optional[torch.device] = None,
module_fqn: Optional[str] = None,
) -> ShardedManagedCollisionEmbeddingBagCollection:
if device is None:
device = torch.device("cuda")
return ShardedManagedCollisionEmbeddingBagCollection(
module,
params,
# pyre-ignore [6]
ebc_sharder=self._e_sharder,
mc_sharder=self._mc_sharder,
env=env,
device=device,
)


class TestFusedEBCSharder(FusedEmbeddingBagCollectionSharder):
def __init__(
self,
Expand Down Expand Up @@ -2188,3 +2272,122 @@ def forward(self, input: ModelInput) -> ModelInput:
modified_input = copy.deepcopy(input)
modified_input.idlist_features = self.fp_proc(modified_input.idlist_features)
return modified_input


class TestSparseArchZCH(nn.Module):
"""
Basic nn.Module for testing MCH EmbeddingBagCollection
Args:
tables
weighted_tables
device
return_remapped
Call Args:
features
weighted_features
batch_size
Returns:
KeyedTensor
Example::
TestSparseArch()
"""

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
return_remapped: bool = False,
) -> None:
super().__init__()
self._return_remapped = return_remapped

mc_modules = {}
for table in tables:
mc_modules[table.name] = MCHManagedCollisionModule(
zch_size=table.num_embeddings,
input_hash_size=4000,
device=device,
# TODO: If eviction interval is set to
# a low number (e.g. 2), semi-sync pipeline test will
# fail with in-place modification error during
# loss.backward(). This is because during semi-sync training,
# we run embedding module forward after autograd graph
# is constructed, but if MCH eviction happens, the
# variable used in autograd will have been modified
eviction_interval=1000,
eviction_policy=DistanceLFU_EvictionPolicy(),
)

self.ebc: ManagedCollisionEmbeddingBagCollection = (
ManagedCollisionEmbeddingBagCollection(
EmbeddingBagCollection(
tables=tables,
device=device,
),
ManagedCollisionCollection(
managed_collision_modules=mc_modules,
embedding_configs=tables,
),
return_remapped_features=self._return_remapped,
)
)

self.weighted_ebc: Optional[ManagedCollisionEmbeddingBagCollection] = None
if weighted_tables:
weighted_mc_modules = {}
for table in weighted_tables:
weighted_mc_modules[table.name] = MCHManagedCollisionModule(
zch_size=table.num_embeddings,
input_hash_size=4000,
device=device,
# TODO: Support MCH evictions during semi-sync
eviction_interval=1000,
eviction_policy=DistanceLFU_EvictionPolicy(),
)
self.weighted_ebc: ManagedCollisionEmbeddingBagCollection = (
ManagedCollisionEmbeddingBagCollection(
EmbeddingBagCollection(
tables=weighted_tables,
device=device,
is_weighted=True,
),
ManagedCollisionCollection(
managed_collision_modules=weighted_mc_modules,
embedding_configs=weighted_tables,
),
return_remapped_features=self._return_remapped,
)
)

def forward(
self,
features: KeyedJaggedTensor,
weighted_features: Optional[KeyedJaggedTensor] = None,
batch_size: Optional[int] = None,
) -> KeyedTensor:
"""
Runs forward and MC EBC and optionally, weighted MC EBC,
then merges the results into one KeyedTensor
Args:
features
weighted_features
batch_size
Returns:
KeyedTensor
"""
ebc, _ = self.ebc(features)
ebc = _post_ebc_test_wrap_function(ebc)
w_ebc, _ = (
self.weighted_ebc(weighted_features)
if self.weighted_ebc is not None and weighted_features is not None
else None
)
result = _post_sparsenn_forward(ebc, None, w_ebc, batch_size)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from unittest.mock import MagicMock

import torch
from hypothesis import given, settings, strategies as st, Verbosity
from hypothesis import assume, given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._dynamo.utils import counters
Expand Down Expand Up @@ -1531,7 +1531,7 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
@settings(max_examples=4, deadline=None)
@settings(max_examples=8, deadline=None)
# pyre-ignore[56]
@given(
start_batch=st.sampled_from([0, 6]),
Expand All @@ -1547,17 +1547,21 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
EmbeddingComputeKernel.FUSED.value,
]
),
zch=st.booleans(),
)
def test_equal_to_non_pipelined(
self,
start_batch: int,
stash_gradients: bool,
sharding_type: str,
kernel_type: str,
zch: bool,
) -> None:
"""
Checks that pipelined training is equivalent to non-pipelined training.
"""
# ZCH only supports row-wise currently
assume(not zch or (zch and sharding_type != ShardingType.TABLE_WISE.value))
torch.autograd.set_detect_anomaly(True)
data = self._generate_data(
num_batches=12,
Expand All @@ -1572,7 +1576,7 @@ def test_equal_to_non_pipelined(
**fused_params,
}

model = self._setup_model()
model = self._setup_model(zch=zch)
sharded_model, optim = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchrec.distributed.test_utils.test_model import (
ModelInput,
TestEBCSharder,
TestEBCSharderMCH,
TestSparseNN,
)
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist
Expand Down Expand Up @@ -96,13 +97,15 @@ def _setup_model(
model_type: Type[nn.Module] = TestSparseNN,
enable_fsdp: bool = False,
postproc_module: Optional[nn.Module] = None,
zch: bool = False,
) -> nn.Module:
unsharded_model = model_type(
tables=self.tables,
weighted_tables=self.weighted_tables,
dense_device=self.device,
sparse_device=torch.device("meta"),
postproc_module=postproc_module,
zch=zch,
)
if enable_fsdp:
unsharded_model.over.dhn_arch.linear0 = FSDP(
Expand Down Expand Up @@ -135,6 +138,11 @@ def _generate_sharded_model_and_optimizer(
kernel_type=kernel_type,
fused_params=fused_params,
)
mc_sharder = TestEBCSharderMCH(
sharding_type=sharding_type,
kernel_type=kernel_type,
fused_params=fused_params,
)
sharded_model = DistributedModelParallel(
module=copy.deepcopy(model),
env=ShardingEnv.from_process_group(self.pg),
Expand All @@ -144,7 +152,11 @@ def _generate_sharded_model_and_optimizer(
cast(
ModuleSharder[nn.Module],
sharder,
)
),
cast(
ModuleSharder[nn.Module],
mc_sharder,
),
],
)
# default fused optimizer is SGD w/ lr=0.1; we need to drop params
Expand Down
Loading

0 comments on commit 09ab2c1

Please sign in to comment.