Skip to content

Commit

Permalink
[PT2][Optimus] fix the default alpha and beta values (#139857)
Browse files Browse the repository at this point in the history
Summary:
We noticed that the default coefficient values for beta and alpha should be int 1, instead of float 1.0, which will cause error when the inputs for the add are int types.

More contex:

https://fb.workplace.com/groups/1075192433118967/permalink/1539142760057263/

Test Plan:
# local reproduce
```
buck2 run mode/opt scripts/shuaiyang:test -- --optimus --flow_id 660724017 2>&1 | tee ~/local_run_shuai_660724017.txt
```

trace link: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/mengluy/2024-11-05-21-18-17/trace.json.gz&bucket=gpu_traces

# E2E

before fix:
f660724017

after fix:

Differential Revision: D65521638

Pull Request resolved: #139857
Approved by: https://github.com/jackiexu1992
  • Loading branch information
mengluy0125 authored and pytorchmergebot committed Nov 7, 2024
1 parent 72d3f5b commit d0da40a
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@

log = logging.getLogger(__name__)

DEFAULT_BETA = 1
DEFAULT_ALPHA = 1

MIN_FUSE_SET_SIZE = 5
MAX_FUSE_SET_SIZE = 300
MAX_FUSE_SEARCH_DEPTH = 5
Expand Down Expand Up @@ -178,7 +181,8 @@ class PostGradBatchLinearFusion(BatchFusion):
def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
# pyre-fixme[7]: Incompatible return type
return (
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value]
node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA
and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA # type: ignore[return-value]
)

def _is_input_2d(self, input: torch.fx.Node) -> bool:
Expand Down Expand Up @@ -303,8 +307,8 @@ def _addmm_node_can_be_fused(self, node: torch.fx.Node):
input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr]
weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr]
return (
node.kwargs.get("beta", 1.0) == 1.0
and node.kwargs.get("alpha", 1.0) == 1.0
node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA
and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA
and len(input_shape) == 2
and len(weight_shape) == 2
and all(x % 2 == 0 for x in input_shape + weight_shape)
Expand Down Expand Up @@ -411,7 +415,7 @@ def match(self, node: torch.fx.Node):
if CallFunctionVarArgs(self.op).match(
node
) and self._pointwise_node_can_be_fused(node):
alpha = node.kwargs.get("alpha", 1.0)
alpha = node.kwargs.get("alpha", DEFAULT_ALPHA)
rounding_mode = node.kwargs.get("rounding_mode", None)
input, other = node.args
shape = list(input.meta["val"].shape) # type: ignore[union-attr]
Expand Down Expand Up @@ -445,7 +449,7 @@ def match(self, node: torch.fx.Node):

def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_inputs, batch_others = [], []
alpha = subset[0].kwargs.get("alpha", 1.0)
alpha = subset[0].kwargs.get("alpha", DEFAULT_ALPHA)
batch_inputs_meta, batch_others_meta = [], []

for node in subset:
Expand Down

0 comments on commit d0da40a

Please sign in to comment.