From 3665dfdfc0a509827ce8340347d0388b17c2117a Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 4 Apr 2024 09:37:46 +0000 Subject: [PATCH] update tenary mm --- kernels/ternary_mm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kernels/ternary_mm.py b/kernels/ternary_mm.py index 4d5f311..cf6b7e7 100644 --- a/kernels/ternary_mm.py +++ b/kernels/ternary_mm.py @@ -136,9 +136,9 @@ def _ternary_mm_kernel( # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers # See above `Pointer Arithmetics` section for details - offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + #offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M) + # offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -162,13 +162,14 @@ def _ternary_mm_kernel( for k in range(0, total_blocks_k): # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. + tl.multiple_of(b_ptrs, [16, 16]) if EVEN_K: # a = tl.load(a_ptrs) a = tl.load(a_block_ptr) b = tl.load(b_ptrs) else: a = tl.load(a_block_ptr, boundary_check=(0,1)) - #a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # Convert B from int to a.dtype, for each bit in B, 0 becomes -1.0, 1 becomes 1.0