Skip to content

Commit

Permalink
update tenary mm
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Apr 4, 2024
1 parent c077ba6 commit 3665dfd
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions kernels/ternary_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def _ternary_mm_kernel(
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
#offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M)
# offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)

Expand All @@ -162,13 +162,14 @@ def _ternary_mm_kernel(
for k in range(0, total_blocks_k):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
tl.multiple_of(b_ptrs, [16, 16])
if EVEN_K:
# a = tl.load(a_ptrs)
a = tl.load(a_block_ptr)
b = tl.load(b_ptrs)
else:
a = tl.load(a_block_ptr, boundary_check=(0,1))
#a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

# Convert B from int to a.dtype, for each bit in B, 0 becomes -1.0, 1 becomes 1.0
Expand Down

0 comments on commit 3665dfd

Please sign in to comment.