Skip to content

Commit

Permalink
the rest
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 11, 2024
1 parent 4d18b24 commit 02dfe3c
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 93 deletions.
160 changes: 113 additions & 47 deletions optimum_benchmark/backends/tensorrt_llm/backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any, Dict

import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from hydra.utils import get_class
from safetensors.torch import save_file

from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
from ..transformers_utils import fast_weights_init
from .config import TRTLLMConfig
from .utils import MODEL_TYPE_TO_TRTLLMMODEL

Expand All @@ -25,62 +31,122 @@ def load(self) -> None:
self.logger.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

self.logger.info("\t+ Loading pretrained TRTLLMModel")
self.load_trtmodel_from_pretrained()
if self.config.no_weights:
self.logger.info("\t+ Creating no weights model")
self.create_no_weights_model()
self.logger.info("\t+ Loading no weights model")
self.load_trtllm_with_no_weights()
else:
self.logger.info("\t+ Downloading pretrained model")
self.download_pretrained_model()
if self.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Preparing generation config")
self.prepare_generation_config()
self.logger.info("\t+ Loading pretrained model")
self.load_trtllm_from_pretrained()

self.logger.info("\t+ Cleaning up backend temporary directory")
self.tmpdir.cleanup()

def load_trtmodel_from_pretrained(self) -> None:
def download_pretrained_model(self) -> None:
with torch.device("meta"):
self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs)

def prepare_generation_config(self) -> None:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None

model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{HUGGINGFACE_HUB_CACHE}/{model_cache_folder}"
snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
self.generation_config.save_pretrained(save_directory=model_snapshot_path)

def create_no_weights_model(self) -> None:
self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model")
self.logger.info("\t+ Creating no weights model directory")
os.makedirs(self.no_weights_model, exist_ok=True)
self.logger.info("\t+ Creating no weights model state dict")
state_dict = torch.nn.Linear(1, 1).state_dict()
self.logger.info("\t+ Saving no weights model safetensors")
safetensor = os.path.join(self.no_weights_model, "model.safetensors")
save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"})
self.logger.info("\t+ Saving no weights model pretrained config")
self.pretrained_config.save_pretrained(save_directory=self.no_weights_model)
self.logger.info("\t+ Saving no weights model pretrained processor")
self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model)
# unlike Transformers, TRT-LLM won't accept any missing tensors so we need to materialize the model
self.logger.info(f"\t+ Loading no weights model from {self.no_weights_model}")
with fast_weights_init():
self.pretrained_model = self.automodel_loader.from_pretrained(
self.no_weights_model, **self.config.model_kwargs, device_map="auto", _fast_init=False
)
self.logger.info("\t+ Saving no weights model")
self.pretrained_model.save_pretrained(save_directory=self.no_weights_model)
del self.pretrained_model
torch.cuda.empty_cache()

if self.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Modifying generation config for fixed length generation")
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None
self.logger.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=self.no_weights_model)

def load_trtllm_with_no_weights(self) -> None:
original_model, self.config.model = self.config.model, self.no_weights_model
self.load_trtllm_from_pretrained()
self.config.model = original_model

def load_trtllm_from_pretrained(self) -> None:
self.pretrained_model = self.trtllm_loader.from_pretrained(
self.config.model,
tp=self.config.tp,
pp=self.config.pp,
dtype=self.config.dtype,
use_fp8=self.config.use_fp8,
world_size=self.config.world_size,
gpus_per_node=self.config.gpus_per_node,
use_cuda_graph=self.config.use_cuda_graph,
optimization_level=self.config.optimization_level,
max_prompt_length=self.config.max_prompt_length,
max_batch_size=self.config.max_batch_size,
max_new_tokens=self.config.max_new_tokens,
max_beam_width=self.config.max_beam_width,
**self.config.model_kwargs,
**self.trtllm_kwargs,
)

@property
def trtllm_kwargs(self):
kwargs = {}

