Skip to content

Commit 0f51389

Browse files
YazhiGaofacebook-github-bot
authored andcommitted
kernel row alignment correction (#1789)
Summary: Pull Request resolved: #1789 * sharder can help inject tbe row alignment via extra params from fused params. Reviewed By: tissue3 Differential Revision: D54842984 fbshipit-source-id: df52030274fefe3eec271b87fdbb61923eafaa66
1 parent cb6b69a commit 0f51389

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

torchrec/distributed/fused_params.py

+12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: str = (
2121
"__register_quant_state_dict_split_scale_bias"
2222
)
23+
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"
2324

2425

2526
class TBEToRegisterMixIn:
@@ -47,6 +48,15 @@ def is_fused_param_register_tbe(fused_params: Optional[Dict[str, Any]]) -> bool:
4748
)
4849

4950

51+
def get_fused_param_tbe_row_alignment(
52+
fused_params: Optional[Dict[str, Any]]
53+
) -> Optional[int]:
54+
if fused_params is None or FUSED_PARAM_TBE_ROW_ALIGNMENT not in fused_params:
55+
return None
56+
else:
57+
return fused_params[FUSED_PARAM_TBE_ROW_ALIGNMENT]
58+
59+
5060
def is_fused_param_quant_state_dict_split_scale_bias(
5161
fused_params: Optional[Dict[str, Any]]
5262
) -> bool:
@@ -68,5 +78,7 @@ def tbe_fused_params(
6878
fused_params_for_tbe.pop(FUSED_PARAM_REGISTER_TBE_BOOL)
6979
if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS in fused_params_for_tbe:
7080
fused_params_for_tbe.pop(FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS)
81+
if FUSED_PARAM_TBE_ROW_ALIGNMENT in fused_params_for_tbe:
82+
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)
7183

7284
return fused_params_for_tbe

torchrec/distributed/quant_embedding_kernel.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
GroupedEmbeddingConfig,
3333
)
3434
from torchrec.distributed.fused_params import (
35+
get_fused_param_tbe_row_alignment,
3536
is_fused_param_quant_state_dict_split_scale_bias,
3637
is_fused_param_register_tbe,
3738
tbe_fused_params,
@@ -318,6 +319,9 @@ def __init__(
318319
self._quant_state_dict_split_scale_bias: bool = (
319320
is_fused_param_quant_state_dict_split_scale_bias(fused_params)
320321
)
322+
self._tbe_row_alignment: Optional[int] = get_fused_param_tbe_row_alignment(
323+
fused_params
324+
)
321325
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
322326
IntNBitTableBatchedEmbeddingBagsCodegen(
323327
embedding_specs=[
@@ -342,7 +346,11 @@ def __init__(
342346
device=device,
343347
pooling_mode=PoolingMode.NONE,
344348
feature_table_map=self._feature_table_map,
345-
row_alignment=16,
349+
row_alignment=(
350+
self._tbe_row_alignment
351+
if self._tbe_row_alignment is not None
352+
else 16
353+
),
346354
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
347355
**(tbe_fused_params(fused_params) or {}),
348356
)

0 commit comments

Comments
 (0)