Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them #45

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions benchmark/bench_flash_mla.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a
import argparse
import math
import random

import flashinfer
import torch
import triton
import triton.language as tl
import argparse

# pip install flashinfer-python
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
import flashinfer
from flash_mla import flash_mla_with_kvcache, get_mla_metadata


def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
Expand Down Expand Up @@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s")
return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b


def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
Expand Down Expand Up @@ -501,14 +503,18 @@ def get_args():

if __name__ == "__main__":
args = get_args()
with open("all_perf.csv", "w") as fout:
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
elif args.compare:
compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one:
compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
14 changes: 12 additions & 2 deletions benchmark/visualize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import argparse

import matplotlib.pyplot as plt
import pandas as pd

file_path = 'all_perf.csv'

def parse_args():
parser = argparse.ArgumentParser(description='Visualize benchmark results')
parser.add_argument('--file', type=str, default='all_perf.csv',
help='Path to the CSV file with benchmark results (default: all_perf.csv)')
return parser.parse_args()

args = parse_args()
file_path = args.file

df = pd.read_csv(file_path)

Expand All @@ -16,4 +26,4 @@
plt.ylabel('bw (GB/s)')
plt.legend()

plt.savefig('bandwidth_vs_seqlen.png')
plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png')