包含以下内容:
- softmax_f32_per_token_kernel(per token)
- softmax_f32x4_per_token_kernel(per token)
- safe_softmax_f32_per_token_kernel(per token)
- safe_softmax_f32x4_per_token_kernel(per token)
- safe_softmax_f16_f32_per_token_kernel(per token)
- safe_softmax_f16x2_f32_per_token_kernel(per token)
- safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
- online_safe_softmax_f32_per_token_kernel(per token, online softmax)
- online_safe_softmax_f32x4_pack_per_token_kernel(per token, online softmax)
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 softmax.py
输出:
----------------------------------------------------------------------------------------------------
S=4096, H=256
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.00916498 ', '0.00728124 ', '0.00437148 '], time:0.00634432ms
out_f32x4(per): ['0.00916498 ', '0.00728124 ', '0.00437148 '], time:0.00403881ms
out_f32(safe): ['0.00916498 ', '0.00728124 ', '0.00437148 '], time:0.00937700ms
out_f32(safe+online): ['0.00916498 ', '0.00728124 ', '0.00437148 '], time:0.00752211ms
out_f32x4(safe+online): ['0.00916498 ', '0.00728124 ', '0.00437148 '], time:0.00413895ms
out_f32x4(safe): ['0.00916498 ', '0.00728124 ', '0.00437148 '], time:0.00422478ms
out_f32_th(per): ['0.00916498 ', '0.00728124 ', '0.00437148 '], time:0.00657797ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.0091629 ', '0.00728226 ', '0.00437164 '], time:0.00908375ms
out_f16x2f32(safe): ['0.0091629 ', '0.00728226 ', '0.00437164 '], time:0.00526905ms
out_f16x8packf32(safe): ['0.0091629 ', '0.00728226 ', '0.00437164 '], time:0.00419140ms
out_f16_th(per): ['0.0091629 ', '0.00728226 ', '0.00437164 '], time:0.00652790ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=512
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.00044174 ', '0.00643609 ', '0.0022167 '], time:0.01138210ms
out_f32x4(per): ['0.00044174 ', '0.00643608 ', '0.0022167 '], time:0.00517607ms
out_f32(safe): ['0.00044174 ', '0.00643608 ', '0.0022167 '], time:0.01829147ms
out_f32(safe+online): ['0.00044174 ', '0.00643608 ', '0.0022167 '], time:0.01358271ms
out_f32x4(safe+online): ['0.00044174 ', '0.00643609 ', '0.0022167 '], time:0.00545502ms
out_f32x4(safe): ['0.00044174 ', '0.00643608 ', '0.0022167 '], time:0.00577927ms
out_f32_th(per): ['0.00044174 ', '0.00643608 ', '0.0022167 '], time:0.00664234ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00044155 ', '0.00643921 ', '0.00221634 '], time:0.01775265ms
out_f16x2f32(safe): ['0.00044155 ', '0.00643921 ', '0.00221634 '], time:0.00919342ms
out_f16x8packf32(safe): ['0.00044155 ', '0.00643921 ', '0.00221634 '], time:0.00421047ms
out_f16_th(per): ['0.00044155 ', '0.00643921 ', '0.00221634 '], time:0.00655174ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=1024
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.00067776 ', '0.00096954 ', '0.00083744 '], time:0.03188610ms
out_f32x4(per): ['0.00067776 ', '0.00096954 ', '0.00083744 '], time:0.00862360ms
out_f32(safe): ['0.00067776 ', '0.00096954 ', '0.00083744 '], time:0.04860401ms
out_f32(safe+online): ['0.00067776 ', '0.00096954 ', '0.00083744 '], time:0.03682852ms
out_f32x4(safe+online): ['0.00067776 ', '0.00096954 ', '0.00083744 '], time:0.00926733ms
out_f32x4(safe): ['0.00067776 ', '0.00096954 ', '0.00083744 '], time:0.01023054ms
out_f32_th(per): ['0.00067776 ', '0.00096954 ', '0.00083744 '], time:0.01179218ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00067759 ', '0.00096989 ', '0.00083733 '], time:0.04744530ms
out_f16x2f32(safe): ['0.00067759 ', '0.00096989 ', '0.00083733 '], time:0.01802206ms
out_f16x8packf32(safe): ['0.00067759 ', '0.00096989 ', '0.00083733 '], time:0.00607967ms
out_f16_th(per): ['0.00067759 ', '0.00096989 ', '0.00083733 '], time:0.01050234ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=2048
----------------------------------------------------------------------------------------------------
out_f32x4(per): ['0.00034175 ', '0.00010398 ', '0.00071711 '], time:0.01597404ms
out_f32x4(safe): ['0.00034175 ', '0.00010398 ', '0.00071711 '], time:0.02072334ms
out_f32x4(safe+online): ['0.00034175 ', '0.00010398 ', '0.00071711 '], time:0.01784563ms
out_f32_th(per): ['0.00034175 ', '0.00010398 ', '0.00071711 '], time:0.06695986ms
----------------------------------------------------------------------------------------------------
out_f16x2f32(safe): ['0.00034165 ', '0.00010395 ', '0.00071716 '], time:0.04815578ms
out_f16x8packf32(safe): ['0.00034165 ', '0.00010395 ', '0.00071716 '], time:0.01078129ms
out_f16_th(per): ['0.00034165 ', '0.00010395 ', '0.00071716 '], time:0.07112741ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=4096
----------------------------------------------------------------------------------------------------
out_f32x4(per): ['7.791e-05 ', '0.00012107 ', '0.00016793 '], time:0.18604755ms
out_f32x4(safe): ['7.791e-05 ', '0.00012107 ', '0.00016793 '], time:0.18649578ms
out_f32x4(safe+online): ['7.791e-05 ', '0.00012107 ', '0.00016793 '], time:0.18569946ms
out_f32_th(per): ['7.791e-05 ', '0.00012107 ', '0.00016793 '], time:0.18718004ms
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['7.79e-05 ', '0.00012106 ', '0.00016797 '], time:0.02230883ms
out_f16_th(per): ['7.79e-05 ', '0.00012106 ', '0.00016797 '], time:0.08208990ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=8192
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['6.109e-05 ', '0.00015938 ', '0.00022686 '], time:0.18989086ms
out_f16_th(per): ['6.109e-05 ', '0.00015938 ', '0.00022686 '], time:0.19092798ms
----------------------------------------------------------------------------------------------------
S=8192, H=8192
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['0.00012851 ', '0.00010681 ', '6.098e-05 '], time:0.40277004ms
out_f16_th(per): ['0.00012851 ', '0.00010681 ', '6.098e-05 '], time:0.40700197ms
----------------------------------------------------------------------------------------------------