Skip to content

Commit

Permalink
[Dynamo] allow dynamic callables on tensor variables (pytorch#137940)
Browse files Browse the repository at this point in the history
Fixes pytorch#134844

Pull Request resolved: pytorch#137940
Approved by: https://github.com/williamwen42
  • Loading branch information
mlazos authored and Ryo-not-rio committed Dec 2, 2024
1 parent ab0a1ec commit 5abe9e2
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 25 deletions.
31 changes: 16 additions & 15 deletions benchmarks/dynamo/pr_time_benchmarks/expected_results.csv
Original file line number Diff line number Diff line change
@@ -1,64 +1,65 @@
add_loop_eager,compile_time_instruction_count,3073000000,0.015
add_loop_eager,compile_time_instruction_count,3077000000,0.015



add_loop_eager_dynamic,compile_time_instruction_count,5700000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5719000000,0.025



add_loop_inductor,compile_time_instruction_count,24580000000,0.015
add_loop_inductor,compile_time_instruction_count,24630000000,0.015



add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40810000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40910000000,0.025



add_loop_inductor_gpu,compile_time_instruction_count,23290000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,23330000000,0.015



basic_modules_ListOfLinears_eager,compile_time_instruction_count,1037000000,0.015



basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19200000000,0.015

basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19210000000,0.015


basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15820000000,0.015

basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15840000000,0.015


basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16890000000,0.2

basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16510000000,0.2


update_hint_regression,compile_time_instruction_count,1757000000,0.02

update_hint_regression,compile_time_instruction_count,1753000000,0.02


sum_floordiv_regression,compile_time_instruction_count,1171000000,0.015

sum_floordiv_regression,compile_time_instruction_count,1241000000,0.015


symint_sum,compile_time_instruction_count,3321000000,0.015

symint_sum,compile_time_instruction_count,3331000000,0.015


aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2014000000,0.015

aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2011000000,0.015


aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5826000000,0.015

aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5827000000,0.015


aotdispatcher_partitioner_cpu,compile_time_instruction_count,9022000000,0.015

aotdispatcher_partitioner_cpu,compile_time_instruction_count,9054000000,0.015


aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3848000000,0.015

aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3844000000,0.015



Expand Down
14 changes: 14 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,20 @@ def fn(x):
inp = torch.ones(2, 2)
fn(inp)

def test_tensor_dynamic_method(self):
def add_one(x):
return x + 1

t = torch.nn.Parameter(torch.ones(1))
t.add_one = add_one

@torch.compile(fullgraph=True)
def fn(x):
return t.add_one(t) + x

result = fn(torch.ones(1))
self.assertEqual(torch.ones(1) + 2, result)

def test_shape_unpack(self):
def fn(x):
a, b = x.size()
Expand Down
10 changes: 9 additions & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,15 +963,23 @@ def call_function(
class GetAttrVariable(VariableTracker):
_nonvar_fields = {
"name",
"py_type",
*VariableTracker._nonvar_fields,
}

def __init__(self, obj, name, **kwargs) -> None:
def __init__(self, obj, name, py_type=None, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(obj, VariableTracker)
assert isinstance(name, str)
self.obj = obj
self.name = name
self.py_type = py_type # In some cases we know the type (ex. tensor methods)

def python_type(self):
if self.py_type is not None:
return self.py_type
else:
super().python_type()

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.obj}, {self.name})"
Expand Down
33 changes: 24 additions & 9 deletions torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@
)


def is_bound_tensor_method(value):
return (
callable(value)
and not torch._dynamo.utils.object_has_getattribute(value)
and hasattr(value, "__self__")
and isinstance(value.__self__, torch.Tensor)
and getattr(value.__self__, value.__name__, None)
)


class TensorVariable(VariableTracker):
"""A torch.Tensor input or an intermediate value in the FX graph"""

Expand Down Expand Up @@ -273,14 +283,19 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
raise NotImplementedError

real_value = getattr(_input_associated_real_value, name)
if callable(real_value):
# Callables have more nuanced handling, and we should let the existing system delegate here.
# Raising was past behavior and so should always be sound to fall back.
# Note - at a certain point we may want to handle
raise NotImplementedError

attr_source = AttrSource(self.source, name)
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))

# Typically we'd want to use variable builder here
# but unfortunately id(real_value.__self__) is not id(<original value>)
if is_bound_tensor_method(real_value):
from .misc import GetAttrVariable

return GetAttrVariable(
self, name, source=attr_source, py_type=type(real_value)
)

return VariableTracker.build(tx, real_value, attr_source)

def method_attr_ndim(self, tx):
Expand Down Expand Up @@ -522,16 +537,16 @@ def call_method(
# Only override builtin tensor methods
# The user can manually add override handling
# with a decorator for other methods (e.g. a dispatch subclass with other methods)
has_torch_function_override = False
is_base_tensor_method = False
try:
inspect.getattr_static(torch.Tensor, name)
has_torch_function_override = True
is_base_tensor_method = True
except AttributeError:
has_torch_function_override = False
is_base_tensor_method = False

if (
can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs)
and has_torch_function_override
and is_base_tensor_method
):
if self.source:
func_var = VariableBuilder(
Expand Down

0 comments on commit 5abe9e2

Please sign in to comment.