Skip to content

Commit fdfc4e0

Browse files
Ivan Kobzarevfacebook-github-bot
Ivan Kobzarev
authored andcommitted
EmbeddingShardingContext fields no default_factory for dynamo (#1712)
Summary: Pull Request resolved: #1712 dynamo does not support dataclass.field default_factory with Lists, avoiding them for now with specifying all arguments. Reviewed By: Microve Differential Revision: D53854370 fbshipit-source-id: b469f4a8acbcddbc2b9dca43765e11bd99429a3a
1 parent f1fb67a commit fdfc4e0

File tree

2 files changed

+65
-14
lines changed

2 files changed

+65
-14
lines changed

torchrec/distributed/embedding_sharding.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import abc
99
import copy
10-
from dataclasses import dataclass, field
10+
from dataclasses import dataclass
1111
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
1212

1313
import torch
@@ -614,12 +614,33 @@ def _wait_impl(self) -> Awaitable[ListOfKJTList]:
614614
W = TypeVar("W")
615615

616616

617-
@dataclass
618617
class EmbeddingShardingContext(Multistreamable):
619-
batch_size_per_rank: List[int] = field(default_factory=list)
620-
batch_size_per_rank_per_feature: List[List[int]] = field(default_factory=list)
621-
batch_size_per_feature_pre_a2a: List[int] = field(default_factory=list)
622-
variable_batch_per_feature: bool = False
618+
# Torch Dynamo does not support default_factory=list:
619+
# https://github.com/pytorch/pytorch/issues/120108
620+
# TODO(ivankobzarev) Make this a dataclass once supported
621+
622+
def __init__(
623+
self,
624+
batch_size_per_rank: Optional[List[int]] = None,
625+
batch_size_per_rank_per_feature: Optional[List[List[int]]] = None,
626+
batch_size_per_feature_pre_a2a: Optional[List[int]] = None,
627+
variable_batch_per_feature: bool = False,
628+
) -> None:
629+
super().__init__()
630+
self.batch_size_per_rank: List[int] = (
631+
batch_size_per_rank if batch_size_per_rank is not None else []
632+
)
633+
self.batch_size_per_rank_per_feature: List[List[int]] = (
634+
batch_size_per_rank_per_feature
635+
if batch_size_per_rank_per_feature is not None
636+
else []
637+
)
638+
self.batch_size_per_feature_pre_a2a: List[int] = (
639+
batch_size_per_feature_pre_a2a
640+
if batch_size_per_feature_pre_a2a is not None
641+
else []
642+
)
643+
self.variable_batch_per_feature: bool = variable_batch_per_feature
623644

624645
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
625646
pass

torchrec/distributed/sharding/sequence_sharding.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from dataclasses import dataclass, field
8+
from dataclasses import dataclass
99
from typing import List, Optional
1010

1111
import torch
@@ -15,7 +15,6 @@
1515
from torchrec.streamable import Multistreamable
1616

1717

18-
@dataclass
1918
class SequenceShardingContext(EmbeddingShardingContext):
2019
"""
2120
Stores KJTAllToAll context and reuses it in SequenceEmbeddingsAllToAll.
@@ -32,12 +31,43 @@ class SequenceShardingContext(EmbeddingShardingContext):
3231
input dist.
3332
"""
3433

35-
features_before_input_dist: Optional[KeyedJaggedTensor] = None
36-
input_splits: List[int] = field(default_factory=list)
37-
output_splits: List[int] = field(default_factory=list)
38-
sparse_features_recat: Optional[torch.Tensor] = None
39-
unbucketize_permute_tensor: Optional[torch.Tensor] = None
40-
lengths_after_input_dist: Optional[torch.Tensor] = None
34+
# Torch Dynamo does not support default_factory=list:
35+
# https://github.com/pytorch/pytorch/issues/120108
36+
# TODO(ivankobzarev): Make this a dataclass once supported
37+
38+
def __init__(
39+
self,
40+
# Fields of EmbeddingShardingContext
41+
batch_size_per_rank: Optional[List[int]] = None,
42+
batch_size_per_rank_per_feature: Optional[List[List[int]]] = None,
43+
batch_size_per_feature_pre_a2a: Optional[List[int]] = None,
44+
variable_batch_per_feature: bool = False,
45+
# Fields of SequenceShardingContext
46+
features_before_input_dist: Optional[KeyedJaggedTensor] = None,
47+
input_splits: Optional[List[int]] = None,
48+
output_splits: Optional[List[int]] = None,
49+
sparse_features_recat: Optional[torch.Tensor] = None,
50+
unbucketize_permute_tensor: Optional[torch.Tensor] = None,
51+
lengths_after_input_dist: Optional[torch.Tensor] = None,
52+
) -> None:
53+
super().__init__(
54+
batch_size_per_rank,
55+
batch_size_per_rank_per_feature,
56+
batch_size_per_feature_pre_a2a,
57+
variable_batch_per_feature,
58+
)
59+
self.features_before_input_dist: Optional[
60+
KeyedJaggedTensor
61+
] = features_before_input_dist
62+
self.input_splits: List[int] = input_splits if input_splits is not None else []
63+
self.output_splits: List[int] = (
64+
output_splits if output_splits is not None else []
65+
)
66+
self.sparse_features_recat: Optional[torch.Tensor] = sparse_features_recat
67+
self.unbucketize_permute_tensor: Optional[
68+
torch.Tensor
69+
] = unbucketize_permute_tensor
70+
self.lengths_after_input_dist: Optional[torch.Tensor] = lengths_after_input_dist
4171

4272
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
4373
if self.features_before_input_dist is not None:

0 commit comments

Comments
 (0)