Skip to content

Commit

Permalink
2025-03-11 nightly release (e1ee42c)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Mar 11, 2025
1 parent a290146 commit 6a87507
Showing 1 changed file with 67 additions and 94 deletions.
161 changes: 67 additions & 94 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import math

from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, cast, Dict, List, Optional, overload, Tuple, TypeVar, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -101,13 +101,6 @@ def get_even_shard_sizes(hash_size: int, world_size: int) -> List[int]:
return (embed_sharding, is_even_sharding)


@torch.fx.wrap
def _fx_wrap_block_bucketize_row_pos(
block_bucketize_row_pos: List[torch.Tensor],
) -> Optional[List[torch.Tensor]]:
return block_bucketize_row_pos if block_bucketize_row_pos else None


class BaseRwEmbeddingSharding(EmbeddingSharding[C, F, T, W]):
"""
Base class for row-wise sharding.
Expand Down Expand Up @@ -576,66 +569,25 @@ def create_output_dist(
)


@torch.fx.wrap
def get_total_num_buckets_runtime_device(
total_num_buckets: Optional[List[int]],
runtime_device: torch.device,
tensor_cache: Dict[
str,
Tuple[torch.Tensor, List[torch.Tensor]],
],
dtype: torch.dtype = torch.int32,
) -> Optional[torch.Tensor]:
if total_num_buckets is None:
@overload
def convert_tensor(t: torch.Tensor, feature: KeyedJaggedTensor) -> torch.Tensor: ...
@overload
def convert_tensor(t: None, feature: KeyedJaggedTensor) -> None: ...


def convert_tensor(
t: torch.Tensor | None,
feature: KeyedJaggedTensor,
) -> torch.Tensor | None:
# comparing to Optional[Tensor], this solution will keep output as Tensor when input is not None
if t is None:
return None
cache_key: str = "__total_num_buckets"
if cache_key not in tensor_cache:
tensor_cache[cache_key] = (
torch.tensor(
total_num_buckets,
device=runtime_device,
dtype=dtype,
),
[],
)
return tensor_cache[cache_key][0]


@torch.fx.wrap
def get_block_sizes_runtime_device(
block_sizes: List[int],
runtime_device: torch.device,
tensor_cache: Dict[
str,
Tuple[torch.Tensor, List[torch.Tensor]],
],
embedding_shard_metadata: Optional[List[List[int]]] = None,
dtype: torch.dtype = torch.int32,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
cache_key: str = "__block_sizes"
if cache_key not in tensor_cache:
tensor_cache[cache_key] = (
torch.tensor(
block_sizes,
device=runtime_device,
dtype=dtype,
),
(
[]
if embedding_shard_metadata is None
else [
torch.tensor(
row_pos,
device=runtime_device,
dtype=dtype,
)
for row_pos in embedding_shard_metadata
]
),
else:
return t.to(
device=feature.device(),
dtype=feature.values().dtype,
)

return tensor_cache[cache_key]


class InferRwSparseFeaturesDist(BaseSparseFeaturesDist[InputDistOutputs]):
def __init__(
Expand All @@ -659,7 +611,7 @@ def __init__(
self._world_size: int = world_size
self._num_features = num_features
self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets
self.feature_block_sizes: List[int] = []
feature_block_sizes: List[int] = []
for i, hash_size in enumerate(feature_hash_sizes):
block_divisor = self._world_size
if (
Expand All @@ -668,12 +620,11 @@ def __init__(
):
assert feature_total_num_buckets[i] % self._world_size == 0
block_divisor = feature_total_num_buckets[i]
self.feature_block_sizes.append(
(hash_size + block_divisor - 1) // block_divisor
)
self.tensor_cache: Dict[
str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
] = {}
feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor)
self.register_buffer(
"feature_block_sizes",
torch.tensor(feature_block_sizes),
)

self._dist = KJTOneToAll(
splits=self._world_size * [self._num_features],
Expand All @@ -683,44 +634,66 @@ def __init__(
self._is_sequence = is_sequence
self._has_feature_processor = has_feature_processor
self._need_pos = need_pos

self._embedding_shard_metadata: Optional[List[List[int]]] = (
embedding_shard_metadata
)
embedding_shard_metadata = embedding_shard_metadata or []
for i, row_pos in enumerate(embedding_shard_metadata):
self.register_buffer(f"row_pos_{i}", torch.tensor(row_pos))
self.embedding_shard_metadata_len: int = len(embedding_shard_metadata)
self._keep_original_indices = keep_original_indices

def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device(
self.feature_block_sizes,
sparse_features.device(),
self.tensor_cache,
self._embedding_shard_metadata,
sparse_features.values().dtype,
# pyre-ignore[8]
self.register_buffer(
"feature_total_num_buckets",
(
torch.tensor(feature_total_num_buckets)
if feature_total_num_buckets
else None
),
)
total_num_buckets = get_total_num_buckets_runtime_device(
self._feature_total_num_buckets,
sparse_features.device(),
self.tensor_cache,
sparse_features.values().dtype,
self.forwarded: bool = False

def get_block_bucketize_row_pos(self) -> Optional[List[torch.Tensor]]:
return [
getattr(self, f"row_pos_{i}")
for i in range(self.embedding_shard_metadata_len)
] or None

def move_buffer(self, sparse_features: KeyedJaggedTensor) -> None:
# buffer should only be moved once, even if this method being executed multiple times. as later 'to' should return same tensor after first convert
self.feature_block_sizes = convert_tensor(
t=self.feature_block_sizes, feature=sparse_features
)
self.feature_total_num_buckets = convert_tensor(
t=self.feature_total_num_buckets, feature=sparse_features
)
for i in range(self.embedding_shard_metadata_len):
setattr(
self,
f"row_pos_{i}",
convert_tensor(
t=getattr(self, f"row_pos_{i}"),
feature=sparse_features,
),
)

def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
if not self.forwarded:
# after fx tracing, 'if' will be removed, and below line will actually be called multiple times. but it's ok as 'to' will return same tensor after first convert.
self.move_buffer(sparse_features)
self.forwarded = True
(
bucketized_features,
unbucketize_permute_tensor,
bucket_mapping_tensor_opt,
) = bucketize_kjt_inference(
sparse_features,
num_buckets=self._world_size,
block_sizes=block_sizes,
total_num_buckets=total_num_buckets,
block_sizes=self.feature_block_sizes,
total_num_buckets=self.feature_total_num_buckets,
bucketize_pos=(
self._has_feature_processor
if sparse_features.weights_or_none() is None
else self._need_pos
),
block_bucketize_row_pos=_fx_wrap_block_bucketize_row_pos(
block_bucketize_row_pos
),
block_bucketize_row_pos=self.get_block_bucketize_row_pos(),
is_sequence=self._is_sequence,
keep_original_indices=self._keep_original_indices,
)
Expand Down

0 comments on commit 6a87507

Please sign in to comment.