From ab272bfe09915f84bc4e2439055dd7d0e82e08ca Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 26 Dec 2024 09:24:46 +0100 Subject: [PATCH] matmul_benchmark fix --- benchmark/matmul_benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/matmul_benchmark.py b/benchmark/matmul_benchmark.py index 625441e5..502e44dc 100644 --- a/benchmark/matmul_benchmark.py +++ b/benchmark/matmul_benchmark.py @@ -102,14 +102,14 @@ def benchmark(f, warmup=10, iter=10): matmul = CUDA_KERNEL.code1x16_matmat if args.nbits_per_codebook == 16 else CUDA_KERNEL.code2x8_matmat - output = matmul(input, codes, codebooks, scales) + output = matmul(input, codes, codebooks, scales, None) if args.log_error: print( f"Relative error: {(torch.mean(torch.abs(output_ref - output)) / torch.mean(torch.abs(output_ref))).item():.2e}" ) dense += benchmark(lambda: F.linear(input, weight, out=output_ref), args.warmup_iters, args.benchmark_iters) - quant += benchmark(lambda: matmul(input, codes, codebooks, scales), args.warmup_iters, args.benchmark_iters) + quant += benchmark(lambda: matmul(input, codes, codebooks, scales, None), args.warmup_iters, args.benchmark_iters) print(f"{model}: Dense forward = {dense * 1e6:.0f} mus") print(f"{model}: Quant forward = {quant * 1e6:.0f} mus")