Skip to content

Commit 2f83768

Browse files
Chang Panfacebook-github-bot
Chang Pan
authored andcommitted
avoid fx Constant Folding in rw_sharding (#2777)
Summary: Tensor_constants nodes are known to have issues with delta update, that the predictor cannot find the tensors in the delta weights. 'tensor_cache' in this diff introduces tenor constant in fx. Solution: The solution is to replace 'tensor_cache' through 'register_buffer' + move device and data_type at first forward (so as to keep performance parity). Note this solution breaks a unit test ``` assertFalse( hasattr( local.ro_ec, "_root_mc_embedding_collection", ) ) ``` but it still won't introduce big table like tbes. Reviewed By: jingsh Differential Revision: D70577218
1 parent 411876a commit 2f83768

File tree

1 file changed

+67
-94
lines changed

1 file changed

+67
-94
lines changed

torchrec/distributed/sharding/rw_sharding.py

+67-94
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
import math
1212

13-
from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar, Union
13+
from typing import Any, cast, Dict, List, Optional, overload, Tuple, TypeVar, Union
1414

1515
import torch
1616
import torch.distributed as dist
@@ -101,13 +101,6 @@ def get_even_shard_sizes(hash_size: int, world_size: int) -> List[int]:
101101
return (embed_sharding, is_even_sharding)
102102

103103

104-
@torch.fx.wrap
105-
def _fx_wrap_block_bucketize_row_pos(
106-
block_bucketize_row_pos: List[torch.Tensor],
107-
) -> Optional[List[torch.Tensor]]:
108-
return block_bucketize_row_pos if block_bucketize_row_pos else None
109-
110-
111104
class BaseRwEmbeddingSharding(EmbeddingSharding[C, F, T, W]):
112105
"""
113106
Base class for row-wise sharding.
@@ -576,66 +569,25 @@ def create_output_dist(
576569
)
577570

578571

579-
@torch.fx.wrap
580-
def get_total_num_buckets_runtime_device(
581-
total_num_buckets: Optional[List[int]],
582-
runtime_device: torch.device,
583-
tensor_cache: Dict[
584-
str,
585-
Tuple[torch.Tensor, List[torch.Tensor]],
586-
],
587-
dtype: torch.dtype = torch.int32,
588-
) -> Optional[torch.Tensor]:
589-
if total_num_buckets is None:
572+
@overload
573+
def convert_tensor(t: torch.Tensor, feature: KeyedJaggedTensor) -> torch.Tensor: ...
574+
@overload
575+
def convert_tensor(t: None, feature: KeyedJaggedTensor) -> None: ...
576+
577+
578+
def convert_tensor(
579+
t: torch.Tensor | None,
580+
feature: KeyedJaggedTensor,
581+
) -> torch.Tensor | None:
582+
# comparing to Optional[Tensor], this solution will keep output as Tensor when input is not None
583+
if t is None:
590584
return None
591-
cache_key: str = "__total_num_buckets"
592-
if cache_key not in tensor_cache:
593-
tensor_cache[cache_key] = (
594-
torch.tensor(
595-
total_num_buckets,
596-
device=runtime_device,
597-
dtype=dtype,
598-
),
599-
[],
600-
)
601-
return tensor_cache[cache_key][0]
602-
603-
604-
@torch.fx.wrap
605-
def get_block_sizes_runtime_device(
606-
block_sizes: List[int],
607-
runtime_device: torch.device,
608-
tensor_cache: Dict[
609-
str,
610-
Tuple[torch.Tensor, List[torch.Tensor]],
611-
],
612-
embedding_shard_metadata: Optional[List[List[int]]] = None,
613-
dtype: torch.dtype = torch.int32,
614-
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
615-
cache_key: str = "__block_sizes"
616-
if cache_key not in tensor_cache:
617-
tensor_cache[cache_key] = (
618-
torch.tensor(
619-
block_sizes,
620-
device=runtime_device,
621-
dtype=dtype,
622-
),
623-
(
624-
[]
625-
if embedding_shard_metadata is None
626-
else [
627-
torch.tensor(
628-
row_pos,
629-
device=runtime_device,
630-
dtype=dtype,
631-
)
632-
for row_pos in embedding_shard_metadata
633-
]
634-
),
585+
else:
586+
return t.to(
587+
device=feature.device(),
588+
dtype=feature.values().dtype,
635589
)
636590

637-
return tensor_cache[cache_key]
638-
639591

640592
class InferRwSparseFeaturesDist(BaseSparseFeaturesDist[InputDistOutputs]):
641593
def __init__(
@@ -659,7 +611,7 @@ def __init__(
659611
self._world_size: int = world_size
660612
self._num_features = num_features
661613
self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets
662-
self.feature_block_sizes: List[int] = []
614+
feature_block_sizes: List[int] = []
663615
for i, hash_size in enumerate(feature_hash_sizes):
664616
block_divisor = self._world_size
665617
if (
@@ -668,12 +620,11 @@ def __init__(
668620
):
669621
assert feature_total_num_buckets[i] % self._world_size == 0
670622
block_divisor = feature_total_num_buckets[i]
671-
self.feature_block_sizes.append(
672-
(hash_size + block_divisor - 1) // block_divisor
673-
)
674-
self.tensor_cache: Dict[
675-
str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
676-
] = {}
623+
feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor)
624+
self.register_buffer(
625+
"feature_block_sizes",
626+
torch.tensor(feature_block_sizes),
627+
)
677628

678629
self._dist = KJTOneToAll(
679630
splits=self._world_size * [self._num_features],
@@ -683,44 +634,66 @@ def __init__(
683634
self._is_sequence = is_sequence
684635
self._has_feature_processor = has_feature_processor
685636
self._need_pos = need_pos
686-
687-
self._embedding_shard_metadata: Optional[List[List[int]]] = (
688-
embedding_shard_metadata
689-
)
637+
embedding_shard_metadata = embedding_shard_metadata or []
638+
for i, row_pos in enumerate(embedding_shard_metadata):
639+
self.register_buffer(f"row_pos_{i}", torch.tensor(row_pos))
640+
self.embedding_shard_metadata_len: int = len(embedding_shard_metadata)
690641
self._keep_original_indices = keep_original_indices
691-
692-
def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
693-
block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device(
694-
self.feature_block_sizes,
695-
sparse_features.device(),
696-
self.tensor_cache,
697-
self._embedding_shard_metadata,
698-
sparse_features.values().dtype,
642+
# pyre-ignore[8]
643+
self.register_buffer(
644+
"feature_total_num_buckets",
645+
(
646+
torch.tensor(feature_total_num_buckets)
647+
if feature_total_num_buckets
648+
else None
649+
),
699650
)
700-
total_num_buckets = get_total_num_buckets_runtime_device(
701-
self._feature_total_num_buckets,
702-
sparse_features.device(),
703-
self.tensor_cache,
704-
sparse_features.values().dtype,
651+
self.forwarded: bool = False
652+
653+
def get_block_bucketize_row_pos(self) -> Optional[List[torch.Tensor]]:
654+
return [
655+
getattr(self, f"row_pos_{i}")
656+
for i in range(self.embedding_shard_metadata_len)
657+
] or None
658+
659+
def move_buffer(self, sparse_features: KeyedJaggedTensor) -> None:
660+
# buffer should only be moved once, even if this method being executed multiple times. as later 'to' should return same tensor after first convert
661+
self.feature_block_sizes = convert_tensor(
662+
t=self.feature_block_sizes, feature=sparse_features
705663
)
664+
self.feature_total_num_buckets = convert_tensor(
665+
t=self.feature_total_num_buckets, feature=sparse_features
666+
)
667+
for i in range(self.embedding_shard_metadata_len):
668+
setattr(
669+
self,
670+
f"row_pos_{i}",
671+
convert_tensor(
672+
t=getattr(self, f"row_pos_{i}"),
673+
feature=sparse_features,
674+
),
675+
)
706676

677+
def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
678+
if not self.forwarded:
679+
# after fx tracing, 'if' will be removed, and below line will actually be called multiple times. but it's ok as 'to' will return same tensor after first convert.
680+
self.move_buffer(sparse_features)
681+
self.forwarded = True
707682
(
708683
bucketized_features,
709684
unbucketize_permute_tensor,
710685
bucket_mapping_tensor_opt,
711686
) = bucketize_kjt_inference(
712687
sparse_features,
713688
num_buckets=self._world_size,
714-
block_sizes=block_sizes,
715-
total_num_buckets=total_num_buckets,
689+
block_sizes=self.feature_block_sizes,
690+
total_num_buckets=self.feature_total_num_buckets,
716691
bucketize_pos=(
717692
self._has_feature_processor
718693
if sparse_features.weights_or_none() is None
719694
else self._need_pos
720695
),
721-
block_bucketize_row_pos=_fx_wrap_block_bucketize_row_pos(
722-
block_bucketize_row_pos
723-
),
696+
block_bucketize_row_pos=self.get_block_bucketize_row_pos(),
724697
is_sequence=self._is_sequence,
725698
keep_original_indices=self._keep_original_indices,
726699
)

0 commit comments

Comments
 (0)