Skip to content

Commit 74a3b9c

Browse files
alexbeloipytorchmergebot
authored andcommitted
[fx][acc_tracer] fix defaulted placeholder normalization (pytorch#73406)
Summary: Pull Request resolved: pytorch#73406 Placeholder defaults are stored in `node.args`, during normalization we had dropped these. This diff passes the default args through the normalization transformation. Test Plan: Added tests to cover cases with optional inputs, test covers * nothing passed to optional input * `None` passed to optional input * a tensor passed to optional input Reviewed By: jfix71 Differential Revision: D34463493 fbshipit-source-id: f0c3a4083cb3dd4a69111a758561f0d2c0609787 (cherry picked from commit 7fb482c)
1 parent bbdb758 commit 74a3b9c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch/fx/interpreter.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._symbolic_trace import Tracer
66
from ._compatibility import compatibility
77
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
8+
import inspect
89

910
@compatibility(is_backward_compatible=True)
1011
class Interpreter:
@@ -407,7 +408,8 @@ def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D
407408
kwargs (Dict): Dict of keyword arguments for this invocation
408409
"""
409410
assert isinstance(target, str)
410-
return Proxy(self.new_graph.placeholder(target), self.tracer)
411+
default_value = next(iter(args)) if args else inspect.Signature.empty
412+
return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
411413

412414
@compatibility(is_backward_compatible=True)
413415
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:

0 commit comments

Comments
 (0)