Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemm a8w8 bench #117

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
5e408c3
add rocm backend
valarLip Sep 25, 2024
2c55753
slightly speedup weight loading
valarLip Sep 26, 2024
0b582b3
moe_final_v0.6.0_sept24 based fused_moe_kernel
valarLip Sep 27, 2024
d1bd445
add fused_add_rms_norm
valarLip Sep 28, 2024
b66e4bd
enable shuffle/LDS bypass
valarLip Sep 28, 2024
ca690b0
add fused_add_rms_norm
valarLip Sep 28, 2024
f989c1a
enable shuffle/LDS bypass
valarLip Sep 28, 2024
58ece04
added explicit padding
Oct 1, 2024
509f527
add tuned_gemm
valarLip Oct 4, 2024
ad24efc
Merge branch 'vendors/rocm_base_moe' of https://github.com/ROCm/ByteM…
valarLip Oct 4, 2024
43088a3
aligh default behaviar, default enable LDS_BYPASS , disable MOE_PADDING
valarLip Oct 4, 2024
694c9c6
add fused_rope
valarLip Oct 5, 2024
8020097
default enable VLLM_MOE_PADDING
valarLip Oct 5, 2024
3fae254
moe_final_v0.6.0_sept24 based fused_moe_kernel
valarLip Sep 27, 2024
6ffa0ad
add fused_add_rms_norm
valarLip Sep 28, 2024
c4f8b8f
enable shuffle/LDS bypass
valarLip Sep 28, 2024
c7ba512
added explicit padding
Oct 1, 2024
9debce6
add tuned_gemm
valarLip Oct 4, 2024
6196a72
aligh default behaviar, default enable LDS_BYPASS , disable MOE_PADDING
valarLip Oct 4, 2024
76f1333
add fused_rope
valarLip Oct 5, 2024
693ff3b
default enable VLLM_MOE_PADDING
valarLip Oct 5, 2024
7385d8f
Fix test_iter and use sys._exit for profile.
jiaryang Oct 8, 2024
c7ae902
port paged_attn and result&bench ok
Oct 10, 2024
07e9b24
reconstruct kv cache codes
Oct 10, 2024
0a4b64c
slot mappings only calc once
Oct 10, 2024
85c254a
Merge pull request #8 from dummycoderfe/paged_attn_merge_new
shengnxu Oct 10, 2024
e3cc3ef
merge moe_base
Oct 10, 2024
d199c6c
rm useless
Oct 10, 2024
6857b2c
Merge pull request #9 from dummycoderfe/paged_attn_merge_new
valarLip Oct 10, 2024
8388f5e
hot fix pa perf
felixamd Oct 10, 2024
77aafc5
Merge pull request #10 from dummycoderfe/hot_fix_pa_perf
valarLip Oct 10, 2024
0bca731
add moe_sum
valarLip Oct 10, 2024
8fada97
initial add switch for hipgraph (not working yet)
carlushuang Oct 12, 2024
7e3d60d
enable custom_ar
valarLip Oct 15, 2024
2464206
add missing custom_ar code...
valarLip Oct 15, 2024
8818c04
remove runtime cpu2gpu copy
valarLip Oct 15, 2024
3c83649
try hipgraph...
valarLip Oct 16, 2024
84a9550
reduce mem usage for perf test
valarLip Oct 17, 2024
819000b
fuse renorm into topk_softmax
valarLip Oct 19, 2024
1ac69ac
fuse renorm into topk_softmax........ for other case
valarLip Oct 19, 2024
35fac8e
reduce buffer fill
valarLip Oct 20, 2024
0c3fde3
add more MoE tuned config
valarLip Oct 23, 2024
33b1808
add moe int8 kernel
Oct 23, 2024
43fd4ba
remove comments
shengnxu Oct 24, 2024
07facd3
Merge pull request #12 from shengnxu/vendors/rocm_base
shengnxu Oct 24, 2024
dfd2534
fix for new triton version
Oct 24, 2024
a4e8a54
test for moe int8
Oct 24, 2024
363b992
remove redundant files
Oct 24, 2024
e38df42
Merge pull request #13 from shengnxu/vendors/rocm_base
shengnxu Oct 24, 2024
78096be
add ck layernorm_2d backend
valarLip Oct 24, 2024
bfe52ea
code clean
valarLip Oct 24, 2024
731b864
update layernorm test
valarLip Oct 24, 2024
4160ddc
add moe config for 32 E
valarLip Oct 24, 2024
d5bce8f
update config
valarLip Oct 25, 2024
722c0cc
add int8 a8w8 gemm
Oct 28, 2024
5affc61
Merge branch 'vendors/rocm_base' of https://github.com/ROCm/ByteMLPer…
Oct 28, 2024
4b1bcef
code clean up...
valarLip Oct 28, 2024
9f2a214
bench for gemm a8w8
shengnxu Oct 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def signal_handler(signum, frame):
logger.info(f"rank {local_rank} received signal {signum}, exiting...")
if hasattr(model, 'finalize_inference'):
model.finalize_inference()
os._exit(0)
sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
Expand Down Expand Up @@ -195,4 +195,98 @@ def mp_forward(self, *args):
output_dict = self._output_queues.get(block=True)

