Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Initial implementation for automatic plugin #3301

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

bowang007
Copy link
Collaborator

Description

This PR implements the automatic plugin feature.

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: build system Issues re: Build system component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 22, 2024
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-22 01:20:58.215888+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-22 01:21:18.909129+00:00
@@ -1,7 +1,8 @@
import triton
import triton.language as tl
+

@triton.jit
def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
    # Program ID determines the block of data each thread will process
    pid = tl.program_id(0)
@@ -25,23 +26,23 @@
@custom_op("torchtrt_ex::elementwise_add", mutates_args=())  # type: ignore[misc]
def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    # Ensure the tensors are on the GPU
    assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
    assert X.shape == Y.shape, "Tensors must have the same shape."
-    
+
    # Create output tensor
    Z = torch.empty_like(X)
-    
+
    # Define block size
    BLOCK_SIZE = 1024
-    
+
    # Grid of programs
-    grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],)
-    
+    grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
+
    # Launch the kernel
    elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
-    
+
    return Z


# Using the module in PyTorch
# X = torch.randn(1024, device='cuda', requires_grad=True)
@@ -72,22 +73,31 @@

        return res


my_model = MyModel().to("cuda")
-m = torch.full((64, 64), 2, device='cuda',)
-n = torch.full((64, 64), 3, device='cuda',)
+m = torch.full(
+    (64, 64),
+    2,
+    device="cuda",
+)
+n = torch.full(
+    (64, 64),
+    3,
+    device="cuda",
+)
# print(torch.ops.torchtrt_ex.elementwise_add(m, n))
# print(my_model.forward(m, n))


@torch.library.register_fake("torchtrt_ex::elementwise_add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x

+
import torch_tensorrt as torchtrt


with torchtrt.logging.info():
    model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1)
    res = model_trt(m, n)
-    print(res)
\ No newline at end of file
+    print(res)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-22 01:21:19.453080+00:00
@@ -1,6 +1,11 @@
-from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters
+from . import (
+    aten_ops_converters,
+    ops_evaluators,
+    prims_ops_converters,
+    plugin_ops_converters,
+)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import *  # noqa: F403
from ._TRTInterpreter import *  # noqa: F403
from .truncate_double import repair_double_inputs
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-22 01:21:20.202267+00:00
@@ -1 +1 @@
-from .plugin_generator import PluginCreator
\ No newline at end of file
+from .plugin_generator import PluginCreator
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-22 01:21:20.284627+00:00
@@ -17,25 +17,28 @@

logger = logging.getLogger(__name__)

TRT_PLUGIN_REGISTRY = trt.get_plugin_registry()

+
@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default)
def torchtrt_ex_elementwise_add(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
-): 
+):
    # logger.debug(f"plugin stuff here2")
    # return torch.add(args)
-    
+
    # How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
-    plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={})
-    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")    
-    
+    plugin_creator = PluginCreator(
+        "elementwise_add_plugin", plugin_namespace="", attrs={}
+    )
+    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
+
    plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator(
        type="elementwise_add_plugin", version="1", plugin_namespace=""
    )
    assert plugin_creator, f"Unable to find elementwise_add_plugin creator"

@@ -44,45 +47,47 @@
    # plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs)
    # assert plugin, "Unable to create <PLUGIN_NAME>"

    # <GENERATE LINK BETWEEN PLUGIN AND INPUTS>
    #    <GET INPUTS INTO LIST>
-    #    <PASS TO PLUGIN>     
-    
+    #    <PASS TO PLUGIN>
+
    # return layer.get_output(0)
    field_configs = trt.PluginFieldCollection([])
-    
-    plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs)
+
+    plugin = plugin_creator.create_plugin(
+        name="elementwise_add_plugin", field_collection=field_configs
+    )
    assert plugin, "Unable to create CircularPaddingPlugin"
-    
+
    # input_tensor = args[
    #     0
    # ]  # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor
    # if not isinstance(input_tensor, trt.ITensor):
    #     # Freeze input tensor if not TensorRT Tensor already
    #     input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input")
