diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 4ae14adca2..9f5da60be5 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -80,7 +80,7 @@ def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop= head_dim = round(dim // num_heads * head_dim_ratio) self.head_dim = head_dim self.scale = head_dim ** -0.5 - self.fused_attn = use_fused_attn() + self.fused_attn = use_fused_attn(experimental=True) self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) self.attn_drop = nn.Dropout(attn_drop) @@ -94,7 +94,7 @@ def forward(self, x): if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( - q, k, v, + q.contiguous(), k.contiguous(), v.contiguous(), dropout_p=self.attn_drop.p, ) else: