Skip to content

Commit

Permalink
Merge branch 'main' into check_vjp_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Nov 14, 2024
2 parents 65ffef8 + 9943778 commit ef0211e
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 36 deletions.
4 changes: 2 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from thunder.core.transform_common import (
dce,
Transform,
wrap_return_value_together_with_argments,
wrap_return_value_together_with_arguments,
unwrap_return_value,
remove_context_manager_prims_from_trace,
)
Expand Down Expand Up @@ -535,7 +535,7 @@ def get_computation_and_inputs(*args, **kwargs):
prologue_traces = [prologue_trc]
computation_traces = [computation_trc]

computation_trc = wrap_return_value_together_with_argments(computation_trc)
computation_trc = wrap_return_value_together_with_arguments(computation_trc)
computation_traces.append(computation_trc)

computation_trc = remove_context_manager_prims_from_trace(computation_trc)
Expand Down
7 changes: 3 additions & 4 deletions thunder/core/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def bsym_of_to_return_self(bsym: BoundSymbol):

def check_inplace_to_views(computation_trace: Trace) -> dict[VariableInterface, TensorProxy]:
"""Error out if in-place op that outputs of different number of elements from the input and the input has other consumers."""
from thunder.core import utils
import thunder.torch as ltorch

producer_bsyms = producers(computation_trace)
Expand All @@ -58,7 +57,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
return bsym.sym.tags and tag in bsym.sym.tags

swap_map: dict[VariableInterface, TensorProxy] = {}
consumers = utils.consumers(computation_trace)
consumer_map = consumers(computation_trace)
bsym: BoundSymbol
for bsym in filter(lambda b: has_tag(b, prims.OpTags.IN_PLACE), computation_trace.bound_symbols):
in_tensor: TensorProxy = list(filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))[0]
Expand All @@ -72,7 +71,7 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
# assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones`
continue
orig_tensor = flat_tensor_proxy_args[0]
consumer_of_orig_tensor = consumers[orig_tensor]
consumer_of_orig_tensor = consumer_map[orig_tensor]
# When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe.
# Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original
# is an arg or a kwarg.
Expand Down Expand Up @@ -604,7 +603,7 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl
if bsym in bsym_to_copy_bsyms:
functionalized_bsyms.extend(bsym_to_copy_bsyms[bsym])
copy_bsym = functionalized_bsyms[-1]
# wrap_return_value_together_with_argments places all the arguments in the return value
# wrap_return_value_together_with_arguments places all the arguments in the return value
# We swap these arguments in the return value with the outputs of copies onto them
# This prevents subsequent transforms from ordering the return statement before those copies
swap_map_for_return[variableify(copy_bsym.flat_proxy_args[0])] = copy_bsym.flat_proxy_outs[0]
Expand Down
3 changes: 2 additions & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from thunder.core.utils import ProxyDict, producers, check

if TYPE_CHECKING:
from numbers import Number
from typing import Any
from thunder.core.module import ThunderModule

Expand Down Expand Up @@ -456,7 +457,7 @@ def process_bound_symbols(src_bound_symbols, target_bound_symbols):
return output


def wrap_return_value_together_with_argments(trace: Trace) -> Trace:
def wrap_return_value_together_with_arguments(trace: Trace) -> Trace:
last = trace.bound_symbols[-1]
assert last.sym.id == prims.PrimIDs.RETURN
flat_args, _ = tree_flatten((trace.args, trace.kwargs))
Expand Down
4 changes: 2 additions & 2 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from thunder.core.transform_common import (
dce,
Transform,
wrap_return_value_together_with_argments,
wrap_return_value_together_with_arguments,
unwrap_return_value,
VJPDual,
)
Expand Down Expand Up @@ -1493,7 +1493,7 @@ def python_callable(*args, **kwargs):
grad(python_callable), *computation_trc.args, **computation_trc.kwargs
)

gradtrc = wrap_return_value_together_with_argments(gradtrc)
gradtrc = wrap_return_value_together_with_arguments(gradtrc)
gradtrc = dce(gradtrc)
return prologue_trc, gradtrc, epilogue_trc

Expand Down
96 changes: 83 additions & 13 deletions thunder/examine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import thunder
from thunder.core.trace import TraceCtx
from thunder.core.transforms import bsym_list_to_dag, Node
from thunder.core.proxies import TensorProxy
from thunder.core.proxies import TensorProxy, CollectionProxy
from thunder.core.symbol import BoundSymbol
from thunder.torch import _torch_to_thunder_function_map
from thunder.torch.default_torch_ops import torch_auto_registered_ops
Expand Down Expand Up @@ -276,17 +276,55 @@ def get_nvfuser_repro(trace: TraceCtx, fusion_name: str, /) -> str:
return get_repro(fusion.last_inputs)