-    
+
    lhs_dtype = None
    rhs_dtype = None
    lhs_val = args[0]
    rhs_val = args[1]
-    
+
    if isinstance(lhs_val, TRTTensor):
        lhs_dtype = lhs_val.dtype
        # is_lhs_trt_tensor = True
    if isinstance(rhs_val, TRTTensor):
        rhs_dtype = rhs_val.dtype
        # is_rhs_trt_tensor = True
-        
+
    print(lhs_dtype)
-    
+
    lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
    rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)

    layer = ctx.net.add_plugin_v3(
        [lhs_val, rhs_val], [], plugin
    )  # Add the plugin to the network being constructed
    # layer.name = f"automatic-{name}"
    return layer.get_output(0)


-# 1. generate plugin for any pytorch op
\ No newline at end of file
+# 1. generate plugin for any pytorch op
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-22 01:20:58.227888+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-22 01:21:20.380983+00:00
@@ -11,64 +11,63 @@


logger = logging.getLogger("CustomPlugin")

_numpy_to_plugin_field_type = {
-    np.dtype('int32'): trt.PluginFieldType.INT32,
-    np.dtype('int16'): trt.PluginFieldType.INT16,
-    np.dtype('int8'): trt.PluginFieldType.INT8,
-    np.dtype('bool'): trt.PluginFieldType.INT8,
-    np.dtype('int64'): trt.PluginFieldType.INT64,
-    np.dtype('float32'): trt.PluginFieldType.FLOAT32,
-    np.dtype('float64'): trt.PluginFieldType.FLOAT64,
-    np.dtype('float16'): trt.PluginFieldType.FLOAT16
+    np.dtype("int32"): trt.PluginFieldType.INT32,
+    np.dtype("int16"): trt.PluginFieldType.INT16,
+    np.dtype("int8"): trt.PluginFieldType.INT8,
+    np.dtype("bool"): trt.PluginFieldType.INT8,
+    np.dtype("int64"): trt.PluginFieldType.INT64,
+    np.dtype("float32"): trt.PluginFieldType.FLOAT32,
+    np.dtype("float64"): trt.PluginFieldType.FLOAT64,
+    np.dtype("float16"): trt.PluginFieldType.FLOAT16,
}

_built_in_to_plugin_field_type = {
    int: trt.PluginFieldType.INT64,
    float: trt.PluginFieldType.FLOAT64,
    bool: trt.PluginFieldType.INT8,
    # str is handled separately, so not needed here
}

+
class Tactic(IntEnum):
    TORCH = 1
    TRITON = 2

+
class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime):  # type: ignore[misc]
-    def __init__(
-        self, plugin_name : str, attrs, phase = None
-    ):
+    def __init__(self, plugin_name: str, attrs, phase=None):
        # TODO: needs an additional passed in arguments to specify the needs for each plugin
        # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83
        trt.IPluginV3.__init__(self)
        # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime
        trt.IPluginV3OneCore.__init__(self)
        # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
        trt.IPluginV3OneBuild.__init__(self)
        # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable
-        trt.IPluginV3OneRuntime.__init__(self)       
-        
+        trt.IPluginV3OneRuntime.__init__(self)
+
        # <ANY NON TENSOR INPUTS SHOULD BE AN ATTRIBUTE OF THE PLUGIN>
-        # setattr(<name of input>, <default value for that type>) 
+        # setattr(<name of input>, <default value for that type>)
        # self.pads = []
        # self.X_shape: List[int] = []
- 
-        self.num_outputs = 1 # Defined by schema 
+
+        self.num_outputs = 1  # Defined by schema
        self.plugin_namespace = ""
        self.plugin_name = plugin_name
-        self.plugin_version = "1"   
+        self.plugin_version = "1"

        # Set the timing cache ID to prevent unnecessary timing of second plugin instance
        self.timing_cache_id = ""

        self.attrs = attrs
-        
+
        self.tactic = None
