Skip to content

Commit 7fb1d62

Browse files
seanx92facebook-github-bot
authored andcommitted
fx wrap entrypoint of EmbeddingBagCollection (#1801)
Summary: Pull Request resolved: #1801 We want to tag the part before tbe in ebc to INPUT_DIST net, in order to make the graph structure after splitting similar to sharded ebc. Using fx wrap to tag the entrypoint of ebc and we will use this in the splitting tag rules. Reviewed By: jiayisuse, gnahzg, 842974287 Differential Revision: D55023624 fbshipit-source-id: 6d132f0b313bb8477e053f91c6be0b9e82e7b322
1 parent 66f821e commit 7fb1d62

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

torchrec/quant/embedding_modules.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,13 @@ def _update_embedding_configs(
230230
)
231231

232232

233+
@torch.fx.wrap
234+
def features_to_dict(
235+
features: KeyedJaggedTensor,
236+
) -> Dict[str, JaggedTensor]:
237+
return features.to_dict()
238+
239+
233240
class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin):
234241
"""
235242
EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).
@@ -452,7 +459,7 @@ def forward(
452459
KeyedTensor
453460
"""
454461

455-
feature_dict = features.to_dict()
462+
feature_dict = features_to_dict(features)
456463
embeddings = []
457464

458465
# TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script.

torchrec/quant/tests/test_embedding_modules.py

+18
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torchrec.quant.embedding_modules import (
3030
EmbeddingBagCollection as QuantEmbeddingBagCollection,
3131
EmbeddingCollection as QuantEmbeddingCollection,
32+
features_to_dict,
3233
quant_prep_enable_quant_state_dict_split_scale_bias,
3334
)
3435
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
@@ -446,6 +447,23 @@ def test_trace_and_script(self) -> None:
446447

447448
gm = symbolic_trace(qebc)
448449

450+
non_placeholder_nodes = [
451+
node for node in gm.graph.nodes if node.op != "placeholder"
452+
]
453+
self.assertTrue(
454+
len(non_placeholder_nodes) > 0, "Graph must have non-placeholder nodes"
455+
)
456+
self.assertEqual(
457+
non_placeholder_nodes[0].op,
458+
"call_function",
459+
f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead",
460+
)
461+
self.assertEqual(
462+
non_placeholder_nodes[0].name,
463+
features_to_dict.__name__,
464+
f"First non-placeholder node must be features_to_dict, got {non_placeholder_nodes[0].name} instead",
465+
)
466+
449467
features = KeyedJaggedTensor(
450468
keys=["f1", "f2"],
451469
values=torch.as_tensor([0, 1]),

0 commit comments

Comments
 (0)