def make_trace_dot(trace: TraceCtx):
# Copied from `pytorchviz` which has MIT License (Thank you!)
# See https://github.com/szagoruyko/pytorchviz/blob/0adcd83af8aa7ab36d6afd139cabbd9df598edb7/torchviz/dot.py#L180
def resize_graph(dot, size_per_element=0.15, min_size=12):
"""Resize the graph according to how much content it contains.
Modify the graph in place.
"""
# Get the approximate number of nodes and edges
num_rows = len(dot.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
dot.graph_attr.update(size=size_str)


def _repr_proxy(t_proxy, show_metadata=False):
if isinstance(t_proxy, TensorProxy):
# Should we just delegate to TensorProxy.__repr__ ?
extra_meta = f"\n shape:{t_proxy.shape} \n dtype:{t_proxy.dtype}" if show_metadata else ""
return f"name:{t_proxy.name}" + extra_meta

# For any other proxy, we just print the name.
return f"name:{t_proxy.name}"


def make_trace_dot(trace: TraceCtx, show_metadata=False):
"""
Creates a directed graph of the given trace.
This function is intended to be used to use graphviz to visualize the computation graph of a trace.
Beware, rendering out a graph for large traces might take a while.
Roots nodes are colored "green", intermediates are colored "lightblue" and leaves are colored "orange"
Requires graphviz to be installed, for more information check out -> https://graphviz.readthedocs.io/en/stable/index.html
.. note::
To improve the rendering time, one can update `nslimit` and `nslimit1` graph attributes on the returned Digraph before
calling the `render`.
Eg. dot_graph.graph_attr["nslimit"] = 5
Refer the following links for more details:
[1] https://graphviz.org/docs/attrs/nslimit/
[2] https://graphviz.org/docs/attrs/nslimit1/
Args:
trace (TraceCtx): The Thunder trace to be made into a graph.
show_metadata (bool): Add more meta-data (like shape, dtype) to the nodes representing the Tensor. Defaults to False.
Returns:
graphviz.Digraph: A graphviz directed graph.
Expand All @@ -308,25 +346,57 @@ def make_trace_dot(trace: TraceCtx):

roots, leaves = bsym_list_to_dag(trace.bound_symbols)
leaves_id = {id(leaf) for leaf in leaves}
roots_id = {id(root) for root in roots}

def _get_color(node_id):
if node_id in roots_id:
return "green"
if node_id in leaves_id:
return "orange"
return "lightblue"

# All roots will be positioned at same level due to `rank`:`same`.
roots_g = graphviz.Digraph(graph_attr={"rank": "same"})
for root in roots:
roots_g.node(str(id(root)), root.bsym.python(indent=0, print_depth=1)[0], fillcolor=_get_color(str(id(root))))
dot.subgraph(roots_g) # Add roots_g to main graph.

stack = [*roots]
visited = set()

# Breadth first
while stack:
node: Node = stack.pop()
node: Node = stack.pop(0)
node_id = id(node)
visited.add(node_id)
dot.node(str(node_id), node.bsym.sym.name, fillcolor="orange" if node_id in leaves_id else "white")

color = _get_color(node_id)

# Unpacking collection might be a multi-line.
node_repr = "\n".join(node.bsym.python(indent=0, print_depth=1))
node_repr = node_repr.replace("\\", "")
dot.node(str(node_id), node_repr, fillcolor=color)

# Add node for args and connect args
for arg in node.bsym.flat_args:
# We have collection proxies in backward
if isinstance(arg, (TensorProxy, CollectionProxy)):
arg_id = arg.name
dot.node(arg_id, _repr_proxy(arg, show_metadata))
dot.edge(arg_id, str(node_id))

# Connect outputs
for out in node.bsym.flat_outs:
# We have collection proxies in backward
if isinstance(out, (TensorProxy, CollectionProxy)):
out_id = out.name
dot.edge(str(node_id), out_id)

# Add children for exploration
for child in node.children:
child_id = id(child)
out_proxy_name = node.bsym.output.name if isinstance(node.bsym.output, TensorProxy) else None
dot.edge(str(node_id), str(child_id), label=out_proxy_name)
if child_id not in visited and not str(child.bsym).startswith("#"):
stack.append(child)

for parent in node.parents:
parent_id = id(parent)
dot.edge(str(parent_id), str(node_id))
if parent_id not in visited and not str(parent.bsym).startswith("#"):
stack.append(parent)

# Resize graph based on number of nodes
resize_graph(dot)
return dot
16 changes: 6 additions & 10 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from looseversion import LooseVersion
import torch
from torch import Tensor

import thunder.core.dtypes as dtypes
import thunder.torch as ltorch
Expand Down Expand Up @@ -39,7 +40,7 @@
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs, NON_FUNCTIONAL_OPS
from thunder.core.profile import add_markers
from thunder.core.profile import annotate_for_profile
from thunder.core.compile_data import get_compile_option

from thunder.core.transforms import (
Expand Down Expand Up @@ -401,14 +402,11 @@ def compute_tensor_descriptor(
return compute_symbolic_shape(proxy_shape, shape), *compute_contiguity(shape, stride)


def get_tensor_descriptor(p: TensorProxy, t: torch.Tensor) -> tuple[tuple[int, ...], tuple[bool, ...], tuple[int, ...]]:
return compute_tensor_descriptor(p.shape, t.shape, t.stride())


# TODO Inline the get_tensor_descriptor call
def to_descriptors(proxy_args, args) -> tuple:
def to_descriptor(proxy_arg, arg):
if isinstance(arg, Number):
if isinstance(arg, Tensor):
return (*compute_tensor_descriptor(proxy_arg._shape, arg.shape, arg.stride()), arg.dtype)
elif isinstance(arg, Number):
return type(arg)
elif isinstance(arg, tuple):
if len(arg) != 0:
Expand All @@ -419,8 +417,6 @@ def to_descriptor(proxy_arg, arg):
exception_type=AssertionError,
)
return type(arg)
elif isinstance(arg, torch.Tensor):
return (*get_tensor_descriptor(proxy_arg, arg), arg.dtype)

raise ValueError(f"unrecognized type in arguments: {type(arg)}")

Expand Down Expand Up @@ -452,7 +448,7 @@ def __call__(self, *args):

# Set device if set in one of the "factory" methods like full, iota, or uniform
kwargs = {"device": fd._selected_device} if hasattr(fd, "_selected_device") else {}
with add_markers(self.name):
with annotate_for_profile(self.name):
return fd.execute(args, **kwargs)

def __repr__(self):
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ def test_forward_and_backward_from_trace(executor, device, _):
from thunder.clang import cos, sin
import thunder.torch as ltorch
from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad
from thunder.core.transform_common import wrap_return_value_together_with_argments
from thunder.core.transform_common import wrap_return_value_together_with_arguments

def func(a, b, *, c):
d = a + b + c
Expand All @@ -1159,7 +1159,7 @@ def func(a, b, *, c):
b = make_tensor((2, 3), device=device, dtype=torch.float64, requires_grad=True)
c = make_tensor((3,), device=device, dtype=torch.float64, requires_grad=True)
initial_trace = trace(inline_trace=False)(func, a, b, c=c)
wrapped_trace = wrap_return_value_together_with_argments(initial_trace)
wrapped_trace = wrap_return_value_together_with_arguments(initial_trace)
fw_trace, bw_trace = forward_and_backward_from_trace(wrapped_trace)
fw = executor.make_callable(fw_trace)
bw = executor.make_callable(bw_trace)
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_rematerialization_with_forward_and_backward_from_trace(executor: TestEx
from thunder.clang import cos, sin
import thunder.torch as ltorch
from thunder.core.transforms import forward_and_backward_from_trace, value_and_grad
from thunder.core.transform_common import wrap_return_value_together_with_argments
from thunder.core.transform_common import wrap_return_value_together_with_arguments
from thunder.common import transform_for_execution
from thunder.core.rematerialization import rematerialize_forward_and_backward

Expand All @@ -74,7 +74,7 @@ def func(a, b, *, c):
requires_grad=True,
)
trace = trace(inline_trace=False)(func, a, b, c=c)
trace = wrap_return_value_together_with_argments(trace)
trace = wrap_return_value_together_with_arguments(trace)
fw_trace, bw_trace = forward_and_backward_from_trace(trace)

fw_extraces = transform_for_execution(fw_trace, executors_list=executor.executors_list())
Expand Down

0 comments on commit ef0211e

Please sign in to comment.