return output_dict


# ROCM_HIPGRAPH modify
class GpuMpEngineWithGraph(GpuMpEngine):
def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None:
super().__init__(world_size, model_impl, xpu_cfg)
logger.info("@@@@@@@@@@ GpuMpEngineWithGraph")

@torch.no_grad()
def mp_loop_worker(
self,
local_rank: int,
world_size: int,
input_queue: Queue,
output_queue: Queue,
model_impl,
xpu_config
):
try:
torch.manual_seed(1)

# set rank and world_size
os.environ["RANK"] = str(local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)

# create and init model based on model_impl and xpu_config
model = model_impl(xpu_config)
if hasattr(model, 'init_inference'):
model.init_inference()

def signal_handler(signum, frame):
logger.info(f"rank {local_rank} received signal {signum}, exiting...")
if hasattr(model, 'finalize_inference'):
model.finalize_inference()
sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

# current rank is ready
output_queue.put("ready", block=True)
logger.info(f"{local_rank}/{world_size} rank is ready")

graph = torch.cuda.CUDAGraph()

# model process loop
while True:
(
forward_inputs,
) = input_queue.get(block=True)

# this is the capture phase of graph
if 'capture' in forward_inputs:
graph.reset() # reset cuda graph each time
inputs_dict = self.build_inputs(forward_inputs)
# model.forward(inputs_dict)
torch.cuda.synchronize()
with torch.cuda.graph(graph):
model.forward(inputs_dict)
torch.cuda.synchronize()
continue

log = forward_inputs.get("log", False)
workspace = forward_inputs.get("workspace", None)

forward_inputs["log_file"] = None
if log and workspace is not None:
workspace_dir = workspace / f"rank_{local_rank}"
workspace_dir.mkdir(exist_ok=True, parents=True)
forward_inputs["log_file"] = open(workspace_dir / "run.log", "w")


inputs_dict = self.build_inputs(forward_inputs)
start_time = time.perf_counter_ns()

# output_dict = model.forward(inputs_dict)
graph.replay()

torch.cuda.synchronize()
end_time = time.perf_counter_ns()
duration_ms = round((end_time - start_time) / 1e6, 3)
output_dict = dict()
output_dict["duration_ms"] = duration_ms

# TP realization: rank0 send result back to main process
if local_rank == 0:
output_queue.put(output_dict)

if log and workspace is not None:
forward_inputs["log_file"].close()

except Exception as e:
logger.exception(f"[BUG] engine _load_and_listen failed, no more requests will be handled. {e}")
output_queue.put(RuntimeError("[BUG] fatal exception in model subprocess"))
51 changes: 51 additions & 0 deletions byte_infer_perf/llm_perf/backends/ROCM/gpu_ckpt_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.distributed as dist

from llm_perf.core.ckpt_loader import CoreCkptLoader

class GpuCkptLoader(CoreCkptLoader):
def __init__(
self,
prefix, model,
mp_size=1, mp_rank=0,
ckpt_path: str=""
):
super().__init__(prefix, model, mp_size, mp_rank, ckpt_path)


def weight_to_device(self, weight : torch.Tensor, non_blocking=False):
if self.mp_rank == 0:
weight = weight.cuda(non_blocking=non_blocking)
else:
cur_device = torch.cuda.current_device()
weight = torch.empty_like(weight, device=f"cuda:{cur_device}")
return weight


