5
5
# This source code is licensed under the BSD-style license found in the
6
6
# LICENSE file in the root directory of this source tree.
7
7
8
- from dataclasses import dataclass , field
8
+ from dataclasses import dataclass
9
9
from typing import List , Optional
10
10
11
11
import torch
15
15
from torchrec .streamable import Multistreamable
16
16
17
17
18
- @dataclass
19
18
class SequenceShardingContext (EmbeddingShardingContext ):
20
19
"""
21
20
Stores KJTAllToAll context and reuses it in SequenceEmbeddingsAllToAll.
@@ -32,12 +31,43 @@ class SequenceShardingContext(EmbeddingShardingContext):
32
31
input dist.
33
32
"""
34
33
35
- features_before_input_dist : Optional [KeyedJaggedTensor ] = None
36
- input_splits : List [int ] = field (default_factory = list )
37
- output_splits : List [int ] = field (default_factory = list )
38
- sparse_features_recat : Optional [torch .Tensor ] = None
39
- unbucketize_permute_tensor : Optional [torch .Tensor ] = None
40
- lengths_after_input_dist : Optional [torch .Tensor ] = None
34
+ # Torch Dynamo does not support default_factory=list:
35
+ # https://github.com/pytorch/pytorch/issues/120108
36
+ # TODO(ivankobzarev): Make this a dataclass once supported
37
+
38
+ def __init__ (
39
+ self ,
40
+ # Fields of EmbeddingShardingContext
41
+ batch_size_per_rank : Optional [List [int ]] = None ,
42
+ batch_size_per_rank_per_feature : Optional [List [List [int ]]] = None ,
43
+ batch_size_per_feature_pre_a2a : Optional [List [int ]] = None ,
44
+ variable_batch_per_feature : bool = False ,
45
+ # Fields of SequenceShardingContext
46
+ features_before_input_dist : Optional [KeyedJaggedTensor ] = None ,
47
+ input_splits : Optional [List [int ]] = None ,
48
+ output_splits : Optional [List [int ]] = None ,
49
+ sparse_features_recat : Optional [torch .Tensor ] = None ,
50
+ unbucketize_permute_tensor : Optional [torch .Tensor ] = None ,
51
+ lengths_after_input_dist : Optional [torch .Tensor ] = None ,
52
+ ) -> None :
53
+ super ().__init__ (
54
+ batch_size_per_rank ,
55
+ batch_size_per_rank_per_feature ,
56
+ batch_size_per_feature_pre_a2a ,
57
+ variable_batch_per_feature ,
58
+ )
59
+ self .features_before_input_dist : Optional [
60
+ KeyedJaggedTensor
61
+ ] = features_before_input_dist
62
+ self .input_splits : List [int ] = input_splits if input_splits is not None else []
63
+ self .output_splits : List [int ] = (
64
+ output_splits if output_splits is not None else []
65
+ )
66
+ self .sparse_features_recat : Optional [torch .Tensor ] = sparse_features_recat
67
+ self .unbucketize_permute_tensor : Optional [
68
+ torch .Tensor
69
+ ] = unbucketize_permute_tensor
70
+ self .lengths_after_input_dist : Optional [torch .Tensor ] = lengths_after_input_dist
41
71
42
72
def record_stream (self , stream : torch .cuda .streams .Stream ) -> None :
43
73
if self .features_before_input_dist is not None :
0 commit comments