Skip to content

Commit

Permalink
2025-02-21 nightly release (e00868c)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Feb 21, 2025
1 parent 1f788ab commit b86ccab
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 194 deletions.
8 changes: 8 additions & 0 deletions .github/scripts/validate_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
export PYTORCH_CUDA_PKG=""
export CONDA_ENV="build_binary"

if [[ ${MATRIX_PYTHON_VERSION} = '3.13t' ]]; then
echo "Conda doesn't support 3.13t yet, you can just try \`conda create -n test python=3.13t\`"
exit 0
fi

conda create -y -n "${CONDA_ENV}" python="${MATRIX_PYTHON_VERSION}"

conda run -n build_binary python --version
Expand Down Expand Up @@ -80,6 +85,9 @@ conda run -n "${CONDA_ENV}" pip install fbgemm-gpu --index-url "$PYTORCH_URL"
# install requirements from pypi
conda run -n "${CONDA_ENV}" pip install torchmetrics==1.0.3

# install tensordict from pypi
conda run -n "${CONDA_ENV}" pip install tensordict==0.7.1

# install torchrec
conda run -n "${CONDA_ENV}" pip install torchrec --index-url "$PYTORCH_URL"

Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/build_dynamic_embedding_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,15 @@ jobs:
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: artifact-${{ matrix.os }}-${{ matrix.pyver }}-cu${{ matrix.cuver }}
path: wheelhouse/*.whl

merge:
runs-on: ubuntu-latest
needs: build_wheels
steps:
- name: Merge Artifacts
uses: actions/upload-artifact/merge@v4
with:
name: artifact
pattern: artifact-*
6 changes: 5 additions & 1 deletion torchrec/distributed/sharding/cw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ def _shard(
for i, rank in enumerate(info.param_sharding.ranks):
# Remap rank by number of replica groups if 2D parallelism is enabled
rank = (
self._env.remap_rank(rank, ShardingType.COLUMN_WISE) # pyre-ignore[16]
# pyre-ignore[16]
self._env.remap_rank(
rank,
ShardingType.COLUMN_WISE,
)
if self._is_2D_parallel
else rank
)
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/sharding/grid_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ def _shard(
# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
rank = (
self._env.remap_rank(rank, ShardingType.GRID_SHARD) # pyre-ignore[16]
# pyre-ignore[16]
self._env.remap_rank(rank, ShardingType.GRID_SHARD)
if self._is_2D_parallel
else rank
)
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/sharding/tw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def _shard(
rank = (
# pyre-ignore [16]
self._env.remap_rank(
info.param_sharding.ranks[0], ShardingType.TABLE_WISE # pyre-ignore[16]
info.param_sharding.ranks[0], # pyre-ignore[16]
ShardingType.TABLE_WISE,
)
if self._is_2D_parallel
else info.param_sharding.ranks[0]
Expand Down
97 changes: 97 additions & 0 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,103 @@ def forward(
return pred.sum(), pred


class TestModelWithPreprocCollectionArgs(nn.Module):
"""
Basic module with up to 3 postproc modules:
- postproc on idlist_features for non-weighted EBC
- postproc on idscore_features for weighted EBC
- postproc_inner on model input shared by both EBCs
- postproc_outer providing input to postproc_b (aka nested postproc)
Args:
tables,
weighted_tables,
device,
postproc_module_outer,
postproc_module_nested,
num_float_features,
Example:
>>> TestModelWithPreprocWithListArg(tables, weighted_tables, device)
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""

CONST_DICT_KEY = "const"
INPUT_TENSOR_DICT_KEY = "tensor_from_input"
POSTPTOC_TENSOR_DICT_KEY = "tensor_from_postproc"

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
postproc_module_outer: nn.Module,
postproc_module_nested: nn.Module,
num_float_features: int = 10,
) -> None:
super().__init__()
self.dense = TestDenseArch(num_float_features, device)

self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
tables=tables,
device=device,
)
self.weighted_ebc = EmbeddingBagCollection(
tables=weighted_tables,
is_weighted=True,
device=device,
)
self.postproc_nonweighted = TestPreprocNonWeighted()
self.postproc_weighted = TestPreprocWeighted()
self._postproc_module_outer = postproc_module_outer
self._postproc_module_nested = postproc_module_nested

def forward(
self,
input: ModelInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs preproc for EBC and weighted EBC, optionally runs postproc for input
Args:
input
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""
modified_input = input

outer_postproc_input = self._postproc_module_outer(modified_input)

preproc_input_list = [
1,
modified_input.float_features,
outer_postproc_input,
]
preproc_input_dict = {
self.CONST_DICT_KEY: 1,
self.INPUT_TENSOR_DICT_KEY: modified_input.float_features,
self.POSTPTOC_TENSOR_DICT_KEY: outer_postproc_input,
}

modified_input = self._postproc_module_nested(
modified_input, preproc_input_list, preproc_input_dict
)

modified_idlist_features = self.postproc_nonweighted(
modified_input.idlist_features
)
modified_idscore_features = self.postproc_weighted(
modified_input.idscore_features
)
ebc_out = self.ebc(modified_idlist_features[0])
weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])

pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
return pred.sum(), pred


class TestNegSamplingModule(torch.nn.Module):
"""
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from contextlib import ExitStack
from dataclasses import dataclass
from functools import partial
from typing import cast, List, Optional, Tuple, Type, Union
from typing import cast, Dict, List, Optional, Tuple, Type, Union
from unittest.mock import MagicMock

import torch
from hypothesis import given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._dynamo.utils import counters
from torch.fx._symbolic_trace import is_fx_tracing
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
Expand All @@ -36,6 +37,7 @@
ModelInput,
TestEBCSharder,
TestModelWithPreproc,
TestModelWithPreprocCollectionArgs,
TestNegSamplingModule,
TestPositionWeightedPreprocModule,
TestSparseNN,
Expand Down Expand Up @@ -1448,6 +1450,81 @@ def forward(
self.assertEqual(len(pipeline._pipelined_modules), 2)
self.assertEqual(len(pipeline._pipelined_postprocs), 1)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_pipeline_postproc_with_collection_args(self) -> None:
"""
Exercises scenario when postproc module has an argument that is a list or dict
with some elements being:
* static scalars
* static tensors (e.g. torch.ones())
* tensors derived from input batch (e.g. input.idlist_features["feature_0"])
* tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"]))
"""
test_runner = self

class PostprocOuter(nn.Module):
def __init__(
self,
) -> None:
super().__init__()

def forward(
self,
model_input: ModelInput,
) -> torch.Tensor:
return model_input.float_features * 0.1

class PostprocInner(nn.Module):
def __init__(
self,
) -> None:
super().__init__()

def forward(
self,
model_input: ModelInput,
input_list: List[Union[torch.Tensor, int]],
input_dict: Dict[str, Union[torch.Tensor, int]],
) -> ModelInput:
if not is_fx_tracing():
for idx, value in enumerate(input_list):
if isinstance(value, torch.fx.Node):
test_runner.fail(
f"input_list[{idx}] was a fx.Node: {value}"
)
model_input.float_features += value

for key, value in input_dict.items():
if isinstance(value, torch.fx.Node):
test_runner.fail(
f"input_dict[{key}] was a fx.Node: {value}"
)
model_input.float_features += value

return model_input

model = TestModelWithPreprocCollectionArgs(
tables=self.tables[:-1], # ignore last table as postproc will remove
weighted_tables=self.weighted_tables[:-1], # ignore last table
device=self.device,
postproc_module_outer=PostprocOuter(),
postproc_module_nested=PostprocInner(),
)

pipelined_model, pipeline = self._check_output_equal(
model,
self.sharding_type,
)

# both EC end EBC are pipelined
self.assertEqual(len(pipeline._pipelined_modules), 2)
# both outer and nested postproces are pipelined
self.assertEqual(len(pipeline._pipelined_postprocs), 4)


class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
@unittest.skipIf(
Expand Down
Loading

0 comments on commit b86ccab

Please sign in to comment.