Replies: 2 comments 2 replies
-
Questions For MVP:
|
Beta Was this translation helpful? Give feedback.
-
MVP is done correct ? @gs-olive |
Beta Was this translation helpful? Give feedback.
-
Questions For MVP:
|
Beta Was this translation helpful? Give feedback.
-
MVP is done correct ? @gs-olive |
Beta Was this translation helpful? Give feedback.
-
Autocast
TL;DR
Allow a broader range of valid input types to Torch-TRT-compiled engines, enabling support for Torch operations which require Int64-type inputs (generally index-related, like
aten::scatter
).Goal(s)
Currently, if a Torch-TRT graph is given Int64-type tensors as input, compilation fails. Additionally, if Int64-type tensors are provided at runtime, the inference can crash unexpectedly. The same holds true for Double inputs, among others. Torch-TensorRT should allow non-TensorRT allowed inputs to the graph (i.e.
torch.long
) and optionally augment the provided user graph with automatic data-type casting to ensure input tensors of reasonable type flow seamlessly through the graph.Torch-TRT would mimic Torch execution in a sense, as inputs would only change type for TensorRT engines, and would be casted back to their original type afterward. Ultimately, the goal is to ensure inserted type-casting is used minimally and undone/reverted where applicable, to avoid modifying user-provided tensor inputs.
See Issues #1121, #1346, #1546, #1543
Usecases
There are two key usecases to be aware of in this situation:
1. The user specifies a TensorRT-unsupported data type (ie.
torch.long
), and provides that data type at runtime.autocast
feature is implicit, as compilation cannot proceed without casting inputs to a type compatible with TensorRT Engines.%x
and%y
, and a single Int64 output,%z
. Additionally, assume the entire engine can be run in TensorRT, and is thus marked to run in TensorRT. Then, the tensors%x
and%y
would be downcasted to Int32 by an auxiliary Torch engine, run through the TensorRT engine, and the outputs would be upcasted to Int64 by another auxiliary Torch engine.2. The user specifies a TensorRT-supported data type at compilation time, but provides another at runtime.
This case is more challenging, as it is difficult for Torch-TRT to mimic Torch behavior in such a scenario. The reason is that during partitioning, we cannot reasonably determine what the output dtype of an engine should be on any input dtype; we can only determine what the output dtype of an engine should be on the specified (or inferred) dtype.
If the user specifies Int32 at compilation time, for example, but provides an Int64 or Float input at runtime, the
autocast
feature could do one of two things:If a user specifies an Int64 input, with autocast enabled, inputs will only be casted for TensorRT engines and not for Torch engines.
Edge Case: A user specifies Int64 input at compile-time but also
require_full_compilation=True
, and provides an Int64 input at runtimePotentially Problematic Edge Case: A user specifies Int64 input at compile-time, provides an Int64 input at runtime, but a TensorRT engine uses in-place operations to modify inputs which it does not later return
Proposed APIs / UX
This feature would be enabled via a flag in the compilation arguments,
autocast=True
, and would allow users to input a larger set of datatypes as inputs, with type-casting ops added automatically.Example Workflow
Limitations
This feature does not make Torch-TRT perfectly mimic Torch's handling of data types, and it cannot reasonably do so, as mentioned here. It also does not make TensorRT "compatible" with Int64, Double, or other currently-unsupported data types. It simply abstracts away the necessary casting for switching between data types for compatibility with TensorRT.
Note: Additionally, there are some challenges arising from the use of in-place operations. If, for example, an input to an engine is modified within said engine, but not returned as an output,
autocast
will not be able to detect this change and correctly cast the input back to its original type.Internal Implementation
Design
The structure is as follows. Assume
%x
and%y
are inputs, and%z
is the output of a segmented block determined to run in TensorRT by the partitioning module. We currently record only the shapes of the inputs and outputs, but we are also interested in the types of the inputs and outputs after completion of the computation. See the diagram below for an example tensor-trace diagram.Extensions Required to Core API implementations
Key input-checking required
autocast=True
andrequire_full_compilation=True
Updates to Partitioning needed
aten::to
operations to properly cast inputs to all enginesData Structures
The
SegmentedBlock
data structure is already sufficient for determining the quantity of inputs and outputs across segmented blocks, but an additional data structure would be helpful in determining type constraints for those inputs/outputs. Specifically, during the course of theforward
pass dry-run in partitioning, it would be helpful to store the output types of Torch blocks and ensure these are properly casted as inputs to subsequent TensorRT blocks.Each TensorRT block will need 2 auxiliary Torch blocks surrounding it. The first Torch block casts the tensors to valid types for the TensorRT block, potentially making a copy so as not to modify user-provided inputs. The second Torch block takes the output of the TensorRT block and casts the outputs to the necessary type for the next block.
A straightforward way to complete all of the above is simply to track the data types in, and out, of each block when run in Torch, then perform casting for all blocks, to ensure inputs have the correct type. Then, for TensorRT-executed blocks, we can augment the cast to ensure the data types fed in are a valid TRT type. Then, each Torch block will begin with
aten::to
casts of all of its inputs, and each TensorRT block will be prepended with a Torch block (or section of a previous Torch block), casting its inputs to compatible type.We may need to add up to 2 auxiliary Torch blocks in the code (one at the beginning of the graph, one at the end), to ensure casting is performed before the first TensorRT block, and after the last.
Details specific for TorchScript Support
See above for the TorchScript details
Details specific for FX support
Since fx2trt does not employ partitioning in the same way TorchScript does, this feature could potentially be extended to perform Python-level casting of inputs to TRT operators and fusions automatically.
Implementation Phases
Prototype - L (complete)
MVP
(1.4.0/1.5.0)
- M (complete)dtype=torch.long
for Python TorchScript and C++ APIs.Extension Phase 1 - S
Extension Phase 2 - M
autocast
optional argument tocompile
in Python TorchScript and C++ APIs.Beta Was this translation helpful? Give feedback.
All reactions