Skip to content

Commit

Permalink
Correct mma_sp_m16n8k16 implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett committed Jan 16, 2024
1 parent 15eed5a commit 119f40b
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/IndexSpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1988,7 +1988,7 @@ CUDA.@device_override function mma_m16n8k16(A::NTuple{4,Float16x2}, B::NTuple{2,
return (Float16x2(D[1] % UInt32), Float16x2(D[2] % UInt32))
end

mma_sp_m16n8k16(A::NTuple{4,Float16x2}, B::NTuple{2,Float16x2}, C::NTuple{2,Float16x2}) = C
mma_sp_m16n8k16(A::NTuple{2,Float16x2}, B::NTuple{2,Float16x2}, C::NTuple{2,Float16x2}, e::Int2x16, f::Integer) = C
CUDA.@device_override function mma_sp_m16n8k16(
A::NTuple{2,Float16x2}, B::NTuple{2,Float16x2}, C::NTuple{2,Float16x2}, e::Int2x16, f::Integer
)
Expand All @@ -2005,8 +2005,6 @@ CUDA.@device_override function mma_sp_m16n8k16(
NTuple{8,Int32},
A[1].val % Int32,
A[2].val % Int32,
A[3].val % Int32,
A[4].val % Int32,
B[1].val % Int32,
B[2].val % Int32,
C[1].val % Int32,
Expand Down Expand Up @@ -2698,8 +2696,8 @@ function mma_sp_row_col_m16n8k16_f16!(
e6
elseif threadgroup == 7i32
e7
else
@assert false
# else
# @assert false
end
IndexSpaces.mma_sp_m16n8k16(($A0_name, $A1_name), ($B0_name, $B1_name), ($C0_name, $C1_name), e, 0i32)
end
Expand Down

0 comments on commit 119f40b

Please sign in to comment.