Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support exporting Torch-TRT compiled Graphmodules #3262

Merged
merged 47 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
458a4d1
skip run_shape_analysis
lanluo-nvidia Oct 6, 2024
2f408f9
test
lanluo-nvidia Oct 6, 2024
1c5e86c
test
lanluo-nvidia Oct 6, 2024
ba487dc
test
lanluo-nvidia Oct 6, 2024
99d2274
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 6, 2024
2b43480
test
lanluo-nvidia Oct 6, 2024
17b57a6
feat: Add re-export functionality for Torch-TRT modules
peri044 Oct 10, 2024
b4e02e1
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 11, 2024
3d94f8b
test
lanluo-nvidia Oct 13, 2024
cb03ca1
feat: add support for re-exporting graph modules
peri044 Oct 14, 2024
28ba6cc
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 15, 2024
b89cbe0
resolve comments
lanluo-nvidia Oct 15, 2024
2843d37
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 16, 2024
3eb48d7
test
lanluo-nvidia Oct 16, 2024
839c72e
chore: updates
peri044 Oct 16, 2024
50eb0d8
replace dummy inference
lanluo-nvidia Oct 20, 2024
95ed602
test
lanluo-nvidia Oct 20, 2024
120f30d
test
lanluo-nvidia Oct 21, 2024
424cbf7
add run_test_with_dynamic_shape change
lanluo-nvidia Oct 21, 2024
2fc9cef
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 21, 2024
ef54cfc
split the PR, add dummy inference for converter test
lanluo-nvidia Oct 21, 2024
14f5d61
test
lanluo-nvidia Oct 22, 2024
7563959
test
lanluo-nvidia Oct 22, 2024
77355f0
test
lanluo-nvidia Oct 22, 2024
13361fd
add linear lowering meta val
lanluo-nvidia Oct 22, 2024
fca16a5
chore: updates
peri044 Oct 23, 2024
f0a9fef
add linear_lowering change
lanluo-nvidia Oct 23, 2024
cff64a4
test
lanluo-nvidia Oct 23, 2024
933abac
test
lanluo-nvidia Oct 23, 2024
8417684
resolve comments
lanluo-nvidia Oct 25, 2024
8676f88
test
lanluo-nvidia Oct 25, 2024
df13856
chore: updates
peri044 Oct 28, 2024
d406366
chore: updates
peri044 Oct 28, 2024
595ea6e
chore: updates
peri044 Oct 28, 2024
076f47a
resolve comments
lanluo-nvidia Oct 29, 2024
8250179
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 29, 2024
96e93e4
resolve comments
lanluo-nvidia Oct 29, 2024
675667b
chore: updates
peri044 Oct 29, 2024
4e1a538
chore: updates
peri044 Oct 31, 2024
fb12021
chore: updates
peri044 Oct 31, 2024
6b3f94c
chore: updates
peri044 Nov 1, 2024
1983c60
chore: add tests
peri044 Nov 1, 2024
dd94194
chore: updates
peri044 Nov 4, 2024
ea226d6
chore: address comments
peri044 Nov 13, 2024
0d04111
chore: rebase with main
peri044 Nov 13, 2024
772e5d1
chore: updates
peri044 Nov 13, 2024
f739f57
chore: fix tests
peri044 Nov 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ jobs:
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
popd

tests-py-torch-compile-be:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ jobs:
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
popd

tests-py-torch-compile-be:
Expand Down
62 changes: 62 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "core/runtime/runtime.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
Expand Down Expand Up @@ -253,6 +254,28 @@ std::string TRTEngine::get_engine_layer_info() {
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
}