-        
-
-        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR> 
+
+        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
        # ex.
        # TODO: need to parse the field collection here
        # if fc is not None:
        #     assert fc[0].name == "pads"
        #     self.pads = fc[0].data
@@ -77,14 +76,12 @@
            self.phase = phase

    def get_capability_interface(self, type):
        return self

-    def get_output_data_types(
-        self, input_types: List[trt.DataType]
-    ) -> trt.DataType:
-        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE 
+    def get_output_data_types(self, input_types: List[trt.DataType]) -> trt.DataType:
+        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
        # with torch.fake_tensor():
        #      <GENERATE FAKE INPUTS OF TYPE INPUT_TYPES>
        #      fake_outputs = torch.ops.<custom_ns>.<custom_op>(*fake_inputs)

        # return fake_outputs[index]
@@ -96,20 +93,20 @@
        self,
        inputs: List[trt.DimsExprs],
        shape_inputs,
        exprBuilder: trt.IExprBuilder,
    ) -> trt.DimsExprs:
-        
+
        print(inputs)

-    #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR 
-    #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE 
-    #    SHAPE MAP. 
+        #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
+        #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
+        #    SHAPE MAP.
        output_dims = trt.DimsExprs(inputs[0])

        return [output_dims]
-    
+
    def get_fields_to_serialize(self):
        # should be passed in as another argument
        field_names = []

        for key, value in self.attrs.items():
@@ -149,11 +146,11 @@
        self.X_shape = np.zeros((len(X_dims),))
        for i in range(len(X_dims)):
            self.X_shape[i] = X_dims[i]

    def supports_format_combination(self, pos, in_out, num_inputs):
-        return 
+        return
        assert num_inputs == 1
        assert pos < len(in_out)

        desc = in_out[pos].desc
        if desc.format != trt.TensorFormat.LINEAR:
@@ -166,11 +163,10 @@
        # output should have the same type as the input
        if pos == 1:
            return in_out[0].desc.type == desc.type

        assert False
-

    def enqueue(
        self,
        input_desc: List[trt.PluginTensorDesc],
        output_desc: List[trt.PluginTensorDesc],
@@ -180,40 +176,56 @@
        stream: int,
    ) -> None:
        # input and output memory handling
        input_mems = [None] * (len(inputs))

-        for i in range(len(inputs)): 
-            input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self)
+        for i in range(len(inputs)):
+            input_mems[i] = cp.cuda.UnownedMemory(
+                inputs[i],
+                np.prod(input_desc[i].dims)
+                * cp.dtype(trt.nptype(input_desc[i].type)).itemsize,
+                self,
+            )

        output_mems = [None] * (len(outputs))

        for i in range(len(outputs)):
-            output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self)
-    
+            output_mems[i] = cp.cuda.UnownedMemory(
+                outputs[i],
+                np.prod(output_desc[i].dims)
+                * cp.dtype(trt.nptype(output_desc[i].type)).itemsize,
+                self,
+            )

        input_data = [None] * ((len(inputs)))
        for i in range(len(inputs)):
-            input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0))
+            input_data[i] = cp.ndarray(
+                tuple(input_desc[i].dims),
+                dtype=input_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(input_mems[i], 0),
+            )

        output_data = [None] * ((len(outputs)))
        for i in range(len(outputs)):
-            output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0))
-
-        #TODO: This is just for a simple case for elementwise operations
+            output_data[i] = cp.ndarray(
+                (np.prod(output_desc[i].dims)),
+                dtype=output_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(output_mems[i], 0),
+            )
+
+        # TODO: This is just for a simple case for elementwise operations
        # using Torch implementation for now
-        input_torch_0 = torch.as_tensor(input_data[0], device='cuda')
-        input_torch_1 = torch.as_tensor(input_data[1], device='cuda')
+        input_torch_0 = torch.as_tensor(input_data[0], device="cuda")
+        input_torch_1 = torch.as_tensor(input_data[1], device="cuda")

        output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1)

        cp.copyto(output_data, output)