if self.config.tp is not None:
kwargs["tp"] = self.config.tp

if self.config.pp is not None:
kwargs["pp"] = self.config.pp

if self.config.dtype is not None:
kwargs["dtype"] = self.config.dtype

if self.config.use_fp8 is not None:
kwargs["use_fp8"] = self.config.use_fp8

if self.config.world_size is not None:
kwargs["world_size"] = self.config.world_size

if self.config.gpus_per_node is not None:
kwargs["gpus_per_node"] = self.config.gpus_per_node

if self.config.use_cuda_graph is not None:
kwargs["use_cuda_graph"] = self.config.use_cuda_graph

if self.config.optimization_level is not None:
kwargs["optimization_level"] = self.config.optimization_level

if self.config.max_prompt_length is not None:
kwargs["max_prompt_length"] = self.config.max_prompt_length

if self.config.tp is not None:
kwargs["max_new_tokens"] = self.config.max_new_tokens

if self.config.max_beam_width is not None:
kwargs["max_beam_width"] = self.config.max_beam_width

return kwargs

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
min_length=kwargs.get("min_new_tokens", -1),
max_new_tokens=kwargs.get("max_new_tokens", -1),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
pad_token_id=kwargs.get("pad_token_id", 0),
bos_token_id=kwargs.get("bos_token_id", 1),
eos_token_id=kwargs.get("eos_token_id", 2),
temperature=kwargs.get("temperature", 1.0),
num_beams=kwargs.get("num_beams", 1),
top_p=kwargs.get("top_p", 1.0),
top_k=kwargs.get("top_k", 50),
seed=kwargs.get("seed", 42),
)
return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs)

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
min_length=kwargs.get("min_new_tokens", -1),
max_new_tokens=kwargs.get("max_new_tokens", -1),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
pad_token_id=kwargs.get("pad_token_id", 0),
bos_token_id=kwargs.get("bos_token_id", 1),
eos_token_id=kwargs.get("eos_token_id", 2),
temperature=kwargs.get("temperature", 1.0),
num_beams=kwargs.get("num_beams", 1),
top_p=kwargs.get("top_p", 1.0),
top_k=kwargs.get("top_k", 50),
seed=kwargs.get("seed", 42),
)
return self.pretrained_model.generate(inputs=inputs.get("input_ids"), **kwargs)
30 changes: 15 additions & 15 deletions optimum_benchmark/backends/tensorrt_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ class TRTLLMConfig(BackendConfig):
version: Optional[str] = tesnorrt_llm_version()
_target_: str = "optimum_benchmark.backends.tensorrt_llm.backend.TRTLLMBackend"

# build config
tp: int = 1
pp: int = 1
use_fp8: bool = False
dtype: str = "float16"
optimization_level: int = 2
use_cuda_graph: bool = False

world_size: int = 1
gpus_per_node: int = 1

max_prompt_length: int = 128
max_new_tokens: int = -1
max_batch_size: int = 1
max_beam_width: int = 1
no_weights: bool = False

# trtllm kwargs
tp: Optional[int] = None
pp: Optional[int] = None
dtype: Optional[str] = None
use_fp8: Optional[bool] = None
world_size: Optional[int] = None
gpus_per_node: Optional[int] = None
use_cuda_graph: Optional[bool] = None
optimization_level: Optional[int] = None
max_prompt_length: Optional[int] = None
max_new_tokens: Optional[int] = None
max_batch_size: Optional[int] = None
max_beam_width: Optional[int] = None

def __post_init__(self) -> None:
super().__post_init__()
Expand Down
14 changes: 6 additions & 8 deletions optimum_benchmark/backends/torch_ort/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,17 @@ def load(self) -> None:
self.tmpdir.cleanup()

def load_automodel_with_no_weights(self) -> None:
original_model, self.config.model = self.config.model, self.no_weights_model

with fast_weights_init():
original_model, self.config.model = self.config.model, self.no_weights_model
self.load_automodel_from_pretrained()

