Skip to content

Commit

Permalink
Cosmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Nov 24, 2024
1 parent a94f1e1 commit c3d770c
Showing 1 changed file with 46 additions and 56 deletions.
102 changes: 46 additions & 56 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,12 +1689,12 @@ def forward(self, query, key, value, attn_mask=None):
unexpected_ops = {torch.ops.aten.scaled_dot_product_attention.default}

inputs = [
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
]
if attn:
inputs += [torch.rand(2, 4, 8, 8, dtype=torch.half, device="cuda")]
inputs += [torch.rand(1, 3, 8, 8, dtype=torch.half, device="cuda")]

exported_program = torch.export.export(TestModule(), tuple(inputs))
fx_graph = exported_program.module()
Expand All @@ -1714,11 +1714,9 @@ def forward(self, query, key, value, attn_mask=None):
trt_model = torch_tensorrt.dynamo.compile(
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
)
trt_output = trt_model(*inputs).detach().cpu()
torch_output = fx_graph(*inputs).detach().cpu()
torch.testing.assert_close(
trt_output,
torch_output,
trt_model(*inputs),
fx_graph(*inputs),
rtol=RTOL,
atol=ATOL,
msg="Scaled_dot_product_attention TRT outputs don't match with the original model.",
Expand Down Expand Up @@ -1747,19 +1745,19 @@ def forward(self, query, key, value, attn_mask=None):
)

example_inputs = [
torch.zeros(2, 2, 16, 16, dtype=torch.half, device="cuda"),
torch.zeros(2, 2, 16, 16, dtype=torch.half, device="cuda"),
torch.zeros(2, 2, 16, 16, dtype=torch.half, device="cuda"),
torch.zeros(2, 2, 16, 32, dtype=torch.half, device="cuda"),
torch.zeros(2, 2, 16, 32, dtype=torch.half, device="cuda"),
torch.zeros(2, 2, 16, 32, dtype=torch.half, device="cuda"),
]
if attn:
example_inputs += [
torch.zeros(2, 2, 16, 16, dtype=torch.half, device="cuda")
]