-

    def attach_to_context(self, context):
        return self.clone()
-    
+
    def get_valid_tactics(self):
        return [int(Tactic.TORCH), int(Tactic.TRITON)]

    def set_tactic(self, tactic):
        self.tactic = Tactic(tactic)
@@ -226,17 +238,17 @@
        cloned_plugin.__dict__.update(self.__dict__)
        return cloned_plugin


class PluginCreator(trt.IPluginCreatorV3One):  # type: ignore[misc]
-    def __init__(self, plugin_name : str, plugin_namespace : str, attrs):
-        trt.IPluginCreatorV3One.__init__(self)  
+    def __init__(self, plugin_name: str, plugin_namespace: str, attrs):
+        trt.IPluginCreatorV3One.__init__(self)

        self.name = plugin_name
        self.plugin_namespace = plugin_namespace
        self.plugin_version = "1"
-        
+
        field_names = []
        for name, (builtin, type_) in attrs.items():
            if builtin:
                if type_ is str:
                    field_names.append(
@@ -259,15 +271,12 @@
                    )
                )

        self.field_names = trt.PluginFieldCollection(field_names)

-    def create_plugin(
-        self, name: str, field_collection, phase=None
-    ) -> CustomPlugin:
-
-        
+    def create_plugin(self, name: str, field_collection, phase=None) -> CustomPlugin:
+
        attrs = {}
        # for f in fc:
        #     if f.name not in desc.input_attrs:
        #         raise AssertionError(
        #             f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
@@ -275,10 +284,9 @@

        #     if _is_numpy_array(desc.input_attrs[f.name]):
        #         attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name]))
        #     else:
        #         attrs[f.name] = desc.input_attrs[f.name](f.data)
-                
+
        custom_plugin = CustomPlugin(name, attrs)
-        
+
        return custom_plugin
-

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-26 20:16:28.712186+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/automatic_plugin/custom_op.py	2024-11-26 20:16:48.244419+00:00
@@ -1,7 +1,8 @@
import triton
import triton.language as tl
+

@triton.jit
def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
    # Program ID determines the block of data each thread will process
    pid = tl.program_id(0)
@@ -25,23 +26,23 @@
@custom_op("torchtrt_ex::elementwise_add", mutates_args=())  # type: ignore[misc]
def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    # Ensure the tensors are on the GPU
    assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
    assert X.shape == Y.shape, "Tensors must have the same shape."
-    
+
    # Create output tensor
    Z = torch.empty_like(X)
-    
+
    # Define block size
    BLOCK_SIZE = 1024
-    
+
    # Grid of programs
-    grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],)
-    
+    grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
+
    # Launch the kernel
    elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
-    
+
    return Z


# Using the module in PyTorch
# X = torch.randn(1024, device='cuda', requires_grad=True)
@@ -72,22 +73,31 @@

        return res


my_model = MyModel().to("cuda")
-m = torch.full((64, 64), 2, device='cuda',)
-n = torch.full((64, 64), 3, device='cuda',)
+m = torch.full(
+    (64, 64),
+    2,
+    device="cuda",
+)
+n = torch.full(
+    (64, 64),
+    3,
+    device="cuda",
+)
# print(torch.ops.torchtrt_ex.elementwise_add(m, n))
# print(my_model.forward(m, n))


@torch.library.register_fake("torchtrt_ex::elementwise_add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x

+
import torch_tensorrt as torchtrt


with torchtrt.logging.info():
    model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1)
    res = model_trt(m, n)
-    print(res)
\ No newline at end of file
+    print(res)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-26 20:16:28.728186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/__init__.py	2024-11-26 20:16:48.833342+00:00
@@ -1,6 +1,11 @@
-from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters
+from . import (
+    aten_ops_converters,
+    ops_evaluators,
+    prims_ops_converters,
+    plugin_ops_converters,
+)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import *  # noqa: F403
from ._TRTInterpreter import *  # noqa: F403
from .truncate_double import repair_double_inputs
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-26 20:16:28.728186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py	2024-11-26 20:16:49.583518+00:00
@@ -1 +1 @@
-from .plugin_generator import PluginCreator
\ No newline at end of file
+from .plugin_generator import PluginCreator
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-26 20:16:28.732186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py	2024-11-26 20:16:49.650545+00:00
@@ -17,25 +17,28 @@

