Skip to content

Commit

Permalink
chore: address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Nov 13, 2024
1 parent dd94194 commit ea226d6
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 60 deletions.
23 changes: 23 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "core/runtime/runtime.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
Expand Down Expand Up @@ -253,6 +254,28 @@ std::string TRTEngine::get_engine_layer_info() {
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
}

std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
std::vector<at::Tensor> outputs;
TORCHTRT_CHECK(
(in_binding_names.size() == input_shapes.size()),
"The number of input shapes provided doesn't match with the number of input names registered.");
// Set all input shapes
for (size_t i = 0; i < input_shapes.size(); i++) {
exec_ctx->setInputShape(in_binding_names[i].c_str(), core::util::toDims(input_shapes[i]));
}
for (size_t i = 0; i < out_binding_names.size(); i++) {
auto output_shape = core::util::toVec(exec_ctx->getTensorShape(out_binding_names[i].c_str()));
auto output_dtype =
core::util::TRTDataTypeToScalarType(cuda_engine->getTensorDataType(out_binding_names[i].c_str()));
auto output_tensor = torch::empty(output_shape, torch::dtype(output_dtype));
outputs.push_back(output_tensor);
}
TORCHTRT_CHECK(
(out_binding_names.size() == outputs.size()),
"The number of output shapes inferred doesn't match with the number of output names registered.");
return outputs;
}

