Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Llava-SGlang #54

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
AVAILABLE_MODELS = {
"llava": "Llava",
"llava_hf": "LlavaHf",
"llava_sglang": "LlavaSglang",
"qwen_vl": "Qwen_VL",
"fuyu": "Fuyu",
"gpt4v": "GPT4V",
Expand Down
164 changes: 164 additions & 0 deletions lmms_eval/models/llava_sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch

torch.backends.cuda.matmul.allow_tf32 = True

import logging
from tqdm import tqdm
from datetime import timedelta

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

from accelerate import Accelerator, InitProcessGroupKwargs
from typing import List, Optional, Union, Tuple
import warnings

warnings.filterwarnings("ignore")
from concurrent.futures import ThreadPoolExecutor, as_completed
import tempfile

eval_logger = logging.getLogger("lmms-eval")

try:
import sglang as sgl
from sglang.lang.chat_template import get_chat_template
except ImportError:
eval_logger.error("SGLang is not installed. If you want to use llava_sglang, please install it using pip install 'sglang[all]' ")

if torch.__version__ > "2.1.2":
best_fit_attn_implementation = "sdpa"
else:
best_fit_attn_implementation = "eager"


@register_model("llava_sglang")
class LlavaSglang(lmms):
"""
Llava Sglang Model
"""

def __init__(
self,
pretrained: str = "liuhaotian/llava-v1.5-7b",
tokenizer: str = "llava-hf/llava-1.5-7b-hf",
tp_size: int = 1,
parallel: Optional[Union[int, str]] = 64,
conv_template="vicuna_v1.1",
**kwargs,
) -> None:
super().__init__()
self.pretrained = pretrained
self.tokenizer = tokenizer
self.tp_size = tp_size
self.conv_template = conv_template
torch.multiprocessing.set_start_method("spawn")

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
assert accelerator.num_processes == 1, "Llava-sglang does not support multi-processes yet (it does support tensor parallelism)."
self._rank = 0
self._world_size = 1
self.parallel = parallel

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Llava-sglang does not support loglikelihood evaluation yet")

def generate_until(self, requests: List[Instance]) -> List[str]:

runtime = sgl.Runtime(model_path=self.pretrained, tokenizer_path=self.tokenizer, tp_size=self.tp_size)
runtime.endpoint.chat_template = get_chat_template(self.conv_template)
sgl.set_default_backend(runtime)

@sgl.function
def image_qa(s, image_file, question):
s += sgl.user(sgl.image(image_file) + question)
s += sgl.assistant(sgl.gen("answer"))

res = []

def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[0].split(" ")
return -len(toks), x[0]

# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=self.parallel, batch_fn=None)
num_iters = len(requests) // self.parallel if len(requests) % self.parallel == 0 else len(requests) // self.parallel + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk)
batched_visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids, task, split, doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)] # [B, N]
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = 1.0
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
if gen_kwargs["top_p"] == 0.0:
gen_kwargs["top_p"] = 1.0
gen_kwargs["temperature"] = 0.0
assert gen_kwargs["num_beams"] == 1

def save_image_to_temp_file(image):
temp_file = tempfile.NamedTemporaryFile(suffix=".jpeg", delete=True)
image.save(temp_file.name)
return temp_file

def prepare_arguments_parallel(contexts, batched_visuals, max_workers=64):
arguments = [None] * len(contexts) # Initialize with placeholders
tmp_files = [None] * len(contexts) # Initialize with placeholders

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Associate each future with its index and content
future_to_info = {executor.submit(save_image_to_temp_file, pil_list[0]): (index, context, pil_list) for index, (context, pil_list) in enumerate(zip(contexts, batched_visuals))}

for future in as_completed(future_to_info):
index, context, pil_list = future_to_info[future]
if len(pil_list) > 1:
eval_logger.warning("Llava-sglang only supports one visual input per question. Using the first visual input.")
try:
temp_file = future.result()
arguments[index] = {
"image_file": temp_file.name,
"question": context,
}
tmp_files[index] = temp_file
except Exception as exc:
print(f"Generated an exception: {exc}")

# Filter out any None values in case of exceptions
arguments = [arg for arg in arguments if arg is not None]
tmp_files = [tmp_file for tmp_file in tmp_files if tmp_file is not None]

return arguments, tmp_files

arguments, tmp_files = prepare_arguments_parallel(contexts, batched_visuals, self.parallel)
states = image_qa.run_batch(arguments, temperature=gen_kwargs["temperature"], max_new_tokens=gen_kwargs["max_new_tokens"], top_p=gen_kwargs["top_p"], num_threads=self.parallel, progress_bar=False)

text_outputs = [state["answer"].strip() for state in states]
# clean up the temporary files
for tmp_file in tmp_files:
tmp_file.close()
res.extend(text_outputs)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)

pbar.close()
runtime.shutdown()
return res
Loading