std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
std::vector<at::Tensor> outputs;
TORCHTRT_CHECK(
(in_binding_names.size() == input_shapes.size()),
"The number of input shapes provided doesn't match with the number of input names registered.");
// Set all input shapes
for (size_t i = 0; i < input_shapes.size(); i++) {
exec_ctx->setInputShape(in_binding_names[i].c_str(), core::util::toDims(input_shapes[i]));
}
for (size_t i = 0; i < out_binding_names.size(); i++) {
auto output_shape = core::util::toVec(exec_ctx->getTensorShape(out_binding_names[i].c_str()));
auto output_dtype =
core::util::TRTDataTypeToScalarType(cuda_engine->getTensorDataType(out_binding_names[i].c_str()));
auto output_tensor = torch::empty(output_shape, torch::dtype(output_dtype));
outputs.push_back(output_tensor);
}
TORCHTRT_CHECK(
(out_binding_names.size() == outputs.size()),
"The number of output shapes inferred doesn't match with the number of output names registered.");
return outputs;
}

void TRTEngine::set_profiling_paths() {
device_profile_path =
std::filesystem::path{profile_path_prefix + "/" + name + "_device_config_profile.trace"}.string();
Expand Down Expand Up @@ -354,6 +377,45 @@ void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& seriali
<< ")");
}

FlattenedState TRTEngine::__obj_flatten__() {
peri044 marked this conversation as resolved.
Show resolved Hide resolved
// This method would be called by meta kernel of this custom class and it only needs to return a tuple
std::vector<std::string> serialized_info = this->serialize();

return std::tuple(
std::tuple("version", serialized_info[ABI_TARGET_IDX]),
std::tuple("name", serialized_info[NAME_IDX]),
std::tuple("device_info", serialized_info[DEVICE_IDX]),
std::tuple("serialized_engine", serialized_info[ENGINE_IDX]),
std::tuple("in_binding_names", serialized_info[INPUT_BINDING_NAMES_IDX]),
std::tuple("out_binding_names", serialized_info[OUTPUT_BINDING_NAMES_IDX]),
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
}

std::vector<std::string> TRTEngine::serialize() {
// Serialize TensorRT engine
auto serialized_trt_engine = make_trt(this->cuda_engine->serialize());

// Adding device info related meta data to the serialized file
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

std::vector<std::string> serialized_info;
serialized_info.resize(SERIALIZATION_LEN);

serialized_info[ABI_TARGET_IDX] = ABI_VERSION;
serialized_info[NAME_IDX] = this->name;
serialized_info[DEVICE_IDX] = this->device_info.serialize();
serialized_info[ENGINE_IDX] = base64_encode(trt_engine);
serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(this->in_binding_names);
serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(this->out_binding_names);
serialized_info[HW_COMPATIBLE_IDX] = this->hardware_compatible ? "1" : "0";
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();

return serialized_info;
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
17 changes: 17 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // ABI_VERSION
std::tuple<std::string, std::string>, // name
std::tuple<std::string, std::string>, // device
std::tuple<std::string, std::string>, // engine
std::tuple<std::string, std::string>, // input binding names
std::tuple<std::string, std::string>, // output binding names
std::tuple<std::string, std::string>, // HW compatibility
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
Expand Down Expand Up @@ -69,15 +80,21 @@ struct TRTEngine : torch::CustomClassHolder {
void enable_profiling();
void disable_profiling();
std::string get_engine_layer_info();

void dump_engine_layer_info_to_file(const std::string& path);
void dump_engine_layer_info();
int64_t get_device_memory_budget();
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

// Serde re-export functionality
FlattenedState __obj_flatten__();
std::vector<std::string> serialize();

// CUDAGraph-Related Functionality
at::cuda::CUDAGraph cudagraph = {};
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
Expand Down
35 changes: 9 additions & 26 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
namespace torch_tensorrt {
namespace core {
namespace runtime {
namespace {

std::string serialize_bindings(const std::vector<std::string>& bindings) {
std::stringstream ss;
Expand Down Expand Up @@ -66,6 +65,7 @@ std::string base64_decode(const std::string& in) {
return out;
}

namespace {
// TODO: Implement a call method
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
// auto input_vec = inputs.vec();
Expand All @@ -80,51 +80,30 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
// TODO: .def("run", &TRTEngine::Run)
.def("__str__", &TRTEngine::to_str)
.def("__repr__", &TRTEngine::to_str)
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
.def("enable_profiling", &TRTEngine::enable_profiling)
.def("disable_profiling", &TRTEngine::disable_profiling)
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def_property(
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
&TRTEngine::set_device_memory_budget)
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
// Serialize TensorRT engine
auto serialized_trt_engine = make_trt(self->cuda_engine->serialize());

// Adding device info related meta data to the serialized file
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

std::vector<std::string> serialize_info;
serialize_info.resize(SERIALIZATION_LEN);

serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
serialize_info[NAME_IDX] = self->name;
serialize_info[DEVICE_IDX] = self->device_info.serialize();
serialize_info[ENGINE_IDX] = base64_encode(trt_engine);
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize();
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
LOG_DEBUG("Serialized Target Platform: " << self->target_platform);

return serialize_info;
},
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
TRTEngine::verify_serialization_fmt(serialized_info);
return c10::make_intrusive<TRTEngine>(serialized_info);
});

TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine", execute_engine);
m.def("execute_engine(Tensor[] input_tensors, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]");
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; });
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
Expand Down Expand Up @@ -171,6 +150,10 @@ TORCH_LIBRARY(tensorrt, m) {
});
}

TORCH_LIBRARY_IMPL(tensorrt, CompositeExplicitAutograd, m) {
m.impl("execute_engine", execute_engine);
}

} // namespace
} // namespace runtime
} // namespace core
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ typedef enum {
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

std::string base64_encode(const std::string& in);
std::string base64_decode(const std::string& in);
std::string serialize_bindings(const std::vector<std::string>& bindings);

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device = RTDevice(),
Expand Down
34 changes: 24 additions & 10 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,19 +502,24 @@ def save(
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
)
else:
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
)
torch.jit.save(module, file_path)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
)
else:
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
)
torch.export.save(module, file_path)
elif module_type == _ModuleType.fx:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
)

# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
Expand All @@ -525,13 +530,22 @@ def save(
if not retrace:
from torch_tensorrt.dynamo._exporter import export

exp_program = export(module, arg_inputs, kwarg_inputs)
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
torch.export.save(exp_program, file_path)
exp_program = torch.export.export(
module,
tuple(arg_inputs),
kwargs=kwarg_inputs,
strict=False,
)
torch.export.save(exp_program, file_path)
28 changes: 27 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from torch_tensorrt.dynamo.utils import (
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
prepare_inputs,
set_log_level,
Expand Down Expand Up @@ -302,7 +303,6 @@ def compile(

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
Expand Down Expand Up @@ -433,6 +433,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
if not settings.use_fast_partitioner:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))

submodule_node_dict = {}
for node in partitioned_module.graph.nodes:
if "_run_on_acc" not in node.name:
continue
submodule_node_dict[node.name] = node

# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
Expand All @@ -452,6 +458,26 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
)
continue

if name not in submodule_node_dict:
raise ValueError(
f"node_name: {name} does not exist in the submodule node dictionary"
)

# set the submodule metadata back to the parent trt_module_node
metadata_list = get_output_metadata(submodule)
assert len(metadata_list) > 0
metadata_keys = ["val", "tensor_meta"]
for key in metadata_keys:
if key not in submodule_node_dict[name].meta:
meta_val_list = [
metadata[key] for metadata in metadata_list if key in metadata
]
submodule_node_dict[name].meta[key] = meta_val_list
logger.debug(
f"Updated metadata for node: {name} with its corresponding submodule outputs"
)
break

subgraph_data = PerSubgraphData()
subgraph_data.subgraph_name = name
subgraph_data.subgraph_op_count = len(
Expand Down
Loading
Loading