diff --git a/src/IndexSpaces.jl b/src/IndexSpaces.jl index b41359d..5bea071 100644 --- a/src/IndexSpaces.jl +++ b/src/IndexSpaces.jl @@ -1988,7 +1988,7 @@ CUDA.@device_override function mma_m16n8k16(A::NTuple{4,Float16x2}, B::NTuple{2, return (Float16x2(D[1] % UInt32), Float16x2(D[2] % UInt32)) end -mma_sp_m16n8k16(A::NTuple{4,Float16x2}, B::NTuple{2,Float16x2}, C::NTuple{2,Float16x2}) = C +mma_sp_m16n8k16(A::NTuple{2,Float16x2}, B::NTuple{2,Float16x2}, C::NTuple{2,Float16x2}, e::Int2x16, f::Integer) = C CUDA.@device_override function mma_sp_m16n8k16( A::NTuple{2,Float16x2}, B::NTuple{2,Float16x2}, C::NTuple{2,Float16x2}, e::Int2x16, f::Integer ) @@ -2005,8 +2005,6 @@ CUDA.@device_override function mma_sp_m16n8k16( NTuple{8,Int32}, A[1].val % Int32, A[2].val % Int32, - A[3].val % Int32, - A[4].val % Int32, B[1].val % Int32, B[2].val % Int32, C[1].val % Int32, @@ -2698,8 +2696,8 @@ function mma_sp_row_col_m16n8k16_f16!( e6 elseif threadgroup == 7i32 e7 - else - @assert false + # else + # @assert false end IndexSpaces.mma_sp_m16n8k16(($A0_name, $A1_name), ($B0_name, $B1_name), ($C0_name, $C1_name), e, 0i32) end