diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..f08cca14 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +termcolor diff --git a/scripts/build_contrib.sh b/scripts/build_contrib.sh index 62462176..29f76541 100755 --- a/scripts/build_contrib.sh +++ b/scripts/build_contrib.sh @@ -13,11 +13,11 @@ pushd /tmp/TensorRT git sparse-checkout set /tools/pytorch-quantization/ git apply --reject --whitespace=fix pytorch_nvidia_quantization.patch cd tools/pytorch-quantization/ - python setup.py install + sudo python3 setup.py install popd pushd $parentdir - python3 setup.py install --plugins --contrib + sudo python3 setup.py install --plugins --contrib popd diff --git a/setup.py b/setup.py index 3df5cc97..86c671dd 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,10 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension from packaging import version +REQUIREMENTS_PATH = 'requirements.txt' + +with open(REQUIREMENTS_PATH, 'r') as file: + required_libraries = file.read().splitlines() def trt_inc_dir(): return "/usr/include/aarch64-linux-gnu" @@ -55,5 +59,6 @@ def trt_lib_dir(): packages=find_packages(exclude=exclude_dir), ext_package='torch2trt', ext_modules=ext_modules, + install_requires=required_libraries, cmdclass={'build_ext': BuildExtension} ) diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 5ffe38b4..a9fa49df 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -69,3 +69,4 @@ from .transpose import * from .unary import * from .view import * +from .zeros import * diff --git a/torch2trt/converters/avg_pool.py b/torch2trt/converters/avg_pool.py index 185af508..02eaea1a 100644 --- a/torch2trt/converters/avg_pool.py +++ b/torch2trt/converters/avg_pool.py @@ -2,6 +2,55 @@ from torch2trt.module_test import add_module_test +@tensorrt_converter('torch.nn.functional.avg_pool1d') +def convert_avg_pool1d(ctx): + # At the time of this implementation, TensorRT 8.x does not yet support avg pooling in 1D using `add_pooling_nd(...)`. + # As such, we use a workaround here, by unsqueezing another dimension into the input (thus transforming it from + # (N, C, L) to (N, C, L, 1)) so that we can use 2D max pooling across the last three dimensions. + + input = get_arg(ctx, 'input', pos=0, default=None) + input_trt = trt_(ctx.network, input) + output = ctx.method_return + + kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None) + stride = get_arg(ctx, 'stride', pos=2, default=None) + padding = get_arg(ctx, 'padding', pos=3, default=0) + ceil_mode = get_arg(ctx, 'ceil_mode', pos=4, default=False) + count_include_pad = get_arg(ctx, 'count_include_pad', pos=5, default=True) + + # Convert inputs to be 2d compatible as inputs will always be 1d. + kernel_size = (kernel_size, 1) + stride = kernel_size if not stride else (stride, 1) + padding = (padding, 0) + + # Shuffle layer to unsqueeze another dimension for 2D max pooling. + unsqueeze_layer = ctx.network.add_shuffle(input_trt) + set_layer_precision(ctx, unsqueeze_layer) + unsqueeze_layer.reshape_dims = tuple([*input_trt.shape, 1]) + unsqueeze_trt = unsqueeze_layer.get_output(0) + + # Use 2D max pooling here to fake 1D max pooling. + layer = ctx.network.add_pooling_nd( + input=unsqueeze_trt, + type=trt.PoolingType.AVERAGE, + window_size=kernel_size, + ) + set_layer_precision(ctx, layer) + layer.stride_nd = stride + layer.padding_nd = padding + layer.average_count_excludes_padding = not count_include_pad + + if ceil_mode: + layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP + + pooling_trt = layer.get_output(0) + + # Shuffle layer to squeeze out dimension that was just added for 2D max pooling so return is still in 1D. + squeeze_layer = ctx.network.add_shuffle(pooling_trt) + set_layer_precision(ctx, squeeze_layer) + squeeze_layer.reshape_dims = tuple(pooling_trt.shape[:-1]) + output._trt = squeeze_layer.get_output(0) + @tensorrt_converter("torch.nn.functional.avg_pool2d", enabled=trt_version() < '7.0') def convert_avg_pool2d(ctx): # parse args @@ -83,12 +132,14 @@ def convert_avg_pool_trt7(ctx): layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP output._trt = layer.get_output(0) - - + + @add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6)]) @add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 5, 7)]) def test_avg_pool2d_without_ceil_mode(): - return torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) + return torch.nn.AvgPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=False + ) @add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6)]) @@ -102,10 +153,14 @@ def test_avg_pool2d_with_ceil_mode(): @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 4, 6)], enabled=trt_version() >= '7.0') @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 5, 7)], enabled=trt_version() >= '7.0') def test_avg_pool3d_without_ceil_mode_trt7(): - return torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=False) + return torch.nn.AvgPool3d( + kernel_size=3, stride=2, padding=1, ceil_mode=False + ) @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 4, 6)], enabled=trt_version() >= '7.0') @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 5, 7)], enabled=trt_version() >= '7.0') def test_avg_pool3d_with_ceil_mode_trt7(): - return torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True, count_include_pad=False) # TRT does not support ceil_mode=True && count_include_pad=True + return torch.nn.AvgPool3d( + kernel_size=3, stride=2, padding=1, ceil_mode=True, count_include_pad=False + ) # TRT does not support ceil_mode=True && count_include_pad=True diff --git a/torch2trt/converters/expand.py b/torch2trt/converters/expand.py index e0d07540..2a22a55d 100644 --- a/torch2trt/converters/expand.py +++ b/torch2trt/converters/expand.py @@ -2,10 +2,10 @@ from torch2trt.module_test import add_module_test +@tensorrt_converter('torch.Tensor.expand_as') @tensorrt_converter('torch.Tensor.expand') def convert_expand(ctx): input = ctx.method_args[0] - sizes = ctx.method_args[1:] output = ctx.method_return inshape = tuple(input.shape)[1:] # exclude batch diff --git a/torch2trt/converters/zeros.py b/torch2trt/converters/zeros.py new file mode 100644 index 00000000..b8209576 --- /dev/null +++ b/torch2trt/converters/zeros.py @@ -0,0 +1,66 @@ +from torch2trt.torch2trt import * +from torch2trt.module_test import add_module_test + + +def _set_layer_precision(ctx, layer): + # Supported TRT precisions as given by torch2trt_kwargs. + INT8_MODE = "int8_mode" + FP16_MODE = "fp16_mode" + + # Check that args exist as expected in torch2trt_kwargs. + trt_kwargs = ctx.torch2trt_kwargs + assert INT8_MODE in trt_kwargs + assert FP16_MODE in trt_kwargs + + is_int8 = trt_kwargs.get(INT8_MODE, False) + is_fp16 = trt_kwargs.get(FP16_MODE, False) + + if is_int8: + layer.precision = trt.int8 + layer.set_output_type(0, trt.int8) + elif is_fp16: + layer.precision = trt.float16 + layer.set_output_type(0, trt.float16) + + +@tensorrt_converter('torch.zeros') +def convert_zeros(ctx): + tensor = ctx.method_return + + # Implementation copied from add_trt_constant. + shape = tuple(tensor.shape[1:]) + array = tensor[0].detach().cpu().numpy() + layer = ctx.network.add_constant(shape, array) + + _set_layer_precision(ctx, layer) + + tensor._trt = layer.get_output(0) + + +class Zeros(torch.nn.Module): + def __init__(self, *size): + super().__init__() + self.size = size + + def forward(self, x): + return x + torch.zeros(*self.size, device=torch.device('cuda')) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)]) +def test_zeros(): + return Zeros((1, 2, 3, 4)) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)]) +def test_zeros_var_args(): + return Zeros(1, 2, 3, 4) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)], fp16_mode=True) +def test_zeros_fp16_mode(): + return Zeros(1, 2, 3, 4) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 2, 3, 4)], int8_mode=True) +def test_zeros_int8_mode(): + return Zeros(1, 2, 3, 4) diff --git a/torch2trt/module_test.py b/torch2trt/module_test.py index fb158fe8..3a316407 100644 --- a/torch2trt/module_test.py +++ b/torch2trt/module_test.py @@ -1,7 +1,3 @@ -import torch -import torchvision - - class ModuleTest(object): def __init__(self, module_fn, dtype, device, input_shapes, **torch2trt_kwargs): self.module_fn = module_fn