Skip to content

Commit 91d679c

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Take cache_load_factor from sharder if sharding option doesn't have it (#1644)
Summary: Pull Request resolved: #1644 Sometimes cache_load_factor is passed through sharders. This is not the ideal way of passing cache_load_factor, but for the time being, we still allow it. So we should reflect that in the stats printout. Reviewed By: ge0405 Differential Revision: D52921387 fbshipit-source-id: f7c70aa9136f71148f21c63ec3257ca18cdc1b60
1 parent 98a28ad commit 91d679c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

torchrec/distributed/planner/stats.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,16 @@ def log(
329329
or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value
330330
else f"{so.tensor.shape[1]}"
331331
)
332-
cache_load_factor = str(so.cache_load_factor)
332+
sharder_cache_load_factor = (
333+
sharder.fused_params.get("cache_load_factor") # pyre-ignore[16]
334+
if hasattr(sharder, "fused_params") and sharder.fused_params
335+
else None
336+
)
337+
cache_load_factor = str(
338+
so.cache_load_factor
339+
if so.cache_load_factor is not None
340+
else sharder_cache_load_factor
341+
)
333342
hash_size = so.tensor.shape[0]
334343
param_table.append(
335344
[

0 commit comments

Comments
 (0)