Skip to content

Commit

Permalink
[thunder.dynamo] type hint and docs update (#1265)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Oct 7, 2024
1 parent 3250b60 commit fceb64e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
6 changes: 5 additions & 1 deletion thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Callable
from __future__ import annotations
from typing import TYPE_CHECKING

import torch
from torch.fx.passes.split_module import split_module
Expand All @@ -15,6 +16,9 @@
recompile_graph,
)

if TYPE_CHECKING:
from collections.abc import Callable


def _splitter(
gm: torch.fx.GraphModule,
Expand Down
48 changes: 29 additions & 19 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations
from collections.abc import Callable
from enum import Enum, auto
from typing import TYPE_CHECKING
import dataclasses
from collections.abc import Callable
import itertools
import inspect
import itertools

import torch

from thunder.torch.default_torch_ops import torch_auto_registered_ops
from thunder.torch import _torch_to_thunder_function_map
from thunder.torch.langctx import torchctx

if TYPE_CHECKING:
from thunder.core.symbol import Symbol

auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values()))


Expand Down Expand Up @@ -49,36 +54,38 @@ class SplitReasonType(Enum):

@dataclasses.dataclass(frozen=True)
class SplitReason:
"""
A dataclass containing information about a split.
"""A dataclass containing information about a split.
Attributes:
type (SplitReasonType): Reason for the split.
info (str): String with details of what caused the split.
exception (Exception | None): Exception if there was any.
reason_type: Reason for the split.
info: String with details of what caused the split.
exception: Exception if there was any.
"""

type: SplitReasonType
reason_type: SplitReasonType
info: str | None
exception: Exception | None = None


@dataclasses.dataclass(frozen=True)
class SubgraphInfo:
"""
A dataclass containing information about a subgraph.
"""A dataclass containing information about a subgraph.
Attributes:
original_graph_module (torch.fx.GraphModule): The original graph module.
split_graph_module (torch.fx.GraphModule): Optional. The graph module for the split subgraph.
thunder_compiled_fns (list[Callable]): List of thunder optimized callables. This could be None if there the graph module was not supported by thunder. Look at the `split_reasons` for further information.
compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split.
split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if there are was a split.
original_graph_module: The original graph module.
split_graph_module: The graph module for the split subgraph.
thunder_compiled_fns: List of thunder optimized callables.
This could be :obj:`None` if there the graph module was not supported by thunder.
Look at the :attr:`split_reasons` for further information.
submodule_to_compiled_functions: Dict from subgraph to compiled function.
This will be a dict with one pair in case the graph was not split.
split_reasons: List of reasons explaining why the subgraph was split.
Present only if there are was a split.
"""

original_graph_module: torch.fx.GraphModule
split_graph_module: torch.fx.GraphModule
thunder_compiled_fns: list[Callable]
split_graph_module: torch.fx.GraphModule | None
thunder_compiled_fns: list[Callable] | None
submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction]
split_reasons: list | None = None

Expand Down Expand Up @@ -143,7 +150,7 @@ def make_tensor_proxy(arg_node):
return proxy_args, proxy_kwargs


def try_execute_thunder_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
"""
Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments.
Expand Down Expand Up @@ -303,7 +310,10 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason


def update_node_and_submodule(
graph_module: torch.fx.GraphModule, node: torch.fx.Node, new_name: str, new_callable: Callable
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
new_name: str,
new_callable: Callable,
):
"""
Updates the graph module and the node in place with a new name and a new callable as the target.
Expand Down

0 comments on commit fceb64e

Please sign in to comment.