Skip to content

Commit a0bffa4

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Add fused_params to mch sharders (#1649)
Summary: This diff is not really ideal. Ideally we can deprecate the sharder fused params path. The question is which one would happen first: 1. zch + uvm caching 2. the use of per table CLF. Differential Revision: D52921362
1 parent 239c033 commit a0bffa4

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchrec/distributed/mc_embedding_modules.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import logging
9-
from typing import Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Union
9+
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
1010

1111
import torch
1212
from torch.autograd.profiler import record_function
@@ -276,3 +276,8 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
276276
set(self._mc_sharder.sharding_types(compute_device_type)),
277277
)
278278
)
279+
280+
@property
281+
def fused_params(self) -> Optional[Dict[str, Any]]:
282+
# TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints
283+
return self._e_sharder.fused_params

0 commit comments

Comments
 (0)