Skip to content

Commit e32f7e9

Browse files
committed
cherrypick #3387
1 parent f50c45d commit e32f7e9

13 files changed

+443
-682
lines changed

examples/distributed_inference/llama3_model.py

-538
This file was deleted.

examples/distributed_inference/tensor_parallel_llama3.py

-70
This file was deleted.

examples/distributed_inference/tensor_parallel_simple_example.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import tensorrt as trt
44
import torch
5+
import torch.distributed as dist
56
import torch.nn as nn
67
import torch_tensorrt
78
from tensor_parallel_initialize_dist import initialize_distributed_env
@@ -15,7 +16,6 @@
1516
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
1617
"./tensor_parallel_simple_example"
1718
)
18-
import tensorrt_llm
1919

2020
"""
2121
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
@@ -65,7 +65,6 @@ def forward(self, x):
6565
inp = torch.rand(20, 10, device="cuda")
6666
python_result = tp_model(inp)
6767

68-
6968
backend = "torch_tensorrt"
7069
tp_model = torch.compile(
7170
tp_model,
@@ -75,23 +74,28 @@ def forward(self, x):
7574
"enabled_precisions": {torch.float32, torch.float16},
7675
"use_python_runtime": True,
7776
"min_block_size": 1,
78-
"use_aot_joint_export": False,
77+
"use_distributed_mode_trace": True,
7978
},
80-
dynamic=False,
79+
dynamic=None,
8180
)
8281

83-
for i in range(10):
84-
# For TP, input needs to be same across all TP ranks.
85-
# Setting the random seed is to mimic the behavior of dataloader.
86-
torch.manual_seed(i)
87-
inp = torch.rand(20, 10, device="cuda")
88-
start = time.time()
89-
output = tp_model(inp)
90-
end = time.time()
91-
if i == 0:
92-
logger.info(f"Compilation time is {end-start}")
93-
assert (
94-
python_result - output
95-
).std() < 0.01, "Compilation result is not correct."
96-
elif _rank == 0:
97-
logger.info(f"Inference time is {end-start}")
82+
try:
83+
for i in range(10):
84+
# For TP, input needs to be same across all TP ranks.
85+
# Setting the random seed is to mimic the behavior of dataloader.
86+
torch.manual_seed(i)
87+
inp = torch.rand(20, 10, device="cuda")
88+
start = time.time()
89+
output = tp_model(inp)
90+
end = time.time()
91+
if i == 0:
92+
logger.info(f"Compilation time is {end-start}")
93+
assert (
94+
python_result - output
95+
).std() < 0.01, "Compilation result is not correct."
96+
elif _rank == 0:
97+
logger.info(f"Inference time is {end-start}")
98+
finally:
99+
# This cleans up the distributed process group
100+
if dist.is_initialized():
101+
dist.destroy_process_group()

py/torch_tensorrt/dynamo/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@
4646
IMMUTABLE_WEIGHTS = True
4747
ENABLE_WEIGHT_STREAMING = False
4848
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
49-
USE_AOT_JOINT_EXPORT = True
5049
TILING_OPTIMIZATION_LEVEL = "none"
5150
L2_LIMIT_FOR_TILING = -1
51+
USE_DISTRIBUTED_MODE_TRACE = False
5252

5353

5454
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
TILING_OPTIMIZATION_LEVEL,
3636
TIMING_CACHE_PATH,
3737
TRUNCATE_DOUBLE,
38-
USE_AOT_JOINT_EXPORT,
38+
USE_DISTRIBUTED_MODE_TRACE,
3939
USE_EXPLICIT_TYPING,
4040
USE_FAST_PARTITIONER,
4141
USE_FP32_ACC,
@@ -94,9 +94,9 @@ class CompilationSettings:
9494
enable_weight_streaming (bool): Enable weight streaming.
9595
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
9696
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
97-
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
9897
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
9998
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
99+
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
100100
"""
101101