self.logger.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()

self.config.model = original_model
self.pretrained_model.tie_weights()
self.config.model = original_model

def load_automodel_from_pretrained(self) -> None:
self.pretrained_model = self.automodel_loader.from_pretrained(
self.config.model, **self.automodel_kwargs, **self.config.model_kwargs
self.config.model,
**self.config.model_kwargs,
**self.automodel_kwargs,
).to(self.config.device)

@property
Expand Down
3 changes: 1 addition & 2 deletions optimum_benchmark/backends/torch_ort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class TorchORTConfig(BackendConfig):
# load options
no_weights: bool = False
torch_dtype: Optional[str] = None
# sdpa, which has became default of many architectures, fails with torch ort
attn_implementation: Optional[str] = "eager"
attn_implementation: Optional[str] = None

# peft options
peft_type: Optional[str] = None
Expand Down
35 changes: 22 additions & 13 deletions optimum_benchmark/backends/vllm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from safetensors.torch import save_file
from vllm import AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams

from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
Expand All @@ -32,15 +35,15 @@ def load(self) -> None:
self.logger.info("\t+ Creating no weights model")
self.create_no_weights_model()
self.logger.info("\t+ Loading no weights model")
self.load_model_with_no_weights()
self.load_vllm_with_no_weights()
else:
self.logger.info("\t+ Downloading pretrained model")
self.download_pretrained_model()
if self.config.task in TEXT_GENERATION_TASKS:
self.logger.info("\t+ Preparing generation config")
self.prepare_generation_config()
self.logger.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()
self.load_vllm_from_pretrained()

self.logger.info("\t+ Cleaning up backend temporary directory")
self.tmpdir.cleanup()
Expand All @@ -52,13 +55,11 @@ def download_pretrained_model(self) -> None:
def prepare_generation_config(self) -> None:
self.generation_config.eos_token_id = None
self.generation_config.pad_token_id = None

model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{HUGGINGFACE_HUB_CACHE}/{model_cache_folder}"
snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
self.logger.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=model_snapshot_path)

def create_no_weights_model(self) -> None:
Expand Down Expand Up @@ -92,19 +93,27 @@ def create_no_weights_model(self) -> None:
self.logger.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=self.no_weights_model)

def load_model_with_no_weights(self) -> None:
def load_vllm_with_no_weights(self) -> None:
original_model, self.config.model = self.config.model, self.no_weights_model
self.logger.info("\t+ Loading no weights model")
self.load_model_from_pretrained()
self.load_vllm_from_pretrained()
self.config.model = original_model

def load_model_from_pretrained(self) -> None:
def load_vllm_from_pretrained(self) -> None:
if self.config.serving_mode == "offline":
self.pretrained_model = LLMEngine.from_engine_args(EngineArgs(**self.config.to_engine_args()))
self.pretrained_model = LLMEngine.from_engine_args(EngineArgs(**self.vllm_kwargs))
else:
self.pretrained_model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.config.to_engine_args()))

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
self.pretrained_model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.vllm_kwargs))

@property
def vllm_kwargs(self):
return {
"model": self.config.model,
"tokenizer": self.config.processor,
"device": self.config.device,
**self.config.engine_args,
}

def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task in TEXT_GENERATION_TASKS:
inputs = {"prompts": self.pretrained_processor.batch_decode(inputs["input_ids"])}
else:
Expand Down
8 changes: 0 additions & 8 deletions optimum_benchmark/backends/vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,3 @@ def __post_init__(self):
if self.serving_mode == "online":
if self.engine_args.get("disable_log_requests", None) is None:
self.engine_args["disable_log_requests"] = True

def to_engine_args(self) -> Dict[str, Any]:
return dict(
model=self.model,
tokenizer=self.processor,
device=self.device,
**self.engine_args,
)

0 comments on commit 02dfe3c

Please sign in to comment.