Skip to content

Commit

Permalink
Allow disabling torch 2_0 attention (huggingface#3273)
Browse files Browse the repository at this point in the history
* Allow disabling torch 2_0 attention

* make style

* Update src/diffusers/models/attention.py
  • Loading branch information
patrickvonplaten authored Apr 28, 2023
1 parent 83c4ce7 commit b952672
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
self.proj_attn = nn.Linear(channels, channels, bias=True)

self._use_memory_efficient_attention_xformers = False
self._use_2_0_attn = True
self._attention_op = None

def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
Expand Down Expand Up @@ -142,9 +143,8 @@ def forward(self, hidden_states):

scale = 1 / math.sqrt(self.channels / self.num_heads)

use_torch_2_0_attn = (
hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers
)
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn

query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
Expand Down

0 comments on commit b952672

Please sign in to comment.