From fde61a80a58de0401fdecdee7408db53e17ca4f4 Mon Sep 17 00:00:00 2001 From: Morteza Date: Wed, 10 Jan 2024 02:51:17 -0700 Subject: [PATCH] Add CFG to vllm serving --- docs/reference/vllm.md | 13 +++++- outlines/serve/serve.py | 4 ++ outlines/serve/vllm.py | 100 ++++++++++++++++++++++++++-------------- 3 files changed, 81 insertions(+), 36 deletions(-) diff --git a/docs/reference/vllm.md b/docs/reference/vllm.md index f3232ca10..38bbc21d0 100644 --- a/docs/reference/vllm.md +++ b/docs/reference/vllm.md @@ -24,8 +24,9 @@ You can then query the model in shell by passing a prompt and either 1. a [JSON Schema][jsonschema]{:target="_blank"} specification or 2. a [Regex][regex]{:target="_blank"} pattern +2. an EBNF grammar -with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. +with the `schema`, `regex` of `cfg` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. For example, to generate a string that matches the schema `{"type": "string"}` (any string): @@ -47,6 +48,16 @@ curl http://127.0.0.1:8000/generate \ }' ``` +To generate a string that matches the grammar ``: + +```bash +curl http://127.0.0.1:8000/generate \ + -d '{ + "prompt": "What is Pi? Give me the first 15 digits: ", + "cfg": + }' +``` + Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program. Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs. diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index a669c5b50..7bcf13af8 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -25,6 +25,7 @@ from vllm.utils import random_uuid from .vllm import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor, _patched_apply_logits_processors, @@ -65,10 +66,13 @@ async def generate(request: Request) -> Response: json_schema = request_dict.pop("schema", None) regex_string = request_dict.pop("regex", None) + cfg_string = request_dict.pop("cfg", None) if json_schema is not None: logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] elif regex_string is not None: logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)] + elif cfg_string is not None: + logits_processors = [CFGLogitsProcessor(cfg_string, engine.engine)] else: logits_processors = [] diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index bbf0a50c3..ed63a5f3f 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -2,14 +2,50 @@ import json import math from collections import defaultdict -from typing import DefaultDict, List +from typing import Callable, DefaultDict, List import torch -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_object +def _adapt_tokenizer(tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], str] + ) -> Callable[[List[int]], List[str]]: + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + + return tokenizer + + def _patched_apply_logits_processors( logits, sampling_metadata, @@ -39,21 +75,9 @@ def _patched_apply_logits_processors( return logits -class RegexLogitsProcessor: - def __init__(self, regex_string, llm): - """Compile the FSM that drives the regex-guided generation. - - Parameters - ---------- - regex_string - A string that represents a regular expression - llm - An instance of `vllm.LLM` - - """ - tokenizer = self.adapt_tokenizer(llm.tokenizer) - - fsm = RegexFSM(regex_string, tokenizer) +class FSMLogitsProcessor: + def __init__(self): + fsm = FSM() self.fsm = fsm def __call__( @@ -77,31 +101,37 @@ def __call__( return biased_scores - def adapt_tokenizer(self, tokenizer): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) +class RegexLogitsProcessor(FSMLogitsProcessor): + def __init__(self, regex_string, llm): + """Compile the FSM that drives the regex-guided generation. - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE + Parameters + ---------- + regex_string + A string that represents a regular expression + llm + An instance of `vllm.LLM` - string = tokenizer.convert_tokens_to_string([token]) + """ + fsm = RegexFSM(regex_string, llm.tokenizer) + self.fsm = fsm - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - return string +class CFGLogitsProcessor(FSMLogitsProcessor): + def __init__(self, cfg_string, llm): + """Compile the FSM that drives the cfg-guided generation. - tokenizer.convert_token_to_string = convert_token_to_string + Parameters + ---------- + regex_string + A string that represents a regular expression + llm + An instance of `vllm.LLM` - return tokenizer + """ + fsm = CFGFSM(cfg_string, llm.tokenizer) + self.fsm = fsm class JSONLogitsProcessor(RegexLogitsProcessor):