Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance MM Kernel Performance and Coverage for Specific Input Scenarios #405

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions benchmark/attri_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,30 @@
]


# This function is adapted from: https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/utils/triton_op.py
def llama_shapes():
def model_shapes():
# batch sizes * seq lengths
BS = [2**i for i in range(0, 17)]
BS = [2**i for i in range(0, 9, 2)]
# attn: wqkv, wo; ffn: w13, w2
KN = [
(4096, 12288),
NK = [
# extract from llama3-8b
(1024, 4096),
(128256, 4096),
(14336, 4096),
(4096, 14336),
(4096, 4096),
(4096, 22016),
(11008, 4096),
(8192, 1280),
(1024, 8192),
(8192, 7168),
(3584, 8192),
(16384, 2304),
(2048, 16384),
(16384, 13312),
(6656, 16384),
(6144, 4096),
(28672, 4096),
# extract from qwen2.5-7b
(3584, 3584),
(18944, 3584),
(3584, 18944),
(152064, 3584),
(37888, 3584),
(512, 3584),
(4608, 3584),
]
return [(bs, n, k, None) for bs, (k, n) in itertools.product(BS, KN)]

return [(4, bs, n, k) for bs, (n, k) in itertools.product(BS, NK)]


@dataclass
Expand Down
18 changes: 9 additions & 9 deletions benchmark/core_shapes.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
outer:
shapes:
- [384, 384]
- [1024, 1024]
- [4096, 4096]
- [8192, 8192]
- [10240, 10240] #from perf

randperm:
shapes:
- [64]
Expand Down Expand Up @@ -41,13 +33,21 @@ diag:

BlasBenchmark:
shapes:
- [2, 384, 384, 384]
- [2, 4096, 4096, 4096]
- [16, 384, 384, 384]
- [16, 1024, 1024, 1024]
- [16, 2048, 2048, 2048]
- [16, 4096, 4096, 4096]
shape_desc: "B, M, N, K" # shapes are defined as (B, M, N, K)

MvAndOuterBenchmark:
shapes:
- [384, 384]
- [1024, 1024]
- [4096, 4096]
- [8192, 8192]
- [10240, 10240] #from perf

# NORM shapes can be either 3D or 4D:
# - 3D shapes are represented as [batch_size, channels, hidden_size]
# - 4D shapes are represented as [batch_size, channels, height, width]
Expand Down
134 changes: 71 additions & 63 deletions benchmark/test_blas_perf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import itertools
from typing import Generator

import pytest
import torch

from .attri_util import DEFAULT_METRICS, FLOAT_DTYPES, BenchLevel, llama_shapes
from .attri_util import DEFAULT_METRICS, FLOAT_DTYPES, BenchLevel, model_shapes
from .conftest import Config
from .performance_utils import Benchmark
from .performance_utils import Benchmark, GenericBenchmark2DOnly


class BlasBenchmark(Benchmark):
Expand All @@ -22,83 +21,75 @@ def __init__(self, *args, input_fn, **kwargs):

def get_input_iter(self, cur_dtype) -> Generator:
for b, m, n, k in self.shapes:
yield from self.input_fn(b, m, n, k, cur_dtype, self.device)
# llama shapes
yield from self.input_fn(b, m, n, k, cur_dtype, self.device, False)

if Config.bench_level == BenchLevel.COMPREHENSIVE:
for m, n, k, _ in llama_shapes():
yield from self.input_fn(1, m, n, k, cur_dtype, self.device)
for b, m, n, k in self.shapes:
yield from self.input_fn(b, m, n, k, cur_dtype, self.device, True)

def set_more_shapes(self):
split_k_shapes = [
(1, m, m, k)
for m in [16 * i for i in range(1, 5)]
for k in [4096 * i for i in range(1, 9)]
large_k_shapes = [
(8, 1848, 1536, 151936),
(8, 1848, 1536, 128256),
(8, 1848, 1536, 152064),
]
# 'mv' operations only involve M and N dimensions.
# Shapes with large K values are not suitable for these two operations.
if self.op_name not in ["mv"]:
# B=1 or 4, M= 13, N= 2 , K=2^6..2^15
large_k_shapes = list(
itertools.product([1, 4], [13], [2], [2**i for i in range(6, 15)])
)
return large_k_shapes + split_k_shapes
return split_k_shapes

model_shaps = model_shapes()
return large_k_shapes + model_shaps

def get_tflops(self, op, *args, **kwargs):
"""This method is currently not really implemented and serves as a placeholder.
A proper implementation will be developed in the future."""
total_flops = 0
# shape(m,k)(k,n)
# total_flops mxnx2k
if self.op_name == "mm":
total_flops = args[0].shape[0] * args[0].shape[1] * args[1].shape[1] * 2
# shape(m,n)(n,p)
# total_flops mxpx(2n+1)
if self.op_name == "addmm":
elif self.op_name == "addmm":
total_flops = (
args[0].shape[0] * args[1].shape[1] * (args[1].shape[0] * 2 + 1)
)
# shape(b,n,m), (b,m,p)
# total_flops bxnxpx2m
if self.op_name == "bmm":
elif self.op_name == "bmm":
total_flops = (
args[0].shape[0]
* args[0].shape[1]
* args[1].shape[2]
* 2
* args[0].shape[2]
)
# shape(n,m)(m,)
# total_flops n*2m
if self.op_name == "mv":
total_flops = args[0].shape[0] * 2 * args[0].shape[1]

return total_flops


def addmm_input_fn(b, m, n, k, cur_dtype, device):
def addmm_input_fn(b, m, n, k, cur_dtype, device, b_column_major):
inp1 = torch.randn([m, k], dtype=cur_dtype, device=device)
inp2 = torch.randn([k, n], dtype=cur_dtype, device=device)
bias = torch.randn([m, n], dtype=cur_dtype, device=device)
yield bias, inp1, inp2,
if b_column_major:
inp2 = torch.randn([n, k], dtype=cur_dtype, device=device)
yield bias, inp1, inp2.t(),
else:
inp2 = torch.randn([k, n], dtype=cur_dtype, device=device)
yield bias, inp1, inp2,


def bmm_input_fn(b, m, n, k, cur_dtype, device):
def bmm_input_fn(b, m, n, k, cur_dtype, device, b_column_major):
inp1 = torch.randn([b, m, k], dtype=cur_dtype, device=device)
inp2 = torch.randn([b, k, n], dtype=cur_dtype, device=device)
yield inp1, inp2
if b_column_major:
inp2 = torch.randn([b, n, k], dtype=cur_dtype, device=device)
yield inp1, inp2.transpose(1, 2)
else:
inp2 = torch.randn([b, k, n], dtype=cur_dtype, device=device)
yield inp1, inp2


def mm_input_fn(b, m, n, k, cur_dtype, device):
def mm_input_fn(b, m, n, k, cur_dtype, device, b_column_major):
inp1 = torch.randn([m, k], dtype=cur_dtype, device=device)
inp2 = torch.randn([k, n], dtype=cur_dtype, device=device)
yield inp1, inp2


def mv_input_fn(b, m, n, k, cur_dtype, device):
inp1 = torch.randn([m, n], dtype=cur_dtype, device=device)
inp2 = torch.randn([n], dtype=cur_dtype, device=device)
yield inp1, inp2
if b_column_major:
inp2 = torch.randn([n, k], dtype=cur_dtype, device=device)
yield inp1, inp2.t()
else:
inp2 = torch.randn([k, n], dtype=cur_dtype, device=device)
yield inp1, inp2


@pytest.mark.parametrize(
Expand All @@ -122,12 +113,6 @@ def mv_input_fn(b, m, n, k, cur_dtype, device):
mm_input_fn,
marks=pytest.mark.mm,
),
pytest.param(
"mv",
torch.Tensor.mv,
mv_input_fn,
marks=pytest.mark.mv,
),
],
)
def test_blas_benchmark(op_name, torch_op, input_fn):
Expand All @@ -137,9 +122,9 @@ def test_blas_benchmark(op_name, torch_op, input_fn):
bench.run()