logger = logging.getLogger(__name__)

TRT_PLUGIN_REGISTRY = trt.get_plugin_registry()

+
@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default)
def torchtrt_ex_elementwise_add(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
-): 
+):
    # logger.debug(f"plugin stuff here2")
    # return torch.add(args)
-    
+
    # How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
-    plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={})
-    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")    
-    
+    plugin_creator = PluginCreator(
+        "elementwise_add_plugin", plugin_namespace="", attrs={}
+    )
+    TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "")
+
    plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator(
        type="elementwise_add_plugin", version="1", plugin_namespace=""
    )
    assert plugin_creator, f"Unable to find elementwise_add_plugin creator"

@@ -44,45 +47,47 @@
    # plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs)
    # assert plugin, "Unable to create <PLUGIN_NAME>"

    # <GENERATE LINK BETWEEN PLUGIN AND INPUTS>
    #    <GET INPUTS INTO LIST>
-    #    <PASS TO PLUGIN>     
-    
+    #    <PASS TO PLUGIN>
+
    # return layer.get_output(0)
    field_configs = trt.PluginFieldCollection([])
-    
-    plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs)
+
+    plugin = plugin_creator.create_plugin(
+        name="elementwise_add_plugin", field_collection=field_configs
+    )
    assert plugin, "Unable to create CircularPaddingPlugin"
-    
+
    # input_tensor = args[
    #     0
    # ]  # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor
    # if not isinstance(input_tensor, trt.ITensor):
    #     # Freeze input tensor if not TensorRT Tensor already
    #     input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input")
-    
+
    lhs_dtype = None
    rhs_dtype = None
    lhs_val = args[0]
    rhs_val = args[1]
-    
+
    if isinstance(lhs_val, TRTTensor):
        lhs_dtype = lhs_val.dtype
        # is_lhs_trt_tensor = True
    if isinstance(rhs_val, TRTTensor):
        rhs_dtype = rhs_val.dtype
        # is_rhs_trt_tensor = True
-        
+
    print(lhs_dtype)
-    
+
    lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
    rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)

    layer = ctx.net.add_plugin_v3(
        [lhs_val, rhs_val], [], plugin
    )  # Add the plugin to the network being constructed
    # layer.name = f"automatic-{name}"
    return layer.get_output(0)


-# 1. generate plugin for any pytorch op
\ No newline at end of file
+# 1. generate plugin for any pytorch op
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-26 20:16:28.732186+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py	2024-11-26 20:16:49.769861+00:00
@@ -11,64 +11,63 @@


logger = logging.getLogger("CustomPlugin")

_numpy_to_plugin_field_type = {
-    np.dtype('int32'): trt.PluginFieldType.INT32,
-    np.dtype('int16'): trt.PluginFieldType.INT16,
-    np.dtype('int8'): trt.PluginFieldType.INT8,
-    np.dtype('bool'): trt.PluginFieldType.INT8,
-    np.dtype('int64'): trt.PluginFieldType.INT64,
-    np.dtype('float32'): trt.PluginFieldType.FLOAT32,
-    np.dtype('float64'): trt.PluginFieldType.FLOAT64,
-    np.dtype('float16'): trt.PluginFieldType.FLOAT16
+    np.dtype("int32"): trt.PluginFieldType.INT32,
+    np.dtype("int16"): trt.PluginFieldType.INT16,
+    np.dtype("int8"): trt.PluginFieldType.INT8,
+    np.dtype("bool"): trt.PluginFieldType.INT8,
+    np.dtype("int64"): trt.PluginFieldType.INT64,
+    np.dtype("float32"): trt.PluginFieldType.FLOAT32,
+    np.dtype("float64"): trt.PluginFieldType.FLOAT64,
+    np.dtype("float16"): trt.PluginFieldType.FLOAT16,
}

