Skip to content

Commit

Permalink
[python] Add tool calling support
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jan 30, 2025
1 parent 7b93555 commit 1ff6552
Show file tree
Hide file tree
Showing 13 changed files with 326 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from djl_python.rolling_batch.rolling_batch_vllm_utils import maybe_serialize_tool_calls
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
Expand All @@ -30,7 +31,6 @@ def parse_chat_completions_request_vllm(
rolling_batch,
tokenizer,
chat_template: Optional[str] = None,
image_token: Optional[str] = None,
configs: Properties = None,
is_mistral_tokenizer: bool = False,
):
Expand All @@ -47,10 +47,30 @@ def parse_chat_completions_request_vllm(
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")

tool_parser = rolling_batch.get_tool_parser()
chat_params = ChatProperties(**input_map)

if chat_params.tool_choice == "required":
raise ValueError("tool_choice = \"required\" is not supported!")

if is_mistral_tokenizer:
maybe_serialize_tool_calls(chat_params)
elif chat_params.tool_choice == "auto" and tool_parser is None:
raise ValueError(
"\"auto\" tool choice requires tool_call_parser to be available")

should_parse_tools = tool_parser is not None and (hasattr(
chat_params, "tool_choice") and chat_params.tool_choice != "none")
if should_parse_tools:
chat_params = tool_parser.adjust_request(request=chat_params)

exclude = {"messages"}
param = chat_params.model_dump(exclude_none=True, exclude=exclude)

tool_dicts = None if chat_params.tools is None else [
tool.model_dump() for tool in chat_params.tools
]

conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer)

Expand All @@ -61,18 +81,22 @@ def parse_chat_completions_request_vllm(
messages=chat_params.messages,
chat_template=chat_template,
add_generation_prompt=True,
tools=tool_dicts,
)
else:
text_inputs = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=True,
tools=tool_dicts,
)

param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"
param["tool_parser"] = tool_parser
param["chat_params"] = chat_params

if mm_data:
param["mm_data"] = mm_data
Expand Down
1 change: 0 additions & 1 deletion engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input,
kwargs.get("is_rolling_batch"),
rolling_batch,
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
Expand Down
101 changes: 94 additions & 7 deletions engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _json_output_formatter(request_output: TextGenerationOutput):
request_output.best_sequence_index]
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, _, is_last_token = best_sequence.get_next_token()
next_token, _, _, is_last_token = best_sequence.get_next_token()
if not request_output.finished:
return ""
details = get_details_dict(request_output, include_tokens=True)
Expand Down Expand Up @@ -141,7 +141,7 @@ def _json_3p_output_formatter(request_output: TextGenerationOutput):
request_output.best_sequence_index]
# TODO: Fix this so it is not required. Right now, this call is needed to
# advance the token iterator, which is needed for rolling batch to work properly
next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token()
if not request_output.finished:
return ""

Expand Down Expand Up @@ -221,7 +221,7 @@ def _jsonlines_output_formatter(request_output: TextGenerationOutput):
parameters = request_output.input.parameters
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, _, last_token = best_sequence.get_next_token()
next_token, _, _, last_token = best_sequence.get_next_token()
# with chunked prefill, we don't generate any tokens until the full prompt has been processed.
# that means we sometimes don't have a token to return
if next_token is None:
Expand All @@ -242,7 +242,7 @@ def _jsonlines_output_formatter(request_output: TextGenerationOutput):
def _jsonlines_3p_output_formatter(request_output: TextGenerationOutput):
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token()
# with chunked prefill, we don't generate any tokens until the full prompt has been processed.
# that means we sometimes don't have a token to return
if next_token is None:
Expand Down Expand Up @@ -282,6 +282,8 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
:return: formatted output
"""
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
generated_text = get_generated_text(best_sequence, request_output)
Expand All @@ -299,6 +301,51 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
if chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
"id": f"chatcmpl-tool-{id(request_output)}",
"type": "function",
"function": {
"name": chat_params.tool_choice.function.name,
"arguments": generated_text
}
}]
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": "",
},
"tool_calls": tool_calls,
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
elif parameters.get("tools") and (parameters.get("tool_choice") == "auto"
or parameters.get("tool_choice") is None
) and parameters.get("tool_parser"):
try:
tool_call_info = tool_parser.extract_tool_calls(
generated_text, request=chat_params)
auto_tools_called = tool_call_info.tools_called
if auto_tools_called:
tool_calls = [
t.model_dump() for t in tool_call_info.tool_calls
]
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": tool_call_info.content,
},
"tool_calls": tool_calls,
"logprobs": None,
"finish_reason": "tool_calls",
}
except RuntimeError:
logging.exception("Failed invoke tool parser")

if parameters.get("logprobs"):
logprobs = {
"content": [
Expand All @@ -317,6 +364,7 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
]
}
choice["logprobs"] = logprobs

prompt_tokens = len(request_output.prompt_tokens_details)
completion_tokens = len(best_sequence.tokens)
usage = {
Expand All @@ -341,16 +389,54 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):
:return: formatted output
"""
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token()
# with chunked prefill, we don't generate any tokens until the full prompt has been processed.
# that means we sometimes don't have a token to return
if next_token is None:
return ""