dim0 = torch.export.Dim("dim0", min=2, max=8)
dim0 = torch.export.Dim("dim0", min=2, max=4)
dim1 = torch.export.Dim("dim1", min=2, max=8)
_dim2 = torch.export.Dim("dim2", min=16 // 8, max=64 // 8)
_dim3 = torch.export.Dim("dim3", min=16 // 8, max=64 // 8)
_dim3 = torch.export.Dim("dim3", min=32 // 8, max=128 // 8)
dim2 = _dim2 * 8
dim3 = _dim3 * 8

Expand All @@ -1769,7 +1767,7 @@ def forward(self, query, key, value, attn_mask=None):
"value": {0: dim0, 1: dim1, 2: dim2, 3: dim3},
}
if attn:
dynamic_shapes["attn_mask"] = {0: dim0, 1: dim1, 2: dim2, 3: dim3}
dynamic_shapes["attn_mask"] = {0: dim0, 1: dim1, 2: dim2, 3: dim2}

exported_program = torch.export.export(
TestModule(), tuple(example_inputs), dynamic_shapes=dynamic_shapes
Expand All @@ -1778,30 +1776,30 @@ def forward(self, query, key, value, attn_mask=None):

inputs = [
torch_tensorrt.Input(
min_shape=(2, 2, 16, 16),
opt_shape=(4, 4, 32, 32),
max_shape=(8, 8, 64, 64),
min_shape=(2, 2, 16, 32),
opt_shape=(3, 4, 32, 64),
max_shape=(4, 8, 64, 128),
dtype=torch.half,
),
torch_tensorrt.Input(
min_shape=(2, 2, 16, 16),
opt_shape=(4, 4, 32, 32),
max_shape=(8, 8, 64, 64),
min_shape=(2, 2, 16, 32),
opt_shape=(3, 4, 32, 64),
max_shape=(4, 8, 64, 128),
dtype=torch.half,
),
torch_tensorrt.Input(
min_shape=(2, 2, 16, 16),
opt_shape=(4, 4, 32, 32),
max_shape=(8, 8, 64, 64),
min_shape=(2, 2, 16, 32),
opt_shape=(3, 4, 32, 64),
max_shape=(4, 8, 64, 128),
dtype=torch.half,
),
]
if attn:
inputs += [
torch_tensorrt.Input(
min_shape=(2, 2, 16, 16),
opt_shape=(4, 4, 32, 32),
max_shape=(8, 8, 64, 64),
opt_shape=(3, 4, 32, 32),
max_shape=(4, 8, 64, 64),
dtype=torch.half,
)
]
Expand All @@ -1812,18 +1810,16 @@ def forward(self, query, key, value, attn_mask=None):

# Validate that the results between Torch and Torch-TRT are similar
inputs = [
torch.rand(8, 8, 64, 64, dtype=torch.half, device="cuda"),
torch.rand(8, 8, 64, 64, dtype=torch.half, device="cuda"),
torch.rand(8, 8, 64, 64, dtype=torch.half, device="cuda"),
torch.rand(4, 8, 64, 128, dtype=torch.half, device="cuda"),
torch.rand(4, 8, 64, 128, dtype=torch.half, device="cuda"),
torch.rand(4, 8, 64, 128, dtype=torch.half, device="cuda"),
]
if attn:
inputs += [torch.rand(8, 8, 64, 64, dtype=torch.half, device="cuda")]
inputs += [torch.rand(4, 8, 64, 64, dtype=torch.half, device="cuda")]

trt_output = trt_model(*inputs).detach().cpu()
torch_output = fx_graph(*inputs).detach().cpu()
torch.testing.assert_close(
trt_output,
torch_output,
trt_model(*inputs),
fx_graph(*inputs),
rtol=RTOL,
atol=ATOL,
msg="Scaled_dot_product_attention_with_dynamic_shape TRT outputs don't match with the original model.",
Expand Down Expand Up @@ -1855,9 +1851,9 @@ def forward(self, query, key, value):
unexpected_ops = {torch.ops.aten._scaled_dot_product_flash_attention.default}

inputs = [
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
]

exported_program = torch.export.export(TestModule(), tuple(inputs))
Expand All @@ -1878,11 +1874,9 @@ def forward(self, query, key, value):
trt_model = torch_tensorrt.dynamo.compile(
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
)
trt_output = trt_model(*inputs).detach().cpu()
torch_output = fx_graph(*inputs).detach().cpu()
torch.testing.assert_close(
trt_output,
torch_output,
trt_model(*inputs),
fx_graph(*inputs),
rtol=RTOL,
atol=ATOL,
msg="Scaled_dot_product_flash_attention TRT outputs don't match with the original model.",
Expand Down Expand Up @@ -1916,12 +1910,12 @@ def forward(self, query, key, value, attn_bias=None):
}

inputs = [
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
]
if attn:
inputs += [torch.rand(2, 4, 8, 8, dtype=torch.half, device="cuda")]
inputs += [torch.rand(1, 3, 8, 8, dtype=torch.half, device="cuda")]

exported_program = torch.export.export(TestModule(), tuple(inputs))
fx_graph = exported_program.module()
Expand All @@ -1941,11 +1935,9 @@ def forward(self, query, key, value, attn_bias=None):
trt_model = torch_tensorrt.dynamo.compile(
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
)
trt_output = trt_model(*inputs).detach().cpu()
torch_output = fx_graph(*inputs).detach().cpu()
torch.testing.assert_close(
trt_output,
torch_output,
trt_model(*inputs),
fx_graph(*inputs),
rtol=RTOL,
atol=ATOL,
msg="Scaled_dot_product_efficient_attention TRT outputs don't match with the original model.",
Expand Down Expand Up @@ -1979,12 +1971,12 @@ def forward(self, query, key, value, attn_bias=None):
unexpected_ops = {torch.ops.aten._scaled_dot_product_cudnn_attention.default}

inputs = [
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(2, 4, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
torch.rand(1, 3, 8, 16, dtype=torch.half, device="cuda"),
]
if attn:
inputs += [torch.rand(2, 4, 8, 8, dtype=torch.half, device="cuda")]
inputs += [torch.rand(1, 3, 8, 8, dtype=torch.half, device="cuda")]

exported_program = torch.export.export(TestModule(), tuple(inputs))
fx_graph = exported_program.module()
Expand All @@ -2004,11 +1996,9 @@ def forward(self, query, key, value, attn_bias=None):
trt_model = torch_tensorrt.dynamo.compile(
exported_program, inputs, enabled_precisions={torch.half}, min_block_size=1
)
trt_output = trt_model(*inputs).detach().cpu()
torch_output = fx_graph(*inputs).detach().cpu()
torch.testing.assert_close(
trt_output,
torch_output,
trt_model(*inputs),
fx_graph(*inputs),
rtol=RTOL,
atol=ATOL,
msg="Scaled_dot_product_cudnn_attention TRT outputs don't match with the original model.",
Expand Down

0 comments on commit c3d770c

Please sign in to comment.