void TRTEngine::set_profiling_paths() {
device_profile_path =
std::filesystem::path{profile_path_prefix + "/" + name + "_device_config_profile.trace"}.string();
Expand Down
7 changes: 4 additions & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // name
std::tuple<std::string, std::string>, // device
std::tuple<std::string, std::string>, // engine
std::tuple<std::string, std::vector<std::string>>, // input binding names
std::tuple<std::string, std::vector<std::string>>, // output binding names
std::tuple<std::string, bool>, // HW compatibility
std::tuple<std::string, std::string>, // input binding names
std::tuple<std::string, std::string>, // output binding names
std::tuple<std::string, std::string>, // HW compatibility
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform

Expand Down Expand Up @@ -87,6 +87,7 @@ struct TRTEngine : torch::CustomClassHolder {
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

Expand Down
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def_property(
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
Expand Down
8 changes: 5 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Any, Callable, Dict, List, Optional

import torch
from torch._decomp import register_decomposition
from torch._export.utils import _decomp_table_to_post_autograd_aten
from torch._decomp import (
_core_aten_decompositions_post_autograd,
register_decomposition,
)
from torch._ops import OpOverload
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
Expand Down Expand Up @@ -412,7 +414,7 @@ def get_decompositions(
return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS}
else:
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080
decomp_table = _decomp_table_to_post_autograd_aten()
decomp_table = _core_aten_decompositions_post_autograd()
DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = {
decomp: decomp_table[decomp]
for decomp in decomp_table
Expand Down
77 changes: 23 additions & 54 deletions py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import base64
from collections import defaultdict
from typing import Any, Dict, List
from typing import Any, List

import tensorrt as trt
import torch
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape
from torch_tensorrt.logging import TRT_LOGGER


@torch.library.register_fake("tensorrt::execute_engine") # type: ignore
Expand All @@ -15,31 +13,6 @@ def fake_tensorrt_execute_engine(
"""
We infer outputs using the TRT engine and inputs and return fake tensors in this meta kernel.
"""

# Get the TRT engine from the fake TRTEngine object
serialized_state = fake_trt_engine.wrapped_obj.state_dict

serialized_engine = base64.b64decode(serialized_state["serialized_engine"])

# Store input/output names for shape inference
input_names = serialized_state["in_binding_names"]
output_names = serialized_state["out_binding_names"]
assert len(input_names) == len(
inputs
), f"Number of inputs serialized in TRTEngine {len(input_names)} doesn't match with the number of inputs found during meta kernel execution {len(inputs)} for execute_engine op"

# Deserialize the TRT engine
# TODO: Probably unsafe deserialization. Should we expose infer shape mechanism through TRTEngine class ?
try:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)
except Exception as e:
raise AssertionError(
"TRT engine deserialization failed during meta kernel execution. Please verify if the environment in which you are exporting is same as the one in which you compiled"
)

context = engine.create_execution_context()

# Here's what we are doing
# 1) Check if inputs are dynamic (they have sym ints in their shapes)
# 2) For dynamic inputs, we gather min_input_shape and max_input shape for all inputs
Expand All @@ -52,28 +25,21 @@ def fake_tensorrt_execute_engine(
else:
modes = ["opt"]

# Get the TRTEngine class and infer output shapes based on input shapes
trt_engine = fake_trt_engine.wrapped_obj.engine
outputs_mode_dict = defaultdict(list)
for mode in modes:
for input_idx, input in enumerate(inputs):
# Using TensorRT's infer shape mechanism to infer output shapes
input_shape = unwrap_tensor_shape(input, mode=mode)

context.set_input_shape(input_names[input_idx], input_shape)

for output_name in output_names:
output_shape = context.get_tensor_shape(output_name)
outputs_mode_dict[mode].append(output_shape)
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
proxy_outputs = trt_engine.infer_outputs(input_shapes)
outputs_mode_dict[mode].extend(proxy_outputs)

# Store the number of outputs
if {"min", "max"}.issubset(outputs_mode_dict):
assert len(outputs_mode_dict["min"]) == len(outputs_mode_dict["max"])
num_outputs = len(outputs_mode_dict["min"])
elif "opt" in outputs_mode_dict:
num_outputs = len(outputs_mode_dict["opt"])

assert (
len(output_names) == num_outputs
), f"Number of outputs serialized in TRTEngine {len(output_names)} doesn't match with the number of outputs found during meta kernel execution {num_outputs} for execute_engine op"

fake_outputs = []
for out_idx in range(num_outputs):
output_shape = []
Expand All @@ -82,12 +48,11 @@ def fake_tensorrt_execute_engine(
# Note: We can't establish a relationship b/w incoming input symbolic shape (eg: s0)
# and TensorRT's output shape (represented as unbacked u0). This situation doesn't seem
# to affect compilation results / serialization during our testing.
output_min_shape = outputs_mode_dict["min"][out_idx]
output_opt_shape = outputs_mode_dict["opt"][out_idx]
output_max_shape = outputs_mode_dict["max"][out_idx]
output_min_shape = outputs_mode_dict["min"][out_idx].size()
output_opt_shape = outputs_mode_dict["opt"][out_idx].size()
output_max_shape = outputs_mode_dict["max"][out_idx].size()

ctx = torch._custom_ops.get_ctx()
output_shape = []
for min_val, opt_val, max_val in zip(
output_min_shape, output_opt_shape, output_max_shape
):
Expand All @@ -102,26 +67,27 @@ def fake_tensorrt_execute_engine(
else:
output_shape.append(min_val)
else:
output_shape.extend(outputs_mode_dict["opt"][out_idx])
output_shape.extend(outputs_mode_dict["opt"][out_idx].size())

fake_outputs.append(input.new_empty(output_shape))
fake_outputs.append(
torch.empty(output_shape, dtype=outputs_mode_dict["opt"][out_idx].dtype)
)

return fake_outputs


@torch._library.register_fake_class("tensorrt::Engine")
class FakeTRTEngine:
def __init__(self, state_dict: Dict[str, Any]) -> None:
self.state_dict = state_dict
def __init__(self, engine_info: List[str]) -> None:
self.engine = torch.classes.tensorrt.Engine(engine_info)

@classmethod
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
breakpoint()
state_dict = {}
for key, val in flattened_tq:
state_dict[key] = val
engine_idx = torch.ops.tensorrt.ENGINE_IDX()
engine_info = [info[1] for info in flattened_tq]
engine_info[engine_idx] = base64.b64decode(engine_info[engine_idx])

return cls(state_dict)
return cls(engine_info)

def enable_profiling(self) -> Any:
pass
Expand Down Expand Up @@ -156,5 +122,8 @@ def streamable_device_memory_budget_getter(self) -> Any:
def automatic_device_memory_budget_getter(self) -> Any:
pass

def infer_outputs(self, input_shapes: List[Any]) -> Any:
pass

def __setstate__(self, serialized_state: List[str]) -> Any:
pass

0 comments on commit ea226d6

Please sign in to comment.