Skip to content

Commit

Permalink
Add comprehensive tests to test the kernel across available dtypes.
Browse files Browse the repository at this point in the history
Added softmax and gemm kernel to test across the available float and int dtypes.
  • Loading branch information
Prashant Kumar committed Feb 27, 2024
1 parent 971231c commit c1acd99
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 0 deletions.
21 changes: 21 additions & 0 deletions core/shark_turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@
_FLOAT_TYPES = ["f16", "f32", "f64"]
_INDEX_TYPES = ["index"]

import torch

_TKL_TO_TORCH_DTYPE = {
"f16": torch.half,
"f32": torch.float,
"f64": torch.double,
"i1": torch.bool,
"i8": torch.int8,
"i16": torch.int16,
"i32": torch.int32,
"i64": torch.int64,
}


# TODO: this should really be a type.
class DataType:
Expand Down Expand Up @@ -44,6 +57,14 @@ def is_float_asm(self):
def is_index_asm(self):
return self._name in _INDEX_TYPES

def to_torch_type(self):
try:
return _TKL_TO_TORCH_DTYPE[self._name]
except KeyError:
print(
f"The support for '{self._name}' dtype to torch type isn't implemented."
)


bool = DataType("bool", "i1")
i4 = DataType("i4")
Expand Down
8 changes: 8 additions & 0 deletions core/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,14 @@ def handle_exp2(self, op, val):
kwargs={},
)

def handle_rsqrt(self, op, val):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(val,),
kwargs={},
)

def handle_vector_constant(
self, op, shape: Tuple[int, ...], dtype, value: int | float
):
Expand Down
3 changes: 3 additions & 0 deletions core/shark_turbine/kernel/compiler/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,5 +306,8 @@ def binary_truediv_float(
def unary_exp2_float(self, val: IRProxyValue) -> IRProxyValue:
return IRProxyValue(math_d.exp2(val.ir_value))

def unary_rsqrt_float(self, val: IRProxyValue) -> IRProxyValue:
return IRProxyValue(math_d.rsqrt(val.ir_value))


ScalarBuilder = _ScalarBuilder()
1 change: 1 addition & 0 deletions core/shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _(emitter: ThreadEmitter, node: fx.Node):

UNARY_ARITHMETIC_OPS = [
(tkl.exp2, "exp2"),
(tkl.rsqrt, "rsqrt"),
]


Expand Down
2 changes: 2 additions & 0 deletions core/shark_turbine/kernel/lang/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"broadcast_in_dim",
"transpose",
"to_dtype",
"rsqrt",
]


Expand All @@ -37,6 +38,7 @@ def is_debug() -> bool:
# Math Operations
exp2 = ops.exp2
constant = ops.vector_constant
rsqrt = ops.rsqrt

# Reduction Operations
max = ops.vector_max
Expand Down
6 changes: 6 additions & 0 deletions core/shark_turbine/kernel/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

__all__ = [
"exp2",
"rsqrt",
"vector_constant",
]

Expand All @@ -22,3 +23,8 @@ def exp2(val):
@define_op
def vector_constant(shape: Tuple[int, ...], dtype, value: int | float) -> "Vector":
...


@define_op
def rsqrt(val):
...
167 changes: 167 additions & 0 deletions core/tests/kernel/coverage_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import torch
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import pytest


FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64]
INT_DTYPES = [
tkl.bool,
tkl.i8,
tkl.i16,
tkl.i32,
tkl.i64,
]


def rms_norm_krnl(dtype, input, weight, output):
M = tkl.sym.M
K = tkl.sym.K

