diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 index 8dc3a7e..321fa31 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ pypdf2==3.0.1 sentence-transformers==2.2.2 sentencepiece==0.1.99 tiktoken==0.3.3 -tokenizers==0.13.3 -transformers==4.31.0 +tokenizers==0.15.2 +transformers==4.36.1 faiss-cpu==1.7.4 -mpi4py==3.1.5 \ No newline at end of file +mpi4py==3.1.5 diff --git a/trt_llama_api.py b/trt_llama_api.py old mode 100644 new mode 100755 index 09fecd7..29a587d --- a/trt_llama_api.py +++ b/trt_llama_api.py @@ -47,6 +47,14 @@ from pathlib import Path import uuid import time +from tensorrt_llm.logger import logger +from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner + +if PYTHON_BINDINGS: + from tensorrt_llm.runtime import ModelRunnerCpp + +from utils import (DEFAULT_HF_MODEL_DIRS, DEFAULT_PROMPT_TEMPLATES, + load_tokenizer, read_model_name, throttle_generator) EOS_TOKEN = 2 PAD_TOKEN = 2 @@ -80,6 +88,9 @@ class TrtLlmAPI(CustomLLM): _max_new_tokens = PrivateAttr() _sampling_config = PrivateAttr() _verbose = PrivateAttr() + _max_input_len = PrivateAttr() + _model = PrivateAttr() + _remove_input_padding = PrivateAttr() def __init__( self, @@ -116,58 +127,41 @@ def __init__( # config function with open(config_path, 'r') as f: config = json.load(f) - use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin'] - remove_input_padding = config['plugin_config']['remove_input_padding'] - tp_size = config['builder_config']['tensor_parallel'] - pp_size = config['builder_config']['pipeline_parallel'] - world_size = tp_size * pp_size - assert world_size == tensorrt_llm.mpi_world_size(), \ - f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' - num_heads = config['builder_config']['num_heads'] // tp_size - hidden_size = config['builder_config']['hidden_size'] // tp_size - vocab_size = config['builder_config']['vocab_size'] - num_layers = config['builder_config']['num_layers'] - num_kv_heads = config['builder_config'].get('num_kv_heads', num_heads) - paged_kv_cache = config['plugin_config']['paged_kv_cache'] - if config['builder_config'].get('multi_query_mode', False): - tensorrt_llm.logger.warning( - "`multi_query_mode` config is deprecated. Please rebuild the engine." - ) - num_kv_heads = 1 - num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size - - self._model_config = ModelConfig(num_heads=num_heads, - num_kv_heads=num_kv_heads, - hidden_size=hidden_size, - vocab_size=vocab_size, - num_layers=num_layers, - gpt_attention_plugin=use_gpt_attention_plugin, - paged_kv_cache=paged_kv_cache, - remove_input_padding=remove_input_padding) - - assert pp_size == 1, 'Python runtime does not support pipeline parallelism' - world_size = tp_size * pp_size + use_gpt_attention_plugin = config['build_config']['plugin_config']['gpt_attention_plugin'] + remove_input_padding = config['build_config']['plugin_config']['remove_input_padding'] + num_heads = config['pretrained_config']['num_attention_heads'] # // tp_size + hidden_size = config['pretrained_config']['hidden_size'] # // tp_size + vocab_size = config['pretrained_config']['vocab_size'] + num_layers = config['pretrained_config']['num_hidden_layers'] + paged_kv_cache = config['build_config']['plugin_config']['paged_kv_cache'] + max_batch_size = config['build_config']['max_batch_size'] + max_beam_width = config['build_config']['max_beam_width'] runtime_rank = tensorrt_llm.mpi_rank() - runtime_mapping = tensorrt_llm.Mapping(world_size, - runtime_rank, - tp_size=tp_size, - pp_size=pp_size) - torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) + logger.set_level('error') + self._tokenizer = LlamaTokenizer.from_pretrained(tokenizer_dir, legacy=False) - self._sampling_config = SamplingConfig(end_id=EOS_TOKEN, - pad_id=PAD_TOKEN, - num_beams=1, - temperature=temperature) - - serialize_path = engine_dir_path / engine_name - with open(serialize_path, 'rb') as f: - engine_buffer = f.read() - decoder = tensorrt_llm.runtime.GenerationSession(self._model_config, - engine_buffer, - runtime_mapping, - debug_mode=False) - self._model = decoder + self._remove_input_padding = remove_input_padding + + runner_cls = ModelRunner if not PYTHON_BINDINGS else ModelRunnerCpp + runner_kwargs = dict(engine_dir=engine_dir, + lora_dir=None, + rank=runtime_rank, + debug_mode=False, + lora_ckpt_source="hf") + runner_kwargs.update( + max_batch_size=max_batch_size, + max_input_len=config['build_config']['max_input_len'], + max_output_len=1024, + max_beam_width=max_beam_width, + max_attention_window_size=None, + sink_token_length=None + ) + runner = runner_cls.from_dir(**runner_kwargs) + + self._max_input_len = config['build_config']['max_input_len'] + + self._model = runner messages_to_prompt = messages_to_prompt or generic_messages_to_prompt completion_to_prompt = completion_to_prompt or (lambda x: x) @@ -220,14 +214,32 @@ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: input_text = prompt input_ids, input_lengths = self.parse_input(input_text, self._tokenizer, EOS_TOKEN, - self._model_config) - - max_input_length = torch.max(input_lengths).item() - self._model.setup(input_lengths.size(0), max_input_length, self._max_new_tokens, 1) # beam size is set to 1 - if self._verbose: - start_time = time.time() - - output_ids = self._model.decode(input_ids, input_lengths, self._sampling_config) + self._remove_input_padding) + + outputs = self._model.generate( + input_ids, + max_new_tokens=self._max_new_tokens, + max_attention_window_size=None, + sink_token_length=None, + end_id=EOS_TOKEN, + pad_id=PAD_TOKEN, + temperature=1.0, + top_k=1, + top_p=0.0, + num_beams=1, + length_penalty=1.0, + repetition_penalty=1.0, + presence_penalty=0.0, + frequency_penalty=0.0, + stop_words_list=None, + bad_words_list=None, + lora_uids=None, + prompt_table_path=None, + prompt_tasks=None, + streaming=False, + output_sequence_lengths=True, + return_dict=True, + medusa_choices=None) torch.cuda.synchronize() elapsed_time = None @@ -236,7 +248,7 @@ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: elapsed_time = end_time - start_time - output_txt, output_token_ids = self.get_output(output_ids, + output_txt, output_token_ids = self.get_output(outputs['output_ids'], input_lengths, self._max_new_tokens, self._tokenizer)