diff --git a/thunder/__init__.py b/thunder/__init__.py index c09c1cc9b5..5f9bd9f521 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -32,7 +32,7 @@ from thunder.core.transform_common import ( dce, Transform, - wrap_return_value_together_with_argments, + wrap_return_value_together_with_arguments, unwrap_return_value, remove_context_manager_prims_from_trace, ) @@ -535,7 +535,7 @@ def get_computation_and_inputs(*args, **kwargs): prologue_traces = [prologue_trc] computation_traces = [computation_trc] - computation_trc = wrap_return_value_together_with_argments(computation_trc) + computation_trc = wrap_return_value_together_with_arguments(computation_trc) computation_traces.append(computation_trc) computation_trc = remove_context_manager_prims_from_trace(computation_trc) diff --git a/thunder/core/functionalization.py b/thunder/core/functionalization.py index aac6f15026..86859e6428 100644 --- a/thunder/core/functionalization.py +++ b/thunder/core/functionalization.py @@ -41,7 +41,6 @@ def bsym_of_to_return_self(bsym: BoundSymbol): def check_inplace_to_views(computation_trace: Trace) -> dict[VariableInterface, TensorProxy]: """Error out if in-place op that outputs of different number of elements from the input and the input has other consumers.""" - from thunder.core import utils import thunder.torch as ltorch producer_bsyms = producers(computation_trace) @@ -58,7 +57,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: return bsym.sym.tags and tag in bsym.sym.tags swap_map: dict[VariableInterface, TensorProxy] = {} - consumers = utils.consumers(computation_trace) + consumer_map = consumers(computation_trace) bsym: BoundSymbol for bsym in filter(lambda b: has_tag(b, prims.OpTags.IN_PLACE), computation_trace.bound_symbols): in_tensor: TensorProxy = list(filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))[0] @@ -72,7 +71,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: # assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones` continue orig_tensor = flat_tensor_proxy_args[0] - consumer_of_orig_tensor = consumers[orig_tensor] + consumer_of_orig_tensor = consumer_map[orig_tensor] # When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe. # Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original # is an arg or a kwarg. @@ -604,7 +603,7 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl if bsym in bsym_to_copy_bsyms: functionalized_bsyms.extend(bsym_to_copy_bsyms[bsym]) copy_bsym = functionalized_bsyms[-1] - # wrap_return_value_together_with_argments places all the arguments in the return value + # wrap_return_value_together_with_arguments places all the arguments in the return value # We swap these arguments in the return value with the outputs of copies onto them # This prevents subsequent transforms from ordering the return statement before those copies swap_map_for_return[variableify(copy_bsym.flat_proxy_args[0])] = copy_bsym.flat_proxy_outs[0] diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index e21ca4b28a..a21d1b207c 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -17,6 +17,7 @@ from thunder.core.utils import ProxyDict, producers, check if TYPE_CHECKING: + from numbers import Number from typing import Any from thunder.core.module import ThunderModule @@ -456,7 +457,7 @@ def process_bound_symbols(src_bound_symbols, target_bound_symbols): return output -def wrap_return_value_together_with_argments(trace: Trace) -> Trace: +def wrap_return_value_together_with_arguments(trace: Trace) -> Trace: last = trace.bound_symbols[-1] assert last.sym.id == prims.PrimIDs.RETURN flat_args, _ = tree_flatten((trace.args, trace.kwargs)) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 2bf91372fe..2d9d88cddf 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -75,7 +75,7 @@ from thunder.core.transform_common import ( dce, Transform, - wrap_return_value_together_with_argments, + wrap_return_value_together_with_arguments, unwrap_return_value, VJPDual, ) @@ -1493,7 +1493,7 @@ def python_callable(*args, **kwargs): grad(python_callable), *computation_trc.args, **computation_trc.kwargs ) - gradtrc = wrap_return_value_together_with_argments(gradtrc) + gradtrc = wrap_return_value_together_with_arguments(gradtrc) gradtrc = dce(gradtrc) return prologue_trc, gradtrc, epilogue_trc diff --git a/thunder/examine/__init__.py b/thunder/examine/__init__.py index ebc197abae..fe1394e142 100644 --- a/thunder/examine/__init__.py +++ b/thunder/examine/__init__.py @@ -6,7 +6,7 @@ import thunder from thunder.core.trace import TraceCtx from thunder.core.transforms import bsym_list_to_dag, Node -from thunder.core.proxies import TensorProxy +from thunder.core.proxies import TensorProxy, CollectionProxy from thunder.core.symbol import BoundSymbol from thunder.torch import _torch_to_thunder_function_map from thunder.torch.default_torch_ops import torch_auto_registered_ops @@ -276,17 +276,55 @@ def get_nvfuser_repro(trace: TraceCtx, fusion_name: str, /) -> str: return get_repro(fusion.last_inputs) -def make_trace_dot(trace: TraceCtx): +# Copied from `pytorchviz` which has MIT License (Thank you!) +# See https://github.com/szagoruyko/pytorchviz/blob/0adcd83af8aa7ab36d6afd139cabbd9df598edb7/torchviz/dot.py#L180 +def resize_graph(dot, size_per_element=0.15, min_size=12): + """Resize the graph according to how much content it contains. + + Modify the graph in place. + """ + # Get the approximate number of nodes and edges + num_rows = len(dot.body) + content_size = num_rows * size_per_element + size = max(min_size, content_size) + size_str = str(size) + "," + str(size) + dot.graph_attr.update(size=size_str) + + +def _repr_proxy(t_proxy, show_metadata=False): + if isinstance(t_proxy, TensorProxy): + # Should we just delegate to TensorProxy.__repr__ ? + extra_meta = f"\n shape:{t_proxy.shape} \n dtype:{t_proxy.dtype}" if show_metadata else "" + return f"name:{t_proxy.name}" + extra_meta + + # For any other proxy, we just print the name. + return f"name:{t_proxy.name}" + + +def make_trace_dot(trace: TraceCtx, show_metadata=False): """ Creates a directed graph of the given trace. This function is intended to be used to use graphviz to visualize the computation graph of a trace. Beware, rendering out a graph for large traces might take a while. + Roots nodes are colored "green", intermediates are colored "lightblue" and leaves are colored "orange" + Requires graphviz to be installed, for more information check out -> https://graphviz.readthedocs.io/en/stable/index.html + .. note:: + To improve the rendering time, one can update `nslimit` and `nslimit1` graph attributes on the returned Digraph before + calling the `render`. + Eg. dot_graph.graph_attr["nslimit"] = 5 + + Refer the following links for more details: + [1] https://graphviz.org/docs/attrs/nslimit/ + [2] https://graphviz.org/docs/attrs/nslimit1/ + + Args: trace (TraceCtx): The Thunder trace to be made into a graph. + show_metadata (bool): Add more meta-data (like shape, dtype) to the nodes representing the Tensor. Defaults to False. Returns: graphviz.Digraph: A graphviz directed graph. @@ -308,25 +346,57 @@ def make_trace_dot(trace: TraceCtx): roots, leaves = bsym_list_to_dag(trace.bound_symbols) leaves_id = {id(leaf) for leaf in leaves} + roots_id = {id(root) for root in roots} + + def _get_color(node_id): + if node_id in roots_id: + return "green" + if node_id in leaves_id: + return "orange" + return "lightblue" + + # All roots will be positioned at same level due to `rank`:`same`. + roots_g = graphviz.Digraph(graph_attr={"rank": "same"}) + for root in roots: + roots_g.node(str(id(root)), root.bsym.python(indent=0, print_depth=1)[0], fillcolor=_get_color(str(id(root)))) + dot.subgraph(roots_g) # Add roots_g to main graph. + stack = [*roots] visited = set() + + # Breadth first while stack: - node: Node = stack.pop() + node: Node = stack.pop(0) node_id = id(node) visited.add(node_id) - dot.node(str(node_id), node.bsym.sym.name, fillcolor="orange" if node_id in leaves_id else "white") - + color = _get_color(node_id) + + # Unpacking collection might be a multi-line. + node_repr = "\n".join(node.bsym.python(indent=0, print_depth=1)) + node_repr = node_repr.replace("\\", "") + dot.node(str(node_id), node_repr, fillcolor=color) + + # Add node for args and connect args + for arg in node.bsym.flat_args: + # We have collection proxies in backward + if isinstance(arg, (TensorProxy, CollectionProxy)): + arg_id = arg.name + dot.node(arg_id, _repr_proxy(arg, show_metadata)) + dot.edge(arg_id, str(node_id)) + + # Connect outputs + for out in node.bsym.flat_outs: + # We have collection proxies in backward + if isinstance(out, (TensorProxy, CollectionProxy)): + out_id = out.name + dot.edge(str(node_id), out_id) + + # Add children for exploration for child in node.children: child_id = id(child) - out_proxy_name = node.bsym.output.name if isinstance(node.bsym.output, TensorProxy) else None - dot.edge(str(node_id), str(child_id), label=out_proxy_name) if child_id not in visited and not str(child.bsym).startswith("#"): stack.append(child) - for parent in node.parents: - parent_id = id(parent) - dot.edge(str(parent_id), str(node_id)) - if parent_id not in visited and not str(parent.bsym).startswith("#"): - stack.append(parent) - + # Resize graph based on number of nodes + resize_graph(dot) return dot diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index a13691103f..561f838a7d 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -12,6 +12,7 @@ from looseversion import LooseVersion import torch +from torch import Tensor import thunder.core.dtypes as dtypes import thunder.torch as ltorch @@ -39,7 +40,7 @@ import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs, NON_FUNCTIONAL_OPS -from thunder.core.profile import add_markers +from thunder.core.profile import annotate_for_profile from thunder.core.compile_data import get_compile_option from thunder.core.transforms import ( @@ -401,14 +402,11 @@ def compute_tensor_descriptor( return compute_symbolic_shape(proxy_shape, shape), *compute_contiguity(shape, stride) -def get_tensor_descriptor(p: TensorProxy, t: torch.Tensor) -> tuple[tuple[int, ...], tuple[bool, ...], tuple[int, ...]]: - return compute_tensor_descriptor(p.shape, t.shape, t.stride()) - - -# TODO Inline the get_tensor_descriptor call def to_descriptors(proxy_args, args) -> tuple: def to_descriptor(proxy_arg, arg): - if isinstance(arg, Number): + if isinstance(arg, Tensor): + return (*compute_tensor_descriptor(proxy_arg._shape, arg.shape, arg.stride()), arg.dtype) + elif isinstance(arg, Number): return type(arg) elif isinstance(arg, tuple): if len(arg) != 0: @@ -419,8 +417,6 @@ def to_descriptor(proxy_arg, arg): exception_type=AssertionError, ) return type(arg) - elif isinstance(arg, torch.Tensor): - return (*get_tensor_descriptor(proxy_arg, arg), arg.dtype) raise ValueError(f"unrecognized type in arguments: {type(arg)}") @@ -452,7 +448,7 @@ def __call__(self, *args): # Set device if set in one of the "factory" methods like full, iota, or uniform kwargs = {"device": fd._selected_device} if hasattr(fd, "_selected_device") else {} - with add_markers(self.name): + with annotate_for_profile(self.name): return fd.execute(args, **kwargs) def __repr__(self): diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 5ac409473e..89c9b87272 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1148,7 +1148,7 @@ def test_forward_and_backward_from_trace(executor, device, _): from thunder.clang import cos, sin import thunder.torch as ltorch from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad - from thunder.core.transform_common import wrap_return_value_together_with_argments + from thunder.core.transform_common import wrap_return_value_together_with_arguments def func(a, b, *, c): d = a + b + c @@ -1159,7 +1159,7 @@ def func(a, b, *, c): b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True) c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True) initial_trace = trace(inline_trace=False)(func, a, b, c=c) - wrapped_trace = wrap_return_value_together_with_argments(initial_trace) + wrapped_trace = wrap_return_value_together_with_arguments(initial_trace) fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace) fw = executor.make_callable(fw_trace) bw = executor.make_callable(bw_trace) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index c42c183b7e..32a3eee7cb 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -53,7 +53,7 @@ def test_rematerialization_with_forward_and_backward_from_trace(executor: TestEx from thunder.clang import cos, sin import thunder.torch as ltorch from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad - from thunder.core.transform_common import wrap_return_value_together_with_argments + from thunder.core.transform_common import wrap_return_value_together_with_arguments from thunder.common import transform_for_execution from thunder.core.rematerialization import rematerialize_forward_and_backward @@ -74,7 +74,7 @@ def func(a, b, *, c): requires_grad=True, ) trace = trace(inline_trace=False)(func, a, b, c=c) - trace = wrap_return_value_together_with_argments(trace) + trace = wrap_return_value_together_with_arguments(trace) fw_trace, bw_trace = forward_and_backward_from_trace(trace) fw_extraces = transform_for_execution(fw_trace, executors_list=executor.executors_list())