_built_in_to_plugin_field_type = {
    int: trt.PluginFieldType.INT64,
    float: trt.PluginFieldType.FLOAT64,
    bool: trt.PluginFieldType.INT8,
    # str is handled separately, so not needed here
}

+
class Tactic(IntEnum):
    TORCH = 1
    TRITON = 2

+
class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime):  # type: ignore[misc]
-    def __init__(
-        self, plugin_name : str, attrs, phase = None
-    ):
+    def __init__(self, plugin_name: str, attrs, phase=None):
        # TODO: needs an additional passed in arguments to specify the needs for each plugin
        # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83
        trt.IPluginV3.__init__(self)
        # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime
        trt.IPluginV3OneCore.__init__(self)
        # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder.
        trt.IPluginV3OneBuild.__init__(self)
        # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable
-        trt.IPluginV3OneRuntime.__init__(self)       
-        
+        trt.IPluginV3OneRuntime.__init__(self)
+
        # <ANY NON TENSOR INPUTS SHOULD BE AN ATTRIBUTE OF THE PLUGIN>
-        # setattr(<name of input>, <default value for that type>) 
+        # setattr(<name of input>, <default value for that type>)
        # self.pads = []
        # self.X_shape: List[int] = []
- 
-        self.num_outputs = 1 # Defined by schema 
+
+        self.num_outputs = 1  # Defined by schema
        self.plugin_namespace = ""
        self.plugin_name = plugin_name
-        self.plugin_version = "1"   
+        self.plugin_version = "1"

        # Set the timing cache ID to prevent unnecessary timing of second plugin instance
        self.timing_cache_id = ""

        self.attrs = attrs
-        
+
        self.tactic = None
-        
-
-        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR> 
+
+        # <GENERATE CODE FOR TAKING A FIELD COLLECTION CONTAINING THE NON TENSOR INPUTS AND SETTING AN ATTR>
        # ex.
        # TODO: need to parse the field collection here
        # if fc is not None:
        #     assert fc[0].name == "pads"
        #     self.pads = fc[0].data
@@ -77,14 +76,12 @@
            self.phase = phase

    def get_capability_interface(self, type):
        return self

-    def get_output_data_types(
-        self, input_types: List[trt.DataType]
-    ) -> trt.DataType:
-        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE 
+    def get_output_data_types(self, input_types: List[trt.DataType]) -> trt.DataType:
+        # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE
        # with torch.fake_tensor():
        #      <GENERATE FAKE INPUTS OF TYPE INPUT_TYPES>
        #      fake_outputs = torch.ops.<custom_ns>.<custom_op>(*fake_inputs)

        # return fake_outputs[index]
@@ -96,20 +93,20 @@
        self,
        inputs: List[trt.DimsExprs],
        shape_inputs,
        exprBuilder: trt.IExprBuilder,
    ) -> trt.DimsExprs:
-        
+
        print(inputs)

-    #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR 
-    #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE 
-    #    SHAPE MAP. 
+        #    WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR
+        #    THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE
+        #    SHAPE MAP.
        output_dims = trt.DimsExprs(inputs[0])

        return [output_dims]
-    
+
    def get_fields_to_serialize(self):
        # should be passed in as another argument
        field_names = []

        for key, value in self.attrs.items():
@@ -149,11 +146,11 @@
        self.X_shape = np.zeros((len(X_dims),))
        for i in range(len(X_dims)):
            self.X_shape[i] = X_dims[i]

    def supports_format_combination(self, pos, in_out, num_inputs):
-        return 
+        return
        assert num_inputs == 1
        assert pos < len(in_out)

        desc = in_out[pos].desc
        if desc.format != trt.TensorFormat.LINEAR:
@@ -166,11 +163,10 @@
        # output should have the same type as the input
        if pos == 1:
            return in_out[0].desc.type == desc.type

        assert False
