-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
benchmark_batch_size.py #1
Comments
Here is the code for import json
import torch
import transformers
from collections import namedtuple
from utils.generation_utils import (
get_generation_model,
get_generation_tokenizer,
get_input_encoding,
get_memory_constrained_generation,
get_terminators,
)
from utils.reward_utils import get_reward_model, get_reward_tokenizer
from utils.validation_utils import get_full_model_name
LLM_NAMES = [
"sft10k",
"Meta-Llama-3-8B",
"Meta-Llama-3-8B-Instruct",
]
REWARD_MODEL_NAME = "ArmoRM-Llama3-8B-v0.1"
MODEL_DIR = {YOUR MODEL DIR}
batch_size_low = 20
batch_size_high = 10_000
target_length = 10
def get_generation_length(
batch_size: int,
generation_model: transformers.LlamaForCausalLM,
input_ids: torch.LongTensor,
pad_token_id: int | None,
terminators: list[int | None],
) -> int:
input_length = input_ids.size(1)
batched_input_ids = input_ids.repeat(batch_size, 1)
Args = namedtuple("Args", ["top_p", "top_k", "max_tokens"])
args = Args(top_p=1.0, top_k=50, max_tokens=input_length + target_length + 1)
output_ids = get_memory_constrained_generation(
generation_model,
batched_input_ids,
terminators,
pad_token_id,
args,
)
output_length = output_ids.size(1)
generation_length = output_length - input_length
return generation_length
def find_largest_batch_size(
batch_size_low: int,
batch_size_high: int,
target_length: int,
generation_model: transformers.LlamaForCausalLM,
input_ids: torch.LongTensor,
pad_token_id: int | None,
terminators: list[int | None],
) -> int:
"""
Perform a binary search to find the largest batch size that produces
a generation length equal to the target length.
Parameters:
batch_size_low (int): The lower bound of the batch size.
batch_size_high (int): The upper bound of the batch size.
target_length (int): The target generation length.
Returns:
int: The largest batch size that produces the target generation length.
"""
result = -1
while batch_size_low <= batch_size_high:
batch_size_mid = (batch_size_low + batch_size_high) // 2
length = get_generation_length(
batch_size_mid, generation_model, input_ids, pad_token_id, terminators
)
print(
f"input_length: {input_ids.shape[-1]}, batch_size: {batch_size_mid}, length: {length}"
)
if length == target_length:
result = batch_size_mid
batch_size_low = batch_size_mid + 1
# if the length is greater than the target length, we need to increase the batch size
elif length > target_length:
batch_size_low = batch_size_mid + 1
else:
batch_size_high = batch_size_mid - 1
return result if result != -1 else batch_size_mid
def get_model_data_dict(model_basename: str) -> dict[int, int]:
data_dict = {}
LLM_name = get_full_model_name(MODEL_DIR, model_basename)
reward_model_name = get_full_model_name(MODEL_DIR, REWARD_MODEL_NAME)
generation_model = get_generation_model(LLM_name, device="cuda")
generation_tokenizer = get_generation_tokenizer(LLM_name)
terminators = get_terminators(LLM_name, generation_tokenizer)
generation_model = get_generation_model(LLM_name, "cuda")
reward_tokenizer = get_reward_tokenizer(reward_model_name)
reward_model = get_reward_model(reward_model_name, reward_tokenizer, "cuda")
sample_prompt = "The"
start_encoding = get_input_encoding(
[sample_prompt], generation_model, generation_tokenizer
)
for input_length in range(5, 206, 50):
# repeat start_encoding to get a length of input_length
input_ids: torch.LongTensor = start_encoding.input_ids.repeat(1, input_length)
# try out different batch sizes and use a binary search to find the largest batch size that works
largest_batch_size = find_largest_batch_size(
batch_size_low,
batch_size_high,
target_length,
generation_model,
input_ids,
generation_tokenizer.pad_token_id,
terminators,
)
data_dict[input_length] = largest_batch_size
return data_dict
def write_to_disk(data_dict: dict[str, dict[int, int]]) -> None:
with open("benchmark_batch_size.json", "w") as f:
json.dump(data_dict, f, indent=4)
def main() -> None:
full_data_dict: dict[str, dict[int, int]] = {}
for model_basename in LLM_NAMES:
model_data_dict = get_model_data_dict(model_basename)
full_data_dict[model_basename] = model_data_dict
write_to_disk(full_data_dict)
if __name__ == "__main__":
main() |
@preminstrel 谢谢 |
@preminstrel 还有个问题,就是请教下,这个是适合所有的模型吗,还是需要自己去做相应的适配 |
你说这个文件吗?应该要自己加一下别的model(在那些函数里面)。不过这个其实可以直接手算的,手动计算一下kv cache的memory cost还有两个model weights就可以拿到batch size。 |
@preminstrel 哦哦,我是说这项工作 |
模型选择应该没有什么特殊要求,不用进一步适配,但是效果可能取决于 reward model 质量。 |
@preminstrel 好的,谢谢 |
@preminstrel 想试试用qwen2.5在中文上实验一下 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hello,可以分享下benchmark_batch_size.py这个脚本吗,谢谢
The text was updated successfully, but these errors were encountered: