From 5abe9e22d9980483943ab93e120648f3e52d527e Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 8 Nov 2024 23:49:31 +0000 Subject: [PATCH] [Dynamo] allow dynamic callables on tensor variables (#137940) Fixes https://github.com/pytorch/pytorch/issues/134844 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137940 Approved by: https://github.com/williamwen42 --- .../pr_time_benchmarks/expected_results.csv | 31 ++++++++--------- test/dynamo/test_misc.py | 14 ++++++++ torch/_dynamo/variables/misc.py | 10 +++++- torch/_dynamo/variables/tensor.py | 33 ++++++++++++++----- 4 files changed, 63 insertions(+), 25 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index dcaadaeed87646..7063ae80e595d1 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,20 +1,20 @@ -add_loop_eager,compile_time_instruction_count,3073000000,0.015 +add_loop_eager,compile_time_instruction_count,3077000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,5700000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,5719000000,0.025 -add_loop_inductor,compile_time_instruction_count,24580000000,0.015 +add_loop_inductor,compile_time_instruction_count,24630000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40810000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40910000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,23290000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,23330000000,0.015 @@ -22,43 +22,44 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,1037000000,0.01 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19200000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19210000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15820000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15840000000,0.015 -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16890000000,0.2 +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16510000000,0.2 -update_hint_regression,compile_time_instruction_count,1757000000,0.02 +update_hint_regression,compile_time_instruction_count,1753000000,0.02 -sum_floordiv_regression,compile_time_instruction_count,1171000000,0.015 +sum_floordiv_regression,compile_time_instruction_count,1241000000,0.015 -symint_sum,compile_time_instruction_count,3321000000,0.015 +symint_sum,compile_time_instruction_count,3331000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2014000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2011000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5826000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5827000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,9022000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,9054000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3848000000,0.015 + +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3844000000,0.015 diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 177c26d7952067..34555ee71417ee 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1201,6 +1201,20 @@ def fn(x): inp = torch.ones(2, 2) fn(inp) + def test_tensor_dynamic_method(self): + def add_one(x): + return x + 1 + + t = torch.nn.Parameter(torch.ones(1)) + t.add_one = add_one + + @torch.compile(fullgraph=True) + def fn(x): + return t.add_one(t) + x + + result = fn(torch.ones(1)) + self.assertEqual(torch.ones(1) + 2, result) + def test_shape_unpack(self): def fn(x): a, b = x.size() diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 21005f9b56453a..f5c1380e09a64e 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -963,15 +963,23 @@ def call_function( class GetAttrVariable(VariableTracker): _nonvar_fields = { "name", + "py_type", *VariableTracker._nonvar_fields, } - def __init__(self, obj, name, **kwargs) -> None: + def __init__(self, obj, name, py_type=None, **kwargs) -> None: super().__init__(**kwargs) assert isinstance(obj, VariableTracker) assert isinstance(name, str) self.obj = obj self.name = name + self.py_type = py_type # In some cases we know the type (ex. tensor methods) + + def python_type(self): + if self.py_type is not None: + return self.py_type + else: + super().python_type() def __repr__(self) -> str: return f"{self.__class__.__name__}({self.obj}, {self.name})" diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index bf5e7ddfbeaec3..6029ffd7e875e0 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -89,6 +89,16 @@ ) +def is_bound_tensor_method(value): + return ( + callable(value) + and not torch._dynamo.utils.object_has_getattribute(value) + and hasattr(value, "__self__") + and isinstance(value.__self__, torch.Tensor) + and getattr(value.__self__, value.__name__, None) + ) + + class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" @@ -273,14 +283,19 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name): raise NotImplementedError real_value = getattr(_input_associated_real_value, name) - if callable(real_value): - # Callables have more nuanced handling, and we should let the existing system delegate here. - # Raising was past behavior and so should always be sound to fall back. - # Note - at a certain point we may want to handle - raise NotImplementedError attr_source = AttrSource(self.source, name) install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) + + # Typically we'd want to use variable builder here + # but unfortunately id(real_value.__self__) is not id() + if is_bound_tensor_method(real_value): + from .misc import GetAttrVariable + + return GetAttrVariable( + self, name, source=attr_source, py_type=type(real_value) + ) + return VariableTracker.build(tx, real_value, attr_source) def method_attr_ndim(self, tx): @@ -522,16 +537,16 @@ def call_method( # Only override builtin tensor methods # The user can manually add override handling # with a decorator for other methods (e.g. a dispatch subclass with other methods) - has_torch_function_override = False + is_base_tensor_method = False try: inspect.getattr_static(torch.Tensor, name) - has_torch_function_override = True + is_base_tensor_method = True except AttributeError: - has_torch_function_override = False + is_base_tensor_method = False if ( can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs) - and has_torch_function_override + and is_base_tensor_method ): if self.source: func_var = VariableBuilder(