10
10
import logging
11
11
import math
12
12
13
- from typing import Any , cast , Dict , List , Optional , Tuple , TypeVar , Union
13
+ from typing import Any , cast , Dict , List , Optional , overload , Tuple , TypeVar , Union
14
14
15
15
import torch
16
16
import torch .distributed as dist
@@ -101,13 +101,6 @@ def get_even_shard_sizes(hash_size: int, world_size: int) -> List[int]:
101
101
return (embed_sharding , is_even_sharding )
102
102
103
103
104
- @torch .fx .wrap
105
- def _fx_wrap_block_bucketize_row_pos (
106
- block_bucketize_row_pos : List [torch .Tensor ],
107
- ) -> Optional [List [torch .Tensor ]]:
108
- return block_bucketize_row_pos if block_bucketize_row_pos else None
109
-
110
-
111
104
class BaseRwEmbeddingSharding (EmbeddingSharding [C , F , T , W ]):
112
105
"""
113
106
Base class for row-wise sharding.
@@ -576,66 +569,25 @@ def create_output_dist(
576
569
)
577
570
578
571
579
- @torch .fx .wrap
580
- def get_total_num_buckets_runtime_device (
581
- total_num_buckets : Optional [List [int ]],
582
- runtime_device : torch .device ,
583
- tensor_cache : Dict [
584
- str ,
585
- Tuple [torch .Tensor , List [torch .Tensor ]],
586
- ],
587
- dtype : torch .dtype = torch .int32 ,
588
- ) -> Optional [torch .Tensor ]:
589
- if total_num_buckets is None :
572
+ @overload
573
+ def convert_tensor (t : torch .Tensor , feature : KeyedJaggedTensor ) -> torch .Tensor : ...
574
+ @overload
575
+ def convert_tensor (t : None , feature : KeyedJaggedTensor ) -> None : ...
576
+
577
+
578
+ def convert_tensor (
579
+ t : torch .Tensor | None ,
580
+ feature : KeyedJaggedTensor ,
581
+ ) -> torch .Tensor | None :
582
+ # comparing to Optional[Tensor], this solution will keep output as Tensor when input is not None
583
+ if t is None :
590
584
return None
591
- cache_key : str = "__total_num_buckets"
592
- if cache_key not in tensor_cache :
593
- tensor_cache [cache_key ] = (
594
- torch .tensor (
595
- total_num_buckets ,
596
- device = runtime_device ,
597
- dtype = dtype ,
598
- ),
599
- [],
600
- )
601
- return tensor_cache [cache_key ][0 ]
602
-
603
-
604
- @torch .fx .wrap
605
- def get_block_sizes_runtime_device (
606
- block_sizes : List [int ],
607
- runtime_device : torch .device ,
608
- tensor_cache : Dict [
609
- str ,
610
- Tuple [torch .Tensor , List [torch .Tensor ]],
611
- ],
612
- embedding_shard_metadata : Optional [List [List [int ]]] = None ,
613
- dtype : torch .dtype = torch .int32 ,
614
- ) -> Tuple [torch .Tensor , List [torch .Tensor ]]:
615
- cache_key : str = "__block_sizes"
616
- if cache_key not in tensor_cache :
617
- tensor_cache [cache_key ] = (
618
- torch .tensor (
619
- block_sizes ,
620
- device = runtime_device ,
621
- dtype = dtype ,
622
- ),
623
- (
624
- []
625
- if embedding_shard_metadata is None
626
- else [
627
- torch .tensor (
628
- row_pos ,
629
- device = runtime_device ,
630
- dtype = dtype ,
631
- )
632
- for row_pos in embedding_shard_metadata
633
- ]
634
- ),
585
+ else :
586
+ return t .to (
587
+ device = feature .device (),
588
+ dtype = feature .values ().dtype ,
635
589
)
636
590
637
- return tensor_cache [cache_key ]
638
-
639
591
640
592
class InferRwSparseFeaturesDist (BaseSparseFeaturesDist [InputDistOutputs ]):
641
593
def __init__ (
@@ -659,7 +611,7 @@ def __init__(
659
611
self ._world_size : int = world_size
660
612
self ._num_features = num_features
661
613
self ._feature_total_num_buckets : Optional [List [int ]] = feature_total_num_buckets
662
- self . feature_block_sizes : List [int ] = []
614
+ feature_block_sizes : List [int ] = []
663
615
for i , hash_size in enumerate (feature_hash_sizes ):
664
616
block_divisor = self ._world_size
665
617
if (
@@ -668,12 +620,11 @@ def __init__(
668
620
):
669
621
assert feature_total_num_buckets [i ] % self ._world_size == 0
670
622
block_divisor = feature_total_num_buckets [i ]
671
- self .feature_block_sizes .append (
672
- (hash_size + block_divisor - 1 ) // block_divisor
673
- )
674
- self .tensor_cache : Dict [
675
- str , Tuple [torch .Tensor , Optional [List [torch .Tensor ]]]
676
- ] = {}
623
+ feature_block_sizes .append ((hash_size + block_divisor - 1 ) // block_divisor )
624
+ self .register_buffer (
625
+ "feature_block_sizes" ,
626
+ torch .tensor (feature_block_sizes ),
627
+ )
677
628
678
629
self ._dist = KJTOneToAll (
679
630
splits = self ._world_size * [self ._num_features ],
@@ -683,44 +634,66 @@ def __init__(
683
634
self ._is_sequence = is_sequence
684
635
self ._has_feature_processor = has_feature_processor
685
636
self ._need_pos = need_pos
686
-
687
- self . _embedding_shard_metadata : Optional [ List [ List [ int ]]] = (
688
- embedding_shard_metadata
689
- )
637
+ embedding_shard_metadata = embedding_shard_metadata or []
638
+ for i , row_pos in enumerate ( embedding_shard_metadata ):
639
+ self . register_buffer ( f"row_pos_ { i } " , torch . tensor ( row_pos ))
640
+ self . embedding_shard_metadata_len : int = len ( embedding_shard_metadata )
690
641
self ._keep_original_indices = keep_original_indices
691
-
692
- def forward ( self , sparse_features : KeyedJaggedTensor ) -> InputDistOutputs :
693
- block_sizes , block_bucketize_row_pos = get_block_sizes_runtime_device (
694
- self . feature_block_sizes ,
695
- sparse_features . device (),
696
- self . tensor_cache ,
697
- self . _embedding_shard_metadata ,
698
- sparse_features . values (). dtype ,
642
+ # pyre-ignore[8]
643
+ self . register_buffer (
644
+ "feature_total_num_buckets" ,
645
+ (
646
+ torch . tensor ( feature_total_num_buckets )
647
+ if feature_total_num_buckets
648
+ else None
649
+ ) ,
699
650
)
700
- total_num_buckets = get_total_num_buckets_runtime_device (
701
- self ._feature_total_num_buckets ,
702
- sparse_features .device (),
703
- self .tensor_cache ,
704
- sparse_features .values ().dtype ,
651
+ self .forwarded : bool = False
652
+
653
+ def get_block_bucketize_row_pos (self ) -> Optional [List [torch .Tensor ]]:
654
+ return [
655
+ getattr (self , f"row_pos_{ i } " )
656
+ for i in range (self .embedding_shard_metadata_len )
657
+ ] or None
658
+
659
+ def move_buffer (self , sparse_features : KeyedJaggedTensor ) -> None :
660
+ # buffer should only be moved once, even if this method being executed multiple times. as later 'to' should return same tensor after first convert
661
+ self .feature_block_sizes = convert_tensor (
662
+ t = self .feature_block_sizes , feature = sparse_features
705
663
)
664
+ self .feature_total_num_buckets = convert_tensor (
665
+ t = self .feature_total_num_buckets , feature = sparse_features
666
+ )
667
+ for i in range (self .embedding_shard_metadata_len ):
668
+ setattr (
669
+ self ,
670
+ f"row_pos_{ i } " ,
671
+ convert_tensor (
672
+ t = getattr (self , f"row_pos_{ i } " ),
673
+ feature = sparse_features ,
674
+ ),
675
+ )
706
676
677
+ def forward (self , sparse_features : KeyedJaggedTensor ) -> InputDistOutputs :
678
+ if not self .forwarded :
679
+ # after fx tracing, 'if' will be removed, and below line will actually be called multiple times. but it's ok as 'to' will return same tensor after first convert.
680
+ self .move_buffer (sparse_features )
681
+ self .forwarded = True
707
682
(
708
683
bucketized_features ,
709
684
unbucketize_permute_tensor ,
710
685
bucket_mapping_tensor_opt ,
711
686
) = bucketize_kjt_inference (
712
687
sparse_features ,
713
688
num_buckets = self ._world_size ,
714
- block_sizes = block_sizes ,
715
- total_num_buckets = total_num_buckets ,
689
+ block_sizes = self . feature_block_sizes ,
690
+ total_num_buckets = self . feature_total_num_buckets ,
716
691
bucketize_pos = (
717
692
self ._has_feature_processor
718
693
if sparse_features .weights_or_none () is None
719
694
else self ._need_pos
720
695
),
721
- block_bucketize_row_pos = _fx_wrap_block_bucketize_row_pos (
722
- block_bucketize_row_pos
723
- ),
696
+ block_bucketize_row_pos = self .get_block_bucketize_row_pos (),
724
697
is_sequence = self ._is_sequence ,
725
698
keep_original_indices = self ._keep_original_indices ,
726
699
)
0 commit comments