-

    def enqueue(
        self,
        input_desc: List[trt.PluginTensorDesc],
        output_desc: List[trt.PluginTensorDesc],
@@ -180,40 +176,56 @@
        stream: int,
    ) -> None:
        # input and output memory handling
        input_mems = [None] * (len(inputs))

-        for i in range(len(inputs)): 
-            input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self)
+        for i in range(len(inputs)):
+            input_mems[i] = cp.cuda.UnownedMemory(
+                inputs[i],
+                np.prod(input_desc[i].dims)
+                * cp.dtype(trt.nptype(input_desc[i].type)).itemsize,
+                self,
+            )

        output_mems = [None] * (len(outputs))

        for i in range(len(outputs)):
-            output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self)
-    
+            output_mems[i] = cp.cuda.UnownedMemory(
+                outputs[i],
+                np.prod(output_desc[i].dims)
+                * cp.dtype(trt.nptype(output_desc[i].type)).itemsize,
+                self,
+            )

        input_data = [None] * ((len(inputs)))
        for i in range(len(inputs)):
-            input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0))
+            input_data[i] = cp.ndarray(
+                tuple(input_desc[i].dims),
+                dtype=input_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(input_mems[i], 0),
+            )

        output_data = [None] * ((len(outputs)))
        for i in range(len(outputs)):
-            output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0))
-
-        #TODO: This is just for a simple case for elementwise operations
+            output_data[i] = cp.ndarray(
+                (np.prod(output_desc[i].dims)),
+                dtype=output_desc[i].type,
+                memptr=cp.cuda.MemoryPointer(output_mems[i], 0),
+            )
+
+        # TODO: This is just for a simple case for elementwise operations
        # using Torch implementation for now
-        input_torch_0 = torch.as_tensor(input_data[0], device='cuda')
-        input_torch_1 = torch.as_tensor(input_data[1], device='cuda')
+        input_torch_0 = torch.as_tensor(input_data[0], device="cuda")
+        input_torch_1 = torch.as_tensor(input_data[1], device="cuda")

        output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1)

        cp.copyto(output_data, output)
-

    def attach_to_context(self, context):
        return self.clone()
-    
+
    def get_valid_tactics(self):
        return [int(Tactic.TORCH), int(Tactic.TRITON)]

    def set_tactic(self, tactic):
        self.tactic = Tactic(tactic)
@@ -226,17 +238,17 @@
        cloned_plugin.__dict__.update(self.__dict__)
        return cloned_plugin


class PluginCreator(trt.IPluginCreatorV3One):  # type: ignore[misc]
-    def __init__(self, plugin_name : str, plugin_namespace : str, attrs):
-        trt.IPluginCreatorV3One.__init__(self)  
+    def __init__(self, plugin_name: str, plugin_namespace: str, attrs):
+        trt.IPluginCreatorV3One.__init__(self)

        self.name = plugin_name
        self.plugin_namespace = plugin_namespace
        self.plugin_version = "1"
-        
+
        field_names = []
        for name, (builtin, type_) in attrs.items():
            if builtin:
                if type_ is str:
                    field_names.append(
@@ -259,15 +271,12 @@
                    )
                )

        self.field_names = trt.PluginFieldCollection(field_names)

-    def create_plugin(
-        self, name: str, field_collection, phase=None
-    ) -> CustomPlugin:
-
-        
+    def create_plugin(self, name: str, field_collection, phase=None) -> CustomPlugin:
+
        attrs = {}
        # for f in fc:
        #     if f.name not in desc.input_attrs:
        #         raise AssertionError(
        #             f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
@@ -275,10 +284,9 @@

        #     if _is_numpy_array(desc.input_attrs[f.name]):
        #         attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name]))
        #     else:
        #         attrs[f.name] = desc.input_attrs[f.name](f.data)
-                
+
        custom_plugin = CustomPlugin(name, attrs)
-        
+
        return custom_plugin
-

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants