From ac981eba2cfaf45c61000abea7ef5b0c5297799f Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 13 Nov 2024 16:26:44 +0200 Subject: [PATCH] Reduce CPU overhead of FusionDefinitionWrapper (#1416) --- thunder/executors/nvfuserex_impl.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index a13691103f..561f838a7d 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -12,6 +12,7 @@ from looseversion import LooseVersion import torch +from torch import Tensor import thunder.core.dtypes as dtypes import thunder.torch as ltorch @@ -39,7 +40,7 @@ import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs, NON_FUNCTIONAL_OPS -from thunder.core.profile import add_markers +from thunder.core.profile import annotate_for_profile from thunder.core.compile_data import get_compile_option from thunder.core.transforms import ( @@ -401,14 +402,11 @@ def compute_tensor_descriptor( return compute_symbolic_shape(proxy_shape, shape), *compute_contiguity(shape, stride) -def get_tensor_descriptor(p: TensorProxy, t: torch.Tensor) -> tuple[tuple[int, ...], tuple[bool, ...], tuple[int, ...]]: - return compute_tensor_descriptor(p.shape, t.shape, t.stride()) - - -# TODO Inline the get_tensor_descriptor call def to_descriptors(proxy_args, args) -> tuple: def to_descriptor(proxy_arg, arg): - if isinstance(arg, Number): + if isinstance(arg, Tensor): + return (*compute_tensor_descriptor(proxy_arg._shape, arg.shape, arg.stride()), arg.dtype) + elif isinstance(arg, Number): return type(arg) elif isinstance(arg, tuple): if len(arg) != 0: @@ -419,8 +417,6 @@ def to_descriptor(proxy_arg, arg): exception_type=AssertionError, ) return type(arg) - elif isinstance(arg, torch.Tensor): - return (*get_tensor_descriptor(proxy_arg, arg), arg.dtype) raise ValueError(f"unrecognized type in arguments: {type(arg)}") @@ -452,7 +448,7 @@ def __call__(self, *args): # Set device if set in one of the "factory" methods like full, iota, or uniform kwargs = {"device": fd._selected_device} if hasattr(fd, "_selected_device") else {} - with add_markers(self.name): + with annotate_for_profile(self.name): return fd.execute(args, **kwargs) def __repr__(self):