created = int(time.time())
delta = {"content": next_token.text}

if chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
"index": 0,
"function": {
"name": chat_params.tool_choice.function.name,
"arguments": next_token.text
}
}]
delta = {"tool_calls": tool_calls}
elif parameters.get("tools") and (parameters.get("tool_choice") == "auto"
or parameters.get("tool_choice") is None
) and parameters.get("tool_parser"):
current_text = get_generated_text(best_sequence, request_output)
previous_text = current_text[0:-len(next_token.text)]
current_token_ids = [t.id for t in best_sequence.tokens]
previous_token_ids = current_token_ids[:-1]
tool_call_info = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=next_token.text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=[next_token.id],
request=chat_params)
if tool_call_info is None:
return ""
tool_calls = [
t.model_dump(exclude_none=True) for t in tool_call_info.tool_calls
]
delta = {
"tool_calls": tool_calls,
}
else:
delta = {"content": next_token.text}
if first_token:
delta["role"] = "assistant"

Expand Down Expand Up @@ -423,7 +509,8 @@ def adapt_legacy_output_formatter(request_output: TextGenerationOutput) -> str:
elif best_sequence.finish_reason == "error":
details_dict["finish_reason"] = best_sequence.finish_reason

next_token, first_token, last_token = best_sequence.get_next_token()
next_token, index, first_token, last_token = best_sequence.get_next_token(
)
if last_token:
for token in best_sequence.tokens:
generated_text += token.text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import ast
import json
from enum import Enum
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator

from djl_python.properties_manager.properties import Properties
from vllm.entrypoints.openai.tool_parsers import ToolParserManager


class VllmRbProperties(Properties):
Expand Down Expand Up @@ -74,6 +73,10 @@ class VllmRbProperties(Properties):
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None

# Tool calling properties
enable_auto_tool_choice: Optional[bool] = False
tool_call_parser: Optional[str] = None

@field_validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
Expand Down Expand Up @@ -147,3 +150,12 @@ def validate_pipeline_parallel(self):
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self

@model_validator(mode='after')
def validate_tool_call_parser(self):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if self.enable_auto_tool_choice \
and self.tool_call_parser not in valid_tool_parses:
raise ValueError(
f"Invalid tool call parser: {self.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})")
4 changes: 2 additions & 2 deletions engines/python/setup/djl_python/request_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def get_next_token(self) -> (Token, bool, bool):
index = self._tokens_iterator.next_index()
first_token = index == 0
last_token = index == self._last_token_index
return self.tokens[index], first_token, last_token
return None, False, False
return self.tokens[index], index, first_token, last_token
return None, 0, False, False

def get_last_token(self) -> Optional[Token]:
if self._last_token_index is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def use_vllm_chat_completions(self):
"""
return False

def get_tool_parser(self):
"""
:return: the tool call parser if available
"""
return None

@abstractmethod
def inference(self, new_requests: List[Request]) -> List:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,19 @@ def get_prompt_inputs(request: Request):
if multi_modal_data is not None:
prompt["multi_modal_data"] = multi_modal_data
return prompt


def maybe_serialize_tool_calls(request):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.0/vllm/transformers_utils/tokenizers/mistral.py#L34-L68
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get("tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break

request.messages[i]["tool_calls"] = validated_tool_calls
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand All @@ -23,7 +24,7 @@
update_request_cache_with_output, create_lora_request, get_lora_request,
get_engine_args_from_config, get_prompt_inputs)
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List, Optional
from typing import Callable, List, Optional

# FIXME: Once all vllm versions are past 0.6.0 we can move to just struct_fields
VLLM_GENERATION_PARAMS = set(SamplingParams().__struct_fields__) if hasattr(
Expand Down Expand Up @@ -55,6 +56,18 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = {}
self.is_mistral_tokenizer = self.vllm_configs.tokenizer_mode == 'mistral'
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.vllm_configs.enable_auto_tool_choice:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
self.vllm_configs.tool_call_parser)
self.tool_parser = self.tool_parser(
self.engine.tokenizer.tokenizer)
except Exception as e:
raise TypeError(
"Error: option.enable_auto_tools requires "
f"tool call parser:'{self.vllm_configs.tool_call_parser}' which has not "
"been registered") from e

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand All @@ -68,6 +81,9 @@ def get_huggingface_model_config(self):
def use_vllm_chat_completions(self):
return True

def get_tool_parser(self):
return self.tool_parser

def reset(self) -> None:
"""
Aborts all requests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def custom_fmt_wait(request_output: TextGenerationOutput):
sequence_index = request_output.best_sequence_index
best_sequence = request_output.sequences[
request_output.best_sequence_index]
_, _, last_token = best_sequence.get_next_token()
_, _, _, last_token = best_sequence.get_next_token()
if last_token:
tokens = best_sequence.tokens
generated_text = ""
Expand Down
Loading

0 comments on commit 1ff6552

Please sign in to comment.