Skip to content

Latest commit

 

History

History

softmax

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Softmax

0x00 说明

包含以下内容:

  • softmax_f32_per_token_kernel(per token)
  • softmax_f32x4_per_token_kernel(per token)
  • safe_softmax_f32_per_token_kernel(per token)
  • safe_softmax_f32x4_per_token_kernel(per token)
  • safe_softmax_f16_f32_per_token_kernel(per token)
  • safe_softmax_f16x2_f32_per_token_kernel(per token)
  • safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
  • online_safe_softmax_f32_per_token_kernel(per token, online softmax)
  • online_safe_softmax_f32x4_pack_per_token_kernel(per token, online softmax)
  • PyTorch bindings

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 softmax.py

输出:

----------------------------------------------------------------------------------------------------
                                             S=4096, H=256
----------------------------------------------------------------------------------------------------
            out_f32(per): ['0.00916498  ', '0.00728124  ', '0.00437148  '], time:0.00634432ms
          out_f32x4(per): ['0.00916498  ', '0.00728124  ', '0.00437148  '], time:0.00403881ms
           out_f32(safe): ['0.00916498  ', '0.00728124  ', '0.00437148  '], time:0.00937700ms
    out_f32(safe+online): ['0.00916498  ', '0.00728124  ', '0.00437148  '], time:0.00752211ms
  out_f32x4(safe+online): ['0.00916498  ', '0.00728124  ', '0.00437148  '], time:0.00413895ms
         out_f32x4(safe): ['0.00916498  ', '0.00728124  ', '0.00437148  '], time:0.00422478ms
         out_f32_th(per): ['0.00916498  ', '0.00728124  ', '0.00437148  '], time:0.00657797ms
----------------------------------------------------------------------------------------------------
        out_f16f32(safe): ['0.0091629   ', '0.00728226  ', '0.00437164  '], time:0.00908375ms
      out_f16x2f32(safe): ['0.0091629   ', '0.00728226  ', '0.00437164  '], time:0.00526905ms
  out_f16x8packf32(safe): ['0.0091629   ', '0.00728226  ', '0.00437164  '], time:0.00419140ms
         out_f16_th(per): ['0.0091629   ', '0.00728226  ', '0.00437164  '], time:0.00652790ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=512
----------------------------------------------------------------------------------------------------
            out_f32(per): ['0.00044174  ', '0.00643609  ', '0.0022167   '], time:0.01138210ms
          out_f32x4(per): ['0.00044174  ', '0.00643608  ', '0.0022167   '], time:0.00517607ms
           out_f32(safe): ['0.00044174  ', '0.00643608  ', '0.0022167   '], time:0.01829147ms
    out_f32(safe+online): ['0.00044174  ', '0.00643608  ', '0.0022167   '], time:0.01358271ms
  out_f32x4(safe+online): ['0.00044174  ', '0.00643609  ', '0.0022167   '], time:0.00545502ms
         out_f32x4(safe): ['0.00044174  ', '0.00643608  ', '0.0022167   '], time:0.00577927ms
         out_f32_th(per): ['0.00044174  ', '0.00643608  ', '0.0022167   '], time:0.00664234ms
----------------------------------------------------------------------------------------------------
        out_f16f32(safe): ['0.00044155  ', '0.00643921  ', '0.00221634  '], time:0.01775265ms
      out_f16x2f32(safe): ['0.00044155  ', '0.00643921  ', '0.00221634  '], time:0.00919342ms
  out_f16x8packf32(safe): ['0.00044155  ', '0.00643921  ', '0.00221634  '], time:0.00421047ms
         out_f16_th(per): ['0.00044155  ', '0.00643921  ', '0.00221634  '], time:0.00655174ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=1024
----------------------------------------------------------------------------------------------------
            out_f32(per): ['0.00067776  ', '0.00096954  ', '0.00083744  '], time:0.03188610ms
          out_f32x4(per): ['0.00067776  ', '0.00096954  ', '0.00083744  '], time:0.00862360ms
           out_f32(safe): ['0.00067776  ', '0.00096954  ', '0.00083744  '], time:0.04860401ms
    out_f32(safe+online): ['0.00067776  ', '0.00096954  ', '0.00083744  '], time:0.03682852ms
  out_f32x4(safe+online): ['0.00067776  ', '0.00096954  ', '0.00083744  '], time:0.00926733ms
         out_f32x4(safe): ['0.00067776  ', '0.00096954  ', '0.00083744  '], time:0.01023054ms
         out_f32_th(per): ['0.00067776  ', '0.00096954  ', '0.00083744  '], time:0.01179218ms
----------------------------------------------------------------------------------------------------
        out_f16f32(safe): ['0.00067759  ', '0.00096989  ', '0.00083733  '], time:0.04744530ms
      out_f16x2f32(safe): ['0.00067759  ', '0.00096989  ', '0.00083733  '], time:0.01802206ms
  out_f16x8packf32(safe): ['0.00067759  ', '0.00096989  ', '0.00083733  '], time:0.00607967ms
         out_f16_th(per): ['0.00067759  ', '0.00096989  ', '0.00083733  '], time:0.01050234ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=2048
----------------------------------------------------------------------------------------------------
          out_f32x4(per): ['0.00034175  ', '0.00010398  ', '0.00071711  '], time:0.01597404ms
         out_f32x4(safe): ['0.00034175  ', '0.00010398  ', '0.00071711  '], time:0.02072334ms
  out_f32x4(safe+online): ['0.00034175  ', '0.00010398  ', '0.00071711  '], time:0.01784563ms
         out_f32_th(per): ['0.00034175  ', '0.00010398  ', '0.00071711  '], time:0.06695986ms
----------------------------------------------------------------------------------------------------
      out_f16x2f32(safe): ['0.00034165  ', '0.00010395  ', '0.00071716  '], time:0.04815578ms
  out_f16x8packf32(safe): ['0.00034165  ', '0.00010395  ', '0.00071716  '], time:0.01078129ms
         out_f16_th(per): ['0.00034165  ', '0.00010395  ', '0.00071716  '], time:0.07112741ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=4096
----------------------------------------------------------------------------------------------------
          out_f32x4(per): ['7.791e-05   ', '0.00012107  ', '0.00016793  '], time:0.18604755ms
         out_f32x4(safe): ['7.791e-05   ', '0.00012107  ', '0.00016793  '], time:0.18649578ms
  out_f32x4(safe+online): ['7.791e-05   ', '0.00012107  ', '0.00016793  '], time:0.18569946ms
         out_f32_th(per): ['7.791e-05   ', '0.00012107  ', '0.00016793  '], time:0.18718004ms
----------------------------------------------------------------------------------------------------
  out_f16x8packf32(safe): ['7.79e-05    ', '0.00012106  ', '0.00016797  '], time:0.02230883ms
         out_f16_th(per): ['7.79e-05    ', '0.00012106  ', '0.00016797  '], time:0.08208990ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=8192
----------------------------------------------------------------------------------------------------
  out_f16x8packf32(safe): ['6.109e-05   ', '0.00015938  ', '0.00022686  '], time:0.18989086ms
         out_f16_th(per): ['6.109e-05   ', '0.00015938  ', '0.00022686  '], time:0.19092798ms
----------------------------------------------------------------------------------------------------
                                             S=8192, H=8192
----------------------------------------------------------------------------------------------------
  out_f16x8packf32(safe): ['0.00012851  ', '0.00010681  ', '6.098e-05   '], time:0.40277004ms
         out_f16_th(per): ['0.00012851  ', '0.00010681  ', '6.098e-05   '], time:0.40700197ms
----------------------------------------------------------------------------------------------------