@tk.gen.thread(M)
def rms_norm_kernel(
input: tkl.OutputBuffer[M, K, dtype],
weight: tk.lang.InputBuffer[M, K, dtype],
output: tk.lang.OutputBuffer[M, K, dtype],
):
row_index = tk.lang.program_id(0)
eps = tkl.constant((1,), dtype, 0.00001)
zero = tkl.constant((1,), dtype, 0.0)
input_row = input[row_index, :]
sq_inp = input_row * input_row
sq_inp_red = tkl.sum(sq_inp)
# TODO: The input_row * zero is just dummy computation to pass in the right shapes,
# otherwise it leads to 'error: unknown: 'math.exp2' op operand #0 must be floating-point-like, but got 'vector<f16>'
denom = tkl.rsqrt(input_row * zero + sq_inp_red)
denom_eta = denom + eps
output[row_index, :] = denom_eta * input_row * weight[row_index, :]

with tk.gen.TestLaunchContext():
rms_norm_kernel(input, weight, output)


def iota_krnl(dtype, input):
M = tkl.sym.M

@tk.gen.thread(M)
def iota_kernel(out: tkl.OutputBuffer[M, dtype]):
a = (
tkl.constant((17, 37, 19), dtype, 5)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 5.0)
)
b = (
tkl.constant((17, 37, 19), dtype, 10)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 10.0)
)
c = (
tkl.constant((17, 37, 19), dtype, 2)
if dtype in INT_DTYPES
else tkl.constant((17, 37, 19), dtype, 2.0)
)
if dtype in INT_DTYPES:
c = (a * b) // c
else:
c = (a * b) / c
c = c + a - b

with tk.gen.TestLaunchContext():
iota_kernel(input)


def softmax_krnl(dtype, input, output):
M = tkl.sym.M
K = tkl.sym.K

@tk.gen.thread(M)
def softmax_kernel(
input: tk.lang.InputBuffer[M, K, dtype],
output: tk.lang.OutputBuffer[M, K, dtype],
):
row_index = tk.lang.program_id(0)
input_row = input[row_index, :]
numerator = tkl.exp2(input_row - tkl.max(input_row))
if dtype in INT_DTYPES:
output_row = numerator // tkl.sum(numerator)
else:
output_row = numerator / tkl.sum(numerator)
output[row_index, :] = output_row

with tk.gen.TestLaunchContext():
softmax_kernel(input, output)


def gemm_fx_kernel(dtype, A, B, output):
N = tkl.sym.N
M = tkl.sym.M
K = tkl.sym.K
BLOCK_SIZE = tkl.sym.BLOCK_SIZE

@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
def gemm_kernel(
A: tkl.InputBuffer[N, K, dtype],
B: tkl.InputBuffer[K, M, dtype],
output: tkl.OutputBuffer[N, M, dtype],
):
grid_n = tkl.program_id(0)
grid_m = tkl.program_id(1)

acc = None
# TODO: Only considering the float and integer cases.
if dtype in INT_DTYPES:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0)
else:
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0)

@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
def body(i, c):
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
return (tkl.dot(a, b, c),)

tkl.store(output, (grid_n, grid_m), body[0])

with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
gemm_kernel(A, B, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
)
def test_iota_krnl(dtype):
input = torch.zeros(17)
iota_krnl(dtype, input)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES],
)
def test_rms_norm_krnl(dtype):
input = torch.randn(128, 64).to(dtype.to_torch_type())
weight = torch.randn(128, 64).to(dtype.to_torch_type())
output = torch.randn(128, 64).to(dtype.to_torch_type())
rms_norm_krnl(dtype, input, weight, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES],
)
def test_softmax_krnl(dtype):
input = torch.randn(128, 64).to(dtype.to_torch_type())
output = torch.randn(128, 64).to(dtype.to_torch_type())
softmax_krnl(dtype, input, output)


@pytest.mark.parametrize(
("dtype",),
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
)
def test_gemm_krnl(dtype):
A = torch.randn(512, 1024).to(dtype.to_torch_type())
B = torch.randn(1024, 2048).to(dtype.to_torch_type())
output = torch.zeros(512, 2048).to(dtype.to_torch_type())
gemm_fx_kernel(dtype, A, B, output)

0 comments on commit c1acd99

Please sign in to comment.