Skip to content

Commit

Permalink
F.sdpa for visformer fails w/o contiguous on qkv, make experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed May 11, 2023
1 parent cf1884b commit 3eaf729
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions timm/models/visformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=
head_dim = round(dim // num_heads * head_dim_ratio)
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.fused_attn = use_fused_attn(experimental=True)

self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
Expand All @@ -94,7 +94,7 @@ def forward(self, x):

if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
q.contiguous(), k.contiguous(), v.contiguous(),
dropout_p=self.attn_drop.p,
)
else:
Expand Down

0 comments on commit 3eaf729

Please sign in to comment.