Skip to content

Commit

Permalink
import neuron_attention earlier (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin authored Dec 19, 2024
1 parent 68c6ee9 commit f4a68f9
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from axlearn.common.attention import NEG_INF, MaskFn, causal_mask, softmax_with_biases
from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention
from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention
from axlearn.common.flash_attention.neuron_attention import flash_attention as neuron_flash_attention
from axlearn.common.flash_attention.tpu_attention import tpu_flash_attention
from axlearn.common.utils import Tensor

Expand Down Expand Up @@ -160,10 +161,6 @@ def jit_attn(query, key, value, bias, segment_ids):
return jit_attn

elif backend == "neuron":
from axlearn.common.flash_attention.neuron_attention import (
flash_attention as neuron_flash_attention,
)

# shard_map-decorated function needs to be jitted.
@jax.jit
def jit_attn(query, key, value, bias, segment_ids):
Expand Down

0 comments on commit f4a68f9

Please sign in to comment.