Skip to content

Commit

Permalink
mma_sp: Correct sparsity selector
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett committed Dec 20, 2023
1 parent da9a280 commit e20fbfb
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions src/IndexSpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2593,7 +2593,6 @@ function mma_sp_row_col_m16n8k16_f16!(
0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1
0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1
]
@show A_pattern A_pattern_manual
@assert A_pattern == A_pattern_manual
function decode_row_pattern(row::Integer, col::Integer)
@assert col % 4 == 0
Expand All @@ -2615,33 +2614,42 @@ function mma_sp_row_col_m16n8k16_f16!(
return (c1, c2)
end
a_row_patterns = NTuple{2,Int}[decode_row_pattern(row, col) for row in 0:15, col in 0:4:15]
@show a_row_patterns
e = [Int2x16(reinterpret(Int, vec(a_row_patterns[thread+1:8:16,:]))...) for thread in 0:4:31]
@show e
e = [Int2x16(reinterpret(Int, vec(a_row_patterns[(threadgroup + 1):8:16, :]))...) for threadgroup in 0:7]
push!(
emitter.statements,
quote
# thread = IndexSpaces.cuda_threadidx()
# row = 1i32 * (thread ÷ 4i32 % 8i32) # + 8i32 * reg01
# col = 2i32 * (thread ÷ 1i32 % 4i32) # + 1i32 * simd01
# col0 = col + 0i32
# col1 = col + 1i32
# col2 = col + 2i32
# col3 = col + 3i32
# row_spectator = (row >> 2i32) & 1i32
# col_spectator0 = (col0 >> 1i32) & 1i32
# col_spectator1 = (col1 >> 1i32) & 1i32
# col_spectator2 = (col2 >> 1i32) & 1i32
# col_spectator3 = (col3 >> 1i32) & 1i32
# delta = row_spectator == col_spectator
#
#
# col4 = col ÷ 4
#
# e = Int2x16()
($D0_name, $D1_name) = IndexSpaces.mma_ap_m16n8k16(
($A0_name, $A1_name), ($B0_name, $B1_name), ($C0_name, $C1_name), $e, 0i32
)
($D0_name, $D1_name) = let
e0 = $(e[1])
e1 = $(e[2])
e2 = $(e[3])
e3 = $(e[4])
e4 = $(e[5])
e5 = $(e[6])
e6 = $(e[7])
e7 = $(e[8])
thread = IndexSpaces.cuda_threadidx()
threadgroup = thread ÷ 4
e = if threadgroup == 0i32
e0
elseif threadgroup == 1i32
e1
elseif threadgroup == 2i32
e2
elseif threadgroup == 3i32
e3
elseif threadgroup == 4i32
e4
elseif threadgroup == 5i32
e5
elseif threadgroup == 6i32
e6
elseif threadgroup == 7i32
e7
else
@assert false
end
IndexSpaces.mma_ap_m16n8k16(($A0_name, $A1_name), ($B0_name, $B1_name), ($C0_name, $C1_name), e, 0i32)
end
end,
)
end
Expand Down

0 comments on commit e20fbfb

Please sign in to comment.