def broadcast_weight(self, key, device='cpu', non_blocking=False):
if self.mp_rank != 0:
tensor_shape = self.state_dict[key]["shape"]
tensor_dtype = self.state_dict[key]["dtype"]
tensor = torch.empty(tensor_shape, dtype=tensor_dtype)
else:
tensor = self.state_dict[key].cpu()
tensor_gpu = self.weight_to_device(tensor, non_blocking=non_blocking)
dist.broadcast(tensor_gpu, src=0)
self.state_dict[key] = tensor_gpu


def scatter_weight(self, key, dim, split_mode='default', outter=1, device='cpu', non_blocking=False):
self.broadcast_weight(key, non_blocking=non_blocking)
weight = self.state_dict[key]

if split_mode == 'default':
weight_split = self.split(weight, dim)
elif split_mode == 'with_outter':
weight_split = self.with_outter_split(weight, dim, outter)
elif split_mode == 'split_outter':
weight_split = self.split(weight, dim, outter)
else:
assert False, f"unknown split mode {split_mode}"

weight_split = [x.contiguous() for x in weight_split]
self.state_dict[key] = weight_split[self.mp_rank]
131 changes: 131 additions & 0 deletions byte_infer_perf/llm_perf/backends/ROCM/gpu_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
from typing import Dict, List, Any
from dataclasses import dataclass

from llm_perf.core.generation import GenerateRequest
from llm_perf.core.inferencer import CoreInferencer
from llm_perf.backends.ROCM.gpu_mp_engine import GpuMpEngine
from llm_perf.utils.logger import logger

class GpuInferencer(CoreInferencer):
def __init__(self, model_impl, xpu_cfg):
super().__init__()

self.tp_size = xpu_cfg["tp_size"]
self.pad_token_id = xpu_cfg["pad_token_id"]
self.max_batch_size = xpu_cfg["max_batch_size"]
self.mp_engine = GpuMpEngine(self.tp_size, model_impl, xpu_cfg)

def prepare_inputs(
self,
tasks: List[CoreInferencer.Task],
**kwargs
):
input_dict = {
"input_ids": None,
"position_ids": None,
"attention_mask": None,
"all_q_len": None,
"all_kv_len": None,
"is_context": None,
"valid_slot_ids": None
}

is_context = kwargs.get("is_context") if "is_context" in kwargs.keys() else False
valid_slot_ids = kwargs.get("valid_slot_ids") if "valid_slot_ids" in kwargs.keys() else [i for i in range(self.max_batch_size)]


get_input_logits = False
for task in tasks:
if task.request.generate_config.get_input_logits:
get_input_logits = True
break

input_dict["is_context"] = is_context
input_dict["valid_slot_ids"] = valid_slot_ids
input_dict["get_input_logits"] = get_input_logits

if is_context:
q_len = len(tasks[0].request.input_ids)
kv_len = len(tasks[0].request.input_ids)

input_dict["input_ids"] = [
tasks[0].request.input_ids
]
input_dict["position_ids"] = [
[i for i in range(q_len)]
]
input_dict["attention_mask"] = [
[1 for _ in range(q_len)]
]
input_dict["all_q_len"] = [
q_len
]
input_dict["all_kv_len"] = [
kv_len
]
else:
all_input_ids = []
all_position_ids = []
all_attention_mask = []
all_q_len = []
all_kv_len = []

for task in tasks:
q_len = 1
kv_len = 0

if task is None:
kv_len = 1

input_ids = [
self.pad_token_id
]
position_ids = [
0
]
attention_mask = [
0
]
else:
kv_len = len(task.request.input_ids) + len(task.generate_ids) - 1

input_ids = [
task.generate_ids[-1]
]
position_ids = [
kv_len
]
attention_mask = [
1
]
all_input_ids.append(input_ids)
all_position_ids.append(position_ids)
all_attention_mask.append(attention_mask)
all_q_len.append(q_len)
all_kv_len.append(kv_len)

input_dict["input_ids"] = all_input_ids
input_dict["position_ids"] = all_position_ids
input_dict["attention_mask"] = all_attention_mask
input_dict["all_q_len"] = all_q_len
input_dict["all_kv_len"] = all_kv_len

return input_dict


def infer(
self,
tasks: List[CoreInferencer.Task],
**kwargs
):
input_dict = self.prepare_inputs(tasks, **kwargs)
output_dict = self.mp_engine.mp_forward(input_dict)

logits = output_dict["logits"]
next_token_logits = logits[:, -1, :].contiguous()
infer_outputs = {
"logits": logits,
"last_logits": next_token_logits
}
return infer_outputs
Loading