Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace scaled_dot_product_attention lowering pass with decomposition #3296

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2777,36 +2777,6 @@ def aten_ops_max_pool(
)


def attention_validator(node: Node, settings: CompilationSettings = None) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
supports_dynamic_shapes=True,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
activation,
addmm,
arange,
attention,
cast,
cat,
condition,
Expand Down
165 changes: 0 additions & 165 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py

This file was deleted.

128 changes: 127 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torch._decomp import register_decomposition
Expand Down Expand Up @@ -420,6 +420,132 @@ def instance_norm_decomposition(
)


@register_torch_trt_decomposition(
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
device = query.device
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)

if is_causal:
assert attn_mask is None, "attn_mask must be None when is_causal=True"
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias

if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale

attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value


@register_torch_trt_decomposition(
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_flash_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, None, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_efficient_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None


@register_torch_trt_decomposition(
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
)
def scaled_dot_product_cudnn_attention_decomposition(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor],
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.SymInt,
torch.SymInt,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
attn = scaled_dot_product_attention_decomposition(
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
)
return attn, None, None, None, 0, 0, None, None, None


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_detach import remove_detach
Expand All @@ -23,7 +22,6 @@
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
lower_scaled_dot_product_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
Expand Down
Loading
Loading