diff --git a/models/attention.py b/models/attention.py index 8e537c6f3680..fb5f6f48b324 100644 --- a/models/attention.py +++ b/models/attention.py @@ -71,6 +71,7 @@ def __init__( self.proj_attn = nn.Linear(channels, channels, bias=True) self._use_memory_efficient_attention_xformers = False + self._use_2_0_attn = True self._attention_op = None def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True): @@ -142,9 +143,8 @@ def forward(self, hidden_states): scale = 1 / math.sqrt(self.channels / self.num_heads) - use_torch_2_0_attn = ( - hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers - ) + _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers + use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)