102102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -137,9 +137,9 @@ class CompilationSettings:
137137
immutable_weights: bool = IMMUTABLE_WEIGHTS
138138
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
139139
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
140-
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
141140
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
142141
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142+
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
143143

144144

145145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/backend/backends.py

+35-37
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from torch._dynamo.backends.common import aot_autograd
1111
from torch._dynamo.utils import detect_fake_mode
1212
from torch._functorch.aot_autograd import aot_export_joint_simple
13+
from torch.distributed.tensor import DTensor
1314
from torch_tensorrt.dynamo import CompilationSettings
1415
from torch_tensorrt.dynamo._compiler import compile_module
1516
from torch_tensorrt.dynamo.lowering import (
1617
get_decompositions,
17-
modify_reshape_complex_nodes,
1818
post_lowering,
1919
remove_detach,
2020
remove_sym_nodes,
@@ -52,25 +52,39 @@ def aot_torch_tensorrt_aten_backend(
5252
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
5353
) -> torch.nn.Module:
5454
settings, engine_cache = parse_dynamo_kwargs(kwargs)
55-
if settings.use_aot_joint_export:
56-
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
57-
logger.debug("Wrapping the backend with aot_autograd\n")
58-
_pretraced_backend_autograd = functools.partial(
59-
_pretraced_backend, settings=settings, engine_cache=engine_cache
60-
)
61-
settings_aot_autograd = {}
62-
settings_aot_autograd["decompostions"] = get_decompositions(
63-
settings.enable_experimental_decompositions
64-
)
65-
# This is added since detach lowering leads to alias nodes
66-
# Error - View operation returned a tensor that is the same as the input base tensor
67-
# torch nop_decompositions in torch/_decomp/decompositions.py
68-
if aten.detach in settings_aot_autograd["decompositions"]:
69-
del settings_aot_autograd["decompositions"][aten.detach]
70-
return aot_autograd(
71-
fw_compiler=_pretraced_backend_autograd,
72-
decompositions=get_decompositions(settings.enable_experimental_decompositions),
73-
)(gm, sample_inputs)
55+
56+
if settings.use_distributed_mode_trace:
57+
logger.debug(
58+
"Wrapping the backend with aot_autograd for Distributed examples\n"
59+
)
60+
_pretraced_backend_autograd = functools.partial(
61+
_pretraced_backend, settings=settings, engine_cache=engine_cache
62+
)
63+
settings_aot_autograd = {}
64+
settings_aot_autograd["decompositions"] = get_decompositions(
65+
settings.enable_experimental_decompositions
66+
)
67+
# This is added since detach lowering leads to alias nodes
68+
# Error - View operation returned a tensor that is the same as the input base tensor
69+
# torch nop_decompositions in torch/_decomp/decompositions.py
70+
# transpose key deleted since not desirable to lower it to permute
71+
to_delete = {
72+
key
73+
for key in settings_aot_autograd["decompositions"]
74+
if "detach" in key._name
75+
}
76+
for key in to_delete:
77+
del settings_aot_autograd["decompositions"][key]
78+
79+
return aot_autograd(
80+
fw_compiler=_pretraced_backend_autograd,
81+
decompositions=settings_aot_autograd["decompositions"],
82+
)(gm, sample_inputs)
83+
if any(isinstance(tensor, DTensor) for tensor in sample_inputs):
84+
logger.warning(
85+
"It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple"
86+
)
87+
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
7488

7589

7690
def _pretraced_backend(
@@ -110,18 +124,8 @@ def _pretraced_backend(
110124
# Remove detach nodes
111125
remove_detach(gm, settings)
112126

113-
complexInputIndices = []
114-
for i, torch_input in enumerate(torch_inputs):
115-
if torch_inputs[i].dtype == torch.complex64:
116-
complexInputIndices.append(i)
117-
torch_input_real = torch_inputs[i].real
118-
torch_input_imaginary = torch_inputs[i].imag
119-
torch_inputs[i] = torch.stack(
120-
(torch_input_real, torch_input_imaginary), dim=-1
121-
)
122-
123127
# Invoke AOTAutograd to translate operators to aten
124-
if settings.use_aot_joint_export:
128+
if not settings.use_distributed_mode_trace:
125129
gm = aot_export_joint_simple(
126130
gm,
127131
sample_inputs,
@@ -137,12 +141,6 @@ def _pretraced_backend(
137141

138142
logger.debug("Lowered Input graph:\n " + str(gm.graph))
139143

140-
if complexInputIndices:
141-
modify_reshape_complex_nodes(gm, complexInputIndices)
142-
logger.debug(
143-
"Input graph after modifying complex nodes:\n " + str(gm.graph)
144-
)
145-
146144
torchtrt_inputs = prepare_inputs(
147145
torch_inputs, disable_memory_format_check=True
148146
)

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from typing import Dict, Sequence, Tuple, Union
55

6+
import tensorrt as trt
67
from torch.fx.node import Argument, Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
89
from torch_tensorrt.dynamo.conversion import impl
@@ -16,8 +17,6 @@
1617
tensorrt_fused_nccl_reduce_scatter_op,
1718
)
1819

19-
import tensorrt as trt
20-
2120
_LOGGER: logging.Logger = logging.getLogger(__name__)
2221

2322
if load_tensorrt_llm():
@@ -30,7 +29,7 @@ def fused_nccl_gather(
3029
kwargs: Dict[str, Argument],
3130
name: str,
3231
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
33-
return impl.distributed.nccl_gather(
32+
return impl.nccl_ops.nccl_gather(
3433
ctx,
3534
target,
3635
SourceIR.ATEN,
@@ -46,15 +45,14 @@ def fused_nccl_reduce_scatter(
4645
kwargs: Dict[str, Argument],
4746
name: str,
4847
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
49-
return impl.distributed.nccl_reduce_scatter(
48+
return impl.nccl_ops.nccl_reduce_scatter(
5049
ctx,
5150
target,
5251
SourceIR.ATEN,
5352
name,
5453
[args[0]],
5554
)
5655

57-
breakpoint()
5856
else:
5957
_LOGGER.debug(
6058
"Did not load torch.distributed converters since TensorRT-LLM is not available"

py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from typing import Optional, Tuple, Union
44

55
import numpy as np
6+
import tensorrt as trt
67
from torch.fx.node import Argument, Target
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
89
from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name
910

10-
import tensorrt as trt
11-
1211

1312
# class for AllReduce
1413
class AllReduceStrategy(IntEnum):
@@ -94,7 +93,7 @@ def nccl_reduce_scatter(
9493
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
9594
)
9695

97-
p_dtype = trt.float16
96+
p_dtype = trt.float32
9897
pf_dtype = trt.PluginField(
9998
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
10099
)

py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def fuse_distributed_ops(
4949
== torch.ops._c10d_functional.wait_tensor.default
5050
):
5151
wait_tensor_node = list(node.users)[0]
52-
fused_op = None
5352
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
5453
with gm.graph.inserting_after(wait_tensor_node):
5554
fused_node = gm.graph.create_node(
@@ -58,11 +57,12 @@ def fuse_distributed_ops(
5857
args=(node.args[0], node.args[1], node.args[2]),
5958
)
6059
else:
61-
fused_node = gm.graph.create_node(
62-
op="call_function",
63-
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
64-
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
65-
)
60+
with gm.graph.inserting_after(wait_tensor_node):
61+
fused_node = gm.graph.create_node(
62+
op="call_function",
63+
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
64+
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
65+
)
6666

6767
wait_tensor_node.replace_all_uses_with(fused_node)
6868
fused_node.meta.update(node.meta)

0 commit comments

Comments
 (0)