From b67980309b4140dd18edd8b50919243c4dbf0415 Mon Sep 17 00:00:00 2001 From: "yangsijia.614" Date: Tue, 25 Feb 2025 23:52:54 +0800 Subject: [PATCH] fix(benchmark): store 'compare' and 'one' perf results in csv files and visualize them --- benchmark/bench_flash_mla.py | 18 ++++++++++++------ benchmark/visualize.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 7b0e7b4..14e1352 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -1,15 +1,16 @@ # MLA Triton kernel is from: https://github.com/monellz/vllm/commit/feebaa7c063be6bfb590a876741aeef1c5f58cf8#diff-7b2e1c9032522f7266051b9887246a65753871dfb3625a258fee40109fe6e87a +import argparse import math import random +import flashinfer import torch import triton import triton.language as tl -import argparse # pip install flashinfer-python -from flash_mla import get_mla_metadata, flash_mla_with_kvcache -import flashinfer +from flash_mla import flash_mla_with_kvcache, get_mla_metadata + def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): query = query.float() @@ -443,6 +444,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") + return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): @@ -501,7 +503,8 @@ def get_args(): if __name__ == "__main__": args = get_args() - with open("all_perf.csv", "w") as fout: + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: fout.write("name,batch,seqlen,head,bw\n") for shape in shape_configs: if args.all: @@ -509,6 +512,9 @@ def get_args(): perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') elif args.compare: - compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') elif args.one: - compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) \ No newline at end of file + perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') \ No newline at end of file diff --git a/benchmark/visualize.py b/benchmark/visualize.py index db62519..c1fb37e 100644 --- a/benchmark/visualize.py +++ b/benchmark/visualize.py @@ -1,7 +1,17 @@ +import argparse + import matplotlib.pyplot as plt import pandas as pd -file_path = 'all_perf.csv' + +def parse_args(): + parser = argparse.ArgumentParser(description='Visualize benchmark results') + parser.add_argument('--file', type=str, default='all_perf.csv', + help='Path to the CSV file with benchmark results (default: all_perf.csv)') + return parser.parse_args() + +args = parse_args() +file_path = args.file df = pd.read_csv(file_path) @@ -16,4 +26,4 @@ plt.ylabel('bw (GB/s)') plt.legend() -plt.savefig('bandwidth_vs_seqlen.png') \ No newline at end of file +plt.savefig(f'{file_path.split(".")[0].split("/")[-1]}_bandwidth_vs_seqlen.png') \ No newline at end of file