class OuterBenchmark(BlasBenchmark):
class MvAndOuterBenchmark(GenericBenchmark2DOnly):
"""
benchmark for outer
Benchmark for MV and Outer operations
"""

def set_more_shapes(self):
Expand All @@ -150,17 +135,40 @@ def get_input_iter(self, cur_dtype) -> Generator:
yield from self.input_fn(m, n, cur_dtype, self.device)


@pytest.mark.outer
def test_outer_benchmark():
def outer_input_fn(m, n, cur_dtype, device):
inp1 = torch.randn([m], dtype=cur_dtype, device=device)
inp2 = torch.randn([n], dtype=cur_dtype, device=device)
yield inp1, inp2
def mv_input_fn(m, n, cur_dtype, device):
inp1 = torch.randn([m, n], dtype=cur_dtype, device=device)
inp2 = torch.randn([n], dtype=cur_dtype, device=device)
yield inp1, inp2


def outer_input_fn(m, n, cur_dtype, device):
inp1 = torch.randn([m], dtype=cur_dtype, device=device)
inp2 = torch.randn([n], dtype=cur_dtype, device=device)
yield inp1, inp2


bench = OuterBenchmark(
input_fn=outer_input_fn,
op_name="outer",
torch_op=torch.Tensor.outer,
@pytest.mark.parametrize(
"op_name, torch_op, input_fn",
[
pytest.param(
"mv",
torch.Tensor.mv,
mv_input_fn,
marks=pytest.mark.mv,
),
pytest.param(
"outer",
torch.Tensor.outer,
outer_input_fn,
marks=pytest.mark.outer,
),
],
)
def test_mv_and_outer_benchmark(op_name, torch_op, input_fn):
bench = MvAndOuterBenchmark(
input_fn=input_fn,
op_name=op_name,
torch_op=torch_op,
dtypes=FLOAT_DTYPES,
)
bench.run()
Loading