Skip to content

Commit

Permalink
Reduce CPU overhead of FusionDefinitionWrapper (#1416)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk authored Nov 13, 2024
1 parent 8f5026f commit ac981eb
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from looseversion import LooseVersion
import torch
from torch import Tensor

import thunder.core.dtypes as dtypes
import thunder.torch as ltorch
Expand Down Expand Up @@ -39,7 +40,7 @@
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs, NON_FUNCTIONAL_OPS
from thunder.core.profile import add_markers
from thunder.core.profile import annotate_for_profile
from thunder.core.compile_data import get_compile_option

from thunder.core.transforms import (
Expand Down Expand Up @@ -401,14 +402,11 @@ def compute_tensor_descriptor(
return compute_symbolic_shape(proxy_shape, shape), *compute_contiguity(shape, stride)


def get_tensor_descriptor(p: TensorProxy, t: torch.Tensor) -> tuple[tuple[int, ...], tuple[bool, ...], tuple[int, ...]]:
return compute_tensor_descriptor(p.shape, t.shape, t.stride())


# TODO Inline the get_tensor_descriptor call
def to_descriptors(proxy_args, args) -> tuple:
def to_descriptor(proxy_arg, arg):
if isinstance(arg, Number):
if isinstance(arg, Tensor):
return (*compute_tensor_descriptor(proxy_arg._shape, arg.shape, arg.stride()), arg.dtype)
elif isinstance(arg, Number):
return type(arg)
elif isinstance(arg, tuple):
if len(arg) != 0:
Expand All @@ -419,8 +417,6 @@ def to_descriptor(proxy_arg, arg):
exception_type=AssertionError,
)
return type(arg)
elif isinstance(arg, torch.Tensor):
return (*get_tensor_descriptor(proxy_arg, arg), arg.dtype)

raise ValueError(f"unrecognized type in arguments: {type(arg)}")

Expand Down Expand Up @@ -452,7 +448,7 @@ def __call__(self, *args):

# Set device if set in one of the "factory" methods like full, iota, or uniform
kwargs = {"device": fd._selected_device} if hasattr(fd, "_selected_device") else {}
with add_markers(self.name):
with annotate_for_profile(self.name):
return fd.execute(args, **kwargs)

def __repr__(self):
Expand Down

0 comments on commit ac981eb

Please sign in to comment.