Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tflops measurement - habana_main #151

Closed
wants to merge 13 commits into from
56 changes: 56 additions & 0 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
# LICENSE file in the root directory of this source tree.
###############################################################################
import os
import time
from math import ceil
from typing import Optional

import habana_frameworks.torch as htorch
import torch
import torch.nn.functional as F

from vllm.logger import init_logger
from vllm.worker.profiler import Profiler

logger = init_logger(__name__)
HPUFusedRMSNorm = None
Expand Down Expand Up @@ -47,6 +50,11 @@ def paged_attention_v1(query,
matmul_av_op=torch.matmul,
k_cache_cls=None,
v_cache_cls=None) -> None:
habana_profiler = Profiler()
torch.hpu.synchronize()
start_time = time.time()

htorch.core.mark_step()
seq_len = block_tables.size(1)
batch_size, query_heads, _ = query.shape
_, _, kv_heads, _ = key_cache.shape
Expand Down Expand Up @@ -87,6 +95,22 @@ def paged_attention_v1(query,
if query_heads != kv_heads:
attn_weights = [a.flatten(1, 2) for a in attn_weights]
attn_weights = sum(attn_weights)
htorch.core.mark_step()

torch.hpu.synchronize()
end_time = time.time()

flops = flops_counter_decode(num_att_heads=query.shape[1],
batch_size=batch_size,
query_seq_len=query.shape[2],
max_seq_len=key_cache.shape[2],
block_size=block_size,
query_embedding_dim=query.shape[3],
value_embedding_dim=key_cache.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
{"PA TFLOPS": flops / 1e12})

return attn_weights.squeeze(-2)


Expand Down Expand Up @@ -137,6 +161,10 @@ def prompt_attention(
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
) -> torch.Tensor:
habana_profiler = Profiler()
start_time = time.time()

htorch.core.mark_step()
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
Expand All @@ -156,4 +184,32 @@ def prompt_attention(
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
htorch.core.mark_step()

end_time = time.time()
flops = flops_counter_prompt(num_att_heads=query.shape[1],
batch_size=query.shape[0],
query_seq_len=query.shape[2],
max_seq_len=key.shape[2],
query_embedding_dim=query.shape[3],
value_embedding_dim=key.shape[3],
duration=end_time - start_time)
habana_profiler.record_counter(habana_profiler.get_timestamp_us(),
{"Prompt TFLOPS": flops / 1e12})

return attn_weights


def flops_counter_decode(num_att_heads, batch_size, query_seq_len, max_seq_len,
block_size, query_embedding_dim, value_embedding_dim,
duration) -> float:
return (batch_size * num_att_heads * query_seq_len *
ceil(max_seq_len / block_size) * block_size * 2 *
(query_embedding_dim + value_embedding_dim) / duration)


def flops_counter_prompt(num_att_heads, batch_size, query_seq_len, max_seq_len,
query_embedding_dim, value_embedding_dim,
duration) -> float:
return (batch_size * num_att_heads * query_seq_len * max_seq_len * 2 *
(query_embedding_dim + value_embedding_dim) / duration)
12 changes: 12 additions & 0 deletions vllm/worker/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def run(self):
outfile.write(content)


def singleton(class_):
instances = {}

def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]

return getinstance


@singleton
class Profiler:
profiling_trace_events: queue.Queue = queue.Queue()
event_tid = {'counter': 1, 'external': 2, 'internal': 3}
Expand Down
Loading