Skip to content

Commit edda3d9

Browse files
committed
examples: Add example usage script for module-acc
- Add detailed tutorial for excluding modules or functions from tracing in Dynamo and writing custom converters for those excluded modules
1 parent 2684dd9 commit edda3d9

File tree

3 files changed

+164
-0
lines changed

3 files changed

+164
-0
lines changed

docsrc/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Tutorials
7373
tutorials/_rendered_examples/dynamo/dynamo_compile_resnet_example
7474
tutorials/_rendered_examples/dynamo/dynamo_compile_transformers_example
7575
tutorials/_rendered_examples/dynamo/dynamo_compile_advanced_usage
76+
tutorials/_rendered_examples/dynamo/dynamo_module_level_acceleration
7677

7778
Python API Documenation
7879
------------------------

examples/dynamo/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference.
99
* :ref:`dynamo_compile_resnet`: Compiling a ResNet model using the Dynamo Compile Frontend for ``torch_tensorrt.compile``
1010
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
1111
* :ref:`dynamo_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
12+
* :ref:`dynamo_module_level_acceleration`: Accelerate a specific ``torch.nn.Module`` or function by excluding it from decomposition
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
.. _dynamo_module_level_acceleration:
3+
4+
Dynamo Module Level Acceleration Tutorial
5+
=========================
6+
7+
This interactive script is intended as an overview of the process by which module-level acceleration for `torch_tensorrt.dynamo.compile` works, and how it can be used to accelerate built-in or custom `torch.nn` modules by excluding them from AOT tracing. This script shows the process for `torch.nn.MaxPool1d`"""
8+
9+
# %%
10+
# 1. The Placeholder
11+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
#
13+
# Specify the schema and namespace of the operator, as well as a placeholder function
14+
# representing the schema. The schema should be in torch JIT syntax, indicating input and output
15+
# types. The namespace, such as tensorrt, will cause the op to be registered as `torch.ops.tensorrt.your_op`
16+
# Then, create a placeholder function with no operations, but having the same schema and naming as that
17+
# used in the decorator
18+
19+
# %%
20+
21+
from torch._custom_op.impl import custom_op
22+
23+
24+
@custom_op(
25+
qualname="tensorrt::maxpool1d",
26+
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
27+
)
28+
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode):
29+
# Defines operator schema, name, namespace, and function header
30+
...
31+
32+
33+
# %%
34+
# 2. The Generic Implementation
35+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
36+
#
37+
# Define the default implementation of the operator in torch syntax. This is used for autograd
38+
# and other tracing functionality. Generally, the `torch.nn.functional` analog of the operator to replace
39+
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
40+
# implementation here. Note that the function header to the generic function can have specific arguments
41+
# as in the above placeholder
42+
43+
# %%
44+
import torch
45+
46+
47+
@maxpool1d.impl("cpu")
48+
@maxpool1d.impl("cuda")
49+
@maxpool1d.impl_abstract()
50+
def maxpool1d_generic(
51+
*args,
52+
**kwargs,
53+
):
54+
# Defines an implementation for AOT Autograd to use for shape analysis/propagation
55+
return torch.nn.functional.max_pool1d(
56+
*args,
57+
**kwargs,
58+
)
59+
60+
61+
# %%
62+
# 3. The Module Substitution Function
63+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
64+
#
65+
# Define a function which can intercept a node of the kind to be replaced, extract
66+
# the relevant data from that node/submodule, and then re-package the information
67+
# for use by an accelerated implementation (to be implemented in step 4). This function
68+
# should use the operator defined in step 1 (for example `torch.ops.tensorrt.maxpool1d`).
69+
# It should refactor the args and kwargs as is needed by the accelerated implementation.
70+
71+
# %%
72+
73+
from torch_tensorrt.dynamo.backend.lowering import register_substitution
74+
75+
76+
@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
77+
def maxpool1d_insertion_fn(
78+
gm: torch.fx.GraphModule,
79+
node: torch.fx.Node,
80+
submodule: torch.nn.Module,
81+
) -> torch.fx.Node:
82+
# Defines insertion function for new node
83+
new_node = gm.graph.call_function(
84+
torch.ops.tensorrt.maxpool1d,
85+
args=node.args,
86+
kwargs={
87+
"kernel_size": submodule.kernel_size,
88+
"stride": submodule.stride,
89+
"padding": submodule.padding,
90+
"dilation": submodule.dilation,
91+
"ceil_mode": submodule.ceil_mode,
92+
},
93+
)
94+
95+
return new_node
96+
97+
98+
# %%
99+
# If the submodule has weights or other Tensor fields which the accelerated implementation
100+
# needs, the function should insert the necessary nodes to access those weights. For example,
101+
# if the weight Tensor of a submodule is needed, one could write::
102+
#
103+
#
104+
# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor)
105+
# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor)
106+
#
107+
# ...
108+
#
109+
# kwargs={"weight": weights,
110+
# "bias": bias,
111+
# ...
112+
# }
113+
114+
# %%
115+
# 4. The Accelerated Implementation
116+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
117+
#
118+
# Define an accelerated implementation of the operator, and register it as necessary.
119+
# This accelerated implementation should consume the args/kwargs specified in step 3.
120+
# One should expect that torch.compile will compress all kwargs into the args field in
121+
# the order specified in the schema written in step 1.
122+
123+
# %%
124+
125+
from typing import Dict, Tuple
126+
from torch.fx.node import Argument, Target
127+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
128+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
129+
from torch_tensorrt.fx.converters import acc_ops_converters
130+
131+
132+
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
133+
def tensorrt_maxpool1d(
134+
network: TRTNetwork,
135+
target: Target,
136+
args: Tuple[Argument, ...],
137+
kwargs: Dict[str, Argument],
138+
name: str,
139+
) -> TRTTensor:
140+
# Defines converter replacing the default operator for this function
141+
kwargs_new = {
142+
"input": args[0],
143+
"kernel_size": args[1],
144+
"stride": args[2],
145+
"padding": args[3],
146+
"dilation": args[4],
147+
"ceil_mode": False if len(args) < 6 else args[5],
148+
}
149+
150+
return acc_ops_converters.acc_ops_max_pool1d(
151+
network, target, None, kwargs_new, name
152+
)
153+
154+
155+
# %%
156+
# 5. Add Imports
157+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
158+
#
159+
# Add your accelerated module file to the `__init__.py` in the
160+
# `py/torch_tensorrt/dynamo/backend/lowering/substitutions` directory, to ensure
161+
# all registrations are run. For instance, if the new module file is called `new_mod.py`,
162+
# one should add `from .new_mod import *` to the `__init__.py`

0 commit comments

Comments
 (0)