From 0c6e6b467c13784e39256b37a3786d9762f6a14d Mon Sep 17 00:00:00 2001 From: Henry Chen <1474479+chenhunghan@users.noreply.github.com> Date: Fri, 3 Nov 2023 18:59:43 +0200 Subject: [PATCH] =?UTF-8?q?Add=20support=20for=20OpenChat=203.5/Zephyr=207?= =?UTF-8?q?B=20=CE=B2,=20improve=20fallbacks=20of=20`repetition=5Fpenalty`?= =?UTF-8?q?,=20support=20multiple=20messages=20in=20request=20body=20=20(#?= =?UTF-8?q?82)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Hung-Han (Henry) Chen --- .github/workflows/helm_lint_test.yaml | 2 +- .github/workflows/smoke_test.yaml | 18 ++--- README.md | 63 ++++++++++++++++- charts/ialacol/Chart.yaml | 4 +- get_config.py | 13 +++- main.py | 99 ++++++++++++++++++--------- request_body.py | 4 +- 7 files changed, 152 insertions(+), 51 deletions(-) diff --git a/.github/workflows/helm_lint_test.yaml b/.github/workflows/helm_lint_test.yaml index 67d9ec2..ad99741 100644 --- a/.github/workflows/helm_lint_test.yaml +++ b/.github/workflows/helm_lint_test.yaml @@ -22,7 +22,7 @@ jobs: check-latest: true - name: Set up chart-testing - uses: helm/chart-testing-action@v2.4.0 + uses: helm/chart-testing-action@v2.6.1 - name: Run chart-testing (list-changed) id: list-changed diff --git a/.github/workflows/smoke_test.yaml b/.github/workflows/smoke_test.yaml index fe465e9..393dcbe 100644 --- a/.github/workflows/smoke_test.yaml +++ b/.github/workflows/smoke_test.yaml @@ -55,7 +55,7 @@ jobs: push: true tags: | ${{ env.REGISTRY }}/${{ env.REPO_ORG_NAME }}/${{ env.IMAGE_NAME }}:${{ github.sha }} - build-gptq-cuda12-image: + build-gptq-image: runs-on: ubuntu-latest steps: - name: Checkout @@ -74,7 +74,7 @@ jobs: uses: docker/build-push-action@v4 with: context: . - file: ./Dockerfile.cuda12 + file: ./Dockerfile.gptq push: true tags: | ${{ env.REGISTRY }}/${{ env.REPO_ORG_NAME }}/${{ env.IMAGE_NAME }}:${{ env.GPTQ_IMAGE_TAG }} @@ -124,7 +124,7 @@ jobs: helm install $LLAMA_HELM_RELEASE_NAME -f values.yaml --namespace $HELM_NAMESPACE ./charts/ialacol echo "Wait for the pod to be ready, it takes about 36s to download a 1.93GB model (~50MB/s)" - sleep 40 + sleep 120 - if: always() run: | kubectl get pods -n $HELM_NAMESPACE @@ -215,7 +215,7 @@ jobs: helm install $GPT_NEOX_HELM_RELEASE_NAME -f values.yaml --namespace $HELM_NAMESPACE ./charts/ialacol echo "Wait for the pod to be ready, it takes about 36s to download a 1.93GB model (~50MB/s)" - sleep 40 + sleep 120 - if: always() run: | kubectl get pods -n $HELM_NAMESPACE @@ -283,7 +283,7 @@ jobs: helm install $STARCODER_HELM_RELEASE_NAME -f values.yaml --namespace $HELM_NAMESPACE ./charts/ialacol echo "Wait for the pod to be ready" - sleep 20 + sleep 120 - if: always() run: | kubectl get pods -n $HELM_NAMESPACE @@ -303,7 +303,7 @@ jobs: kubectl logs --tail=200 --selector app.kubernetes.io/name=$STARCODER_HELM_RELEASE_NAME -n $HELM_NAMESPACE gptq-smoke-test: runs-on: ubuntu-latest - needs: build-gptq-cuda12-image + needs: build-gptq-image steps: - name: Create k8s Kind Cluster uses: helm/kind-action@v1.7.0 @@ -323,7 +323,7 @@ jobs: cat > values.yaml <\nYou are a friendly chatbot who always responds in the style of a pirate.\n", + "userMessageToken": "<|user|>\n", + "userMessageEndToken": "\n", + "assistantMessageToken": "<|assistant|>\n", + "assistantMessageEndToken": "\n", + "parameters": { + "temperature": 0.1, + "top_p": 0.95, + "repetition_penalty": 1.2, + "top_k": 50, + "max_new_tokens": 4096, + "truncate": 999999 + }, + "endpoints" : [{ + "type": "openai", + "baseURL": "http://localhost:8000/v1", + "completion": "chat_completions" + }] + } +] +``` + +[openchat_3.5.Q4_K_M.gguf](https://huggingface.co/openchat/openchat_3.5) +```shell +MODELS=`[ + { + "name": "openchat_3.5.Q4_K_M.gguf", + "displayName": "OpenChat 3.5", + "preprompt": "", + "userMessageToken": "GPT4 User: ", + "userMessageEndToken": "<|end_of_turn|>", + "assistantMessageToken": "GPT4 Assistant: ", + "assistantMessageEndToken": "<|end_of_turn|>", + "parameters": { + "temperature": 0.1, + "top_p": 0.95, + "repetition_penalty": 1.2, + "top_k": 50, + "max_new_tokens": 4096, + "truncate": 999999, + "stop": ["<|end_of_turn|>"] + }, + "endpoints" : [{ + "type": "openai", + "baseURL": "http://localhost:8000/v1", + "completion": "chat_completions" + }] + } +]` +``` + ## Blogs - [Use Code Llama (and other open LLMs) as Drop-In Replacement for Copilot Code Completion](https://dev.to/chenhunghan/use-code-llama-and-other-open-llms-as-drop-in-replacement-for-copilot-code-completion-58hg) diff --git a/charts/ialacol/Chart.yaml b/charts/ialacol/Chart.yaml index 5d1d626..dcbb032 100644 --- a/charts/ialacol/Chart.yaml +++ b/charts/ialacol/Chart.yaml @@ -1,6 +1,6 @@ apiVersion: v2 -appVersion: 0.12.0 +appVersion: 0.13.0 description: A Helm chart for ialacol name: ialacol type: application -version: 0.12.0 +version: 0.13.0 diff --git a/get_config.py b/get_config.py index 9c08c48..02c1cc3 100644 --- a/get_config.py +++ b/get_config.py @@ -8,6 +8,7 @@ THREADS = int(get_env("THREADS", str(get_default_thread()))) + def get_config( body: CompletionRequestBody | ChatCompletionRequestBody, ) -> Config: @@ -28,8 +29,10 @@ def get_config( # OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-max_tokens MAX_TOKENS = int(get_env("MAX_TOKENS", DEFAULT_MAX_TOKENS)) CONTEXT_LENGTH = int(get_env("CONTEXT_LENGTH", DEFAULT_CONTEXT_LENGTH)) - if (MAX_TOKENS > CONTEXT_LENGTH): - log.warning("MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH") + if MAX_TOKENS > CONTEXT_LENGTH: + log.warning( + "MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH" + ) # OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-stop STOP = get_env_or_none("STOP") @@ -48,7 +51,11 @@ def get_config( top_p = body.top_p if body.top_p else TOP_P temperature = body.temperature if body.temperature else TEMPERATURE repetition_penalty = ( - body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY + body.frequency_penalty + if body.frequency_penalty + else ( + body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY + ) ) last_n_tokens = body.last_n_tokens if body.last_n_tokens else LAST_N_TOKENS seed = body.seed if body.seed else SEED diff --git a/main.py b/main.py index 6b936cd..d69b3ec 100644 --- a/main.py +++ b/main.py @@ -31,9 +31,7 @@ DEFAULT_MODEL_HG_REPO_ID = get_env( "DEFAULT_MODEL_HG_REPO_ID", "TheBloke/Llama-2-7B-Chat-GGML" ) -DEFAULT_MODEL_HG_REPO_REVISION = get_env( - "DEFAULT_MODEL_HG_REPO_REVISION", "main" -) +DEFAULT_MODEL_HG_REPO_REVISION = get_env("DEFAULT_MODEL_HG_REPO_REVISION", "main") DEFAULT_MODEL_FILE = get_env("DEFAULT_MODEL_FILE", "llama-2-7b-chat.ggmlv3.q4_0.bin") log.info("DEFAULT_MODEL_HG_REPO_ID: %s", DEFAULT_MODEL_HG_REPO_ID) @@ -70,13 +68,17 @@ def set_loading_model(boolean: bool): app = FastAPI() + # https://github.com/tiangolo/fastapi/issues/3361 @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") log.error("%s: %s", request, exc_str) content = {"status_code": 10422, "message": exc_str, "data": None} - return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) + return JSONResponse( + content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY + ) + @app.on_event("startup") async def startup_event(): @@ -370,36 +372,67 @@ async def chat_completions( assistant_start = "### Assistant:\n" user_start = "### User:\n" user_end = "\n\n" - - user_message = next( - (message for message in body.messages if message.role == "user"), None - ) - user_message_content = user_message.content if user_message else "" - assistant_message = next( - (message for message in body.messages if message.role == "assistant"), None - ) - assistant_message_content = ( - f"{assistant_start}{assistant_message.content}{assistant_end}" - if assistant_message - else "" - ) - system_message = next( - (message for message in body.messages if message.role == "system"), None - ) - system_message_content = system_message.content if system_message else system - # avoid duplicate user start token in prompt if user message already includes it - if len(user_start) > 0 and user_start in user_message_content: - user_start = "" - # avoid duplicate user end token in prompt if user message already includes it - if len(user_end) > 0 and user_end in user_message_content: - user_end = "" - # avoid duplicate assistant start token in prompt if user message already includes it - if len(assistant_start) > 0 and assistant_start in user_message_content: - assistant_start = "" - # avoid duplicate system_start token in prompt if system_message_content already includes it - if len(system_start) > 0 and system_start in system_message_content: + # openchat_3.5 https://huggingface.co/openchat/openchat_3.5 + if "openchat" in body.model.lower(): system_start = "" - prompt = f"{system_start}{system_message_content}{system_end}{assistant_message_content}{user_start}{user_message_content}{user_end}{assistant_start}" + system = "" + system_end = "" + assistant_start = "GPT4 Assistant: " + assistant_end = "<|end_of_turn|>" + user_start = "GPT4 User: " + user_end = "<|end_of_turn|>" + # HG's zephyr https://huggingface.co/HuggingFaceH4/zephyr-7b-beta + if "zephyr" in body.model.lower(): + system_start = "<|system|>\n" + system = "" + system_end = "\n" + assistant_start = "<|assistant|>" + assistant_end = "\n" + user_start = "<|user|>\n" + user_end = "\n" + + prompt = "" + for message in body.messages: + # Check for system message + if message.role == "system": + system_message_content = message.content if message else "" + + # avoid duplicate system_start token in prompt if system_message_content already includes it + if len(system_start) > 0 and system_start in system_message_content: + system_start = "" + # avoid duplicate system_end token in prompt if system_message_content already includes it + if len(system_end) > 0 and system_end in system_message_content: + system_end = "" + prompt = f"{system_start}{system_message_content}{system_end}" + elif message.role == "user": + user_message_content = message.content if message else "" + + # avoid duplicate user start token in prompt if user_message_content already includes it + if len(user_start) > 0 and user_start in user_message_content: + user_start = "" + # avoid duplicate user end token in prompt if user_message_content already includes it + if len(user_end) > 0 and user_end in user_message_content: + user_end = "" + + prompt = f"{prompt}{user_start}{user_message_content}{user_end}" + elif message.role == "assistant": + assistant_message_content = message.content if message else "" + + # avoid duplicate assistant start token in prompt if user message already includes it + if ( + len(assistant_start) > 0 + and assistant_start in assistant_message_content + ): + assistant_start = "" + # avoid duplicate assistant start token in prompt if user message already includes it + if len(assistant_end) > 0 and assistant_end in assistant_message_content: + assistant_end = "" + + prompt = ( + f"{prompt}{assistant_start}{assistant_message_content}{assistant_end}" + ) + + prompt = f"{prompt}{assistant_start}" model_name = body.model llm = request.app.state.llm if body.stream is True: diff --git a/request_body.py b/request_body.py index d6e3859..ce2999e 100644 --- a/request_body.py +++ b/request_body.py @@ -24,6 +24,7 @@ class CompletionRequestBody(BaseModel): # llama.cpp specific parameters top_k: Optional[int] repetition_penalty: Optional[float] + frequency_penalty: Optional[float] last_n_tokens: Optional[int] seed: Optional[int] batch_size: Optional[int] @@ -32,7 +33,6 @@ class CompletionRequestBody(BaseModel): # ignored or currently unsupported suffix: Any presence_penalty: Any - frequency_penalty: Any echo: Any n: Any logprobs: Any @@ -73,6 +73,7 @@ class ChatCompletionRequestBody(BaseModel): # llama.cpp specific parameters top_k: Optional[int] repetition_penalty: Optional[float] + frequency_penalty: Optional[float] last_n_tokens: Optional[int] seed: Optional[int] batch_size: Optional[int] @@ -83,7 +84,6 @@ class ChatCompletionRequestBody(BaseModel): logit_bias: Any user: Any presence_penalty: Any - frequency_penalty: Any class Config: arbitrary_types_allowed = True