Skip to content

Commit a2f37f8

Browse files
authored
Use a constant with clearly-defined type for log2e in fwd_kernel_splitK (#1181)
Summary: Triton 3.2 made some changes to its interpretation of constants (triton-lang/triton#4613) which makes Triton more consistent with pytorch/numpy, but cause some surprising issues with this kernel. Specifically it seems like log2e is interpreted as float32 in one instance and float64 in another, which leads to reduced prediction accuracy in some cases. To prevent this, let's make log2e a constant and define it as float32.
1 parent 9a59df2 commit a2f37f8

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

xformers/ops/fmha/_triton/splitk_kernels.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,13 @@ def _fwd_kernel_splitK(
341341
# scale sm_scale by log_2(e) and use
342342
# 2^x instead of exp in the loop because CSE and LICM
343343
# don't work as expected with `exp` in the loop
344-
qk_scale = sm_scale * 1.44269504
344+
#
345+
# We declare log2e as a constant with a precisely-specified type to guarantee that
346+
# triton will use the exact same value in all instances below, rather than sometimes
347+
# using float32 and sometimes using float64. For more discussion see:
348+
# https://github.com/triton-lang/triton/issues/5466
349+
log2e = tl.full((), 1.44269504, tl.float32)
350+
qk_scale = sm_scale * log2e
345351
# load q: it will stay in SRAM throughout
346352
q: "VAR_ARGS_ARRAY" # noqa: F821
347353
for i in range(len(acc)): # noqa: F821
@@ -468,7 +474,7 @@ def _fwd_kernel_splitK(
468474
additive_bias_block_ptr,
469475
boundary_check=(0, 1) if BOUNDS_CHECKS_N else (0,),
470476
)
471-
qk += loaded_bias.to(tl.float32) * 1.44269504
477+
qk += loaded_bias.to(tl.float32) * log2e
472478
additive_bias_block_ptr = tl.advance(additive_bias_block_ptr, (0, BLOCK_N))
473479

474480
# TODO: This is slow, and only needed at the last iteration.
@@ -548,7 +554,7 @@ def _fwd_kernel_splitK(
548554
lse_dtype = LSE_splitk.dtype.element_ty
549555
tl.store(
550556
LSE_splitk_ptr,
551-
(tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / 1.44269504,
557+
(tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / log2e,
552558
mask=mask,
553559
)
554560

0 commit comments

Comments
 (0)