diff --git a/src/IndexSpaces.jl b/src/IndexSpaces.jl index cd4ff76..68b33af 100644 --- a/src/IndexSpaces.jl +++ b/src/IndexSpaces.jl @@ -2593,7 +2593,6 @@ function mma_sp_row_col_m16n8k16_f16!( 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 ] - @show A_pattern A_pattern_manual @assert A_pattern == A_pattern_manual function decode_row_pattern(row::Integer, col::Integer) @assert col % 4 == 0 @@ -2615,33 +2614,42 @@ function mma_sp_row_col_m16n8k16_f16!( return (c1, c2) end a_row_patterns = NTuple{2,Int}[decode_row_pattern(row, col) for row in 0:15, col in 0:4:15] - @show a_row_patterns - e = [Int2x16(reinterpret(Int, vec(a_row_patterns[thread+1:8:16,:]))...) for thread in 0:4:31] - @show e + e = [Int2x16(reinterpret(Int, vec(a_row_patterns[(threadgroup + 1):8:16, :]))...) for threadgroup in 0:7] push!( emitter.statements, quote - # thread = IndexSpaces.cuda_threadidx() - # row = 1i32 * (thread ÷ 4i32 % 8i32) # + 8i32 * reg01 - # col = 2i32 * (thread ÷ 1i32 % 4i32) # + 1i32 * simd01 - # col0 = col + 0i32 - # col1 = col + 1i32 - # col2 = col + 2i32 - # col3 = col + 3i32 - # row_spectator = (row >> 2i32) & 1i32 - # col_spectator0 = (col0 >> 1i32) & 1i32 - # col_spectator1 = (col1 >> 1i32) & 1i32 - # col_spectator2 = (col2 >> 1i32) & 1i32 - # col_spectator3 = (col3 >> 1i32) & 1i32 - # delta = row_spectator == col_spectator - # - # - # col4 = col ÷ 4 - # - # e = Int2x16() - ($D0_name, $D1_name) = IndexSpaces.mma_ap_m16n8k16( - ($A0_name, $A1_name), ($B0_name, $B1_name), ($C0_name, $C1_name), $e, 0i32 - ) + ($D0_name, $D1_name) = let + e0 = $(e[1]) + e1 = $(e[2]) + e2 = $(e[3]) + e3 = $(e[4]) + e4 = $(e[5]) + e5 = $(e[6]) + e6 = $(e[7]) + e7 = $(e[8]) + thread = IndexSpaces.cuda_threadidx() + threadgroup = thread ÷ 4 + e = if threadgroup == 0i32 + e0 + elseif threadgroup == 1i32 + e1 + elseif threadgroup == 2i32 + e2 + elseif threadgroup == 3i32 + e3 + elseif threadgroup == 4i32 + e4 + elseif threadgroup == 5i32 + e5 + elseif threadgroup == 6i32 + e6 + elseif threadgroup == 7i32 + e7 + else + @assert false + end + IndexSpaces.mma_ap_m16n8k16(($A0_name, $A1_name), ($B0_name, $B1_name), ($C0_name, $C1_name), e, 0i32) + end end, ) end