Skip to content

Commit 3b91dd2

Browse files
ZhengkaiZfacebook-github-bot
authored andcommitted
Update PositionWeightedModule to make it jit trace compitable
Summary: In ``` torch.ops.fbgemm.offsets_range(features[key].offsets().long(), torch.numel(features[key].values()) ``` This part ``` torch.numel(features[key].values() ``` will be traced into constant Reviewed By: snabelkabiya, houseroad Differential Revision: D53744703 fbshipit-source-id: bf87eeaacf295b82b4591fb53029f66868ab8d86
1 parent a8e1675 commit 3b91dd2

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

torchrec/modules/feature_processor.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def position_weighted_module_update_features(
4040
return features
4141

4242

43+
@torch.jit.script_if_tracing
44+
@torch.fx.wrap
45+
def offsets_to_range_traceble(
46+
offsets: torch.Tensor, values: torch.Tensor
47+
) -> torch.Tensor:
48+
return torch.ops.fbgemm.offsets_range(offsets.long(), torch.numel(values))
49+
50+
4351
# Will be deprecated soon, please use PositionWeightedProcessor, see full doc below
4452
class PositionWeightedModule(BaseFeatureProcessor):
4553
"""
@@ -86,8 +94,8 @@ def forward(
8694

8795
weighted_features: Dict[str, JaggedTensor] = {}
8896
for key, position_weight in self.position_weights.items():
89-
seq = torch.ops.fbgemm.offsets_range(
90-
features[key].offsets().long(), torch.numel(features[key].values())
97+
seq = offsets_to_range_traceble(
98+
features[key].offsets(), features[key].values()
9199
)
92100
weighted_features[key] = JaggedTensor(
93101
values=features[key].values(),

0 commit comments

Comments
 (0)