Skip to content
This repository has been archived by the owner on Mar 30, 2024. It is now read-only.

Commit

Permalink
Add support for OpenChat 3.5/Zephyr 7B β, improve fallbacks of `repet…
Browse files Browse the repository at this point in the history
…ition_penalty`, support multiple messages in request body (#82)

Signed-off-by: Hung-Han (Henry) Chen <[email protected]>
  • Loading branch information
chenhunghan authored Nov 3, 2023
1 parent d84b08e commit 0c6e6b4
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/helm_lint_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
check-latest: true

- name: Set up chart-testing
uses: helm/chart-testing-action@v2.4.0
uses: helm/chart-testing-action@v2.6.1

- name: Run chart-testing (list-changed)
id: list-changed
Expand Down
18 changes: 9 additions & 9 deletions .github/workflows/smoke_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
push: true
tags: |
${{ env.REGISTRY }}/${{ env.REPO_ORG_NAME }}/${{ env.IMAGE_NAME }}:${{ github.sha }}
build-gptq-cuda12-image:
build-gptq-image:
runs-on: ubuntu-latest
steps:
- name: Checkout
Expand All @@ -74,7 +74,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ./Dockerfile.cuda12
file: ./Dockerfile.gptq
push: true
tags: |
${{ env.REGISTRY }}/${{ env.REPO_ORG_NAME }}/${{ env.IMAGE_NAME }}:${{ env.GPTQ_IMAGE_TAG }}
Expand Down Expand Up @@ -124,7 +124,7 @@ jobs:
helm install $LLAMA_HELM_RELEASE_NAME -f values.yaml --namespace $HELM_NAMESPACE ./charts/ialacol
echo "Wait for the pod to be ready, it takes about 36s to download a 1.93GB model (~50MB/s)"
sleep 40
sleep 120
- if: always()
run: |
kubectl get pods -n $HELM_NAMESPACE
Expand Down Expand Up @@ -215,7 +215,7 @@ jobs:
helm install $GPT_NEOX_HELM_RELEASE_NAME -f values.yaml --namespace $HELM_NAMESPACE ./charts/ialacol
echo "Wait for the pod to be ready, it takes about 36s to download a 1.93GB model (~50MB/s)"
sleep 40
sleep 120
- if: always()
run: |
kubectl get pods -n $HELM_NAMESPACE
Expand Down Expand Up @@ -283,7 +283,7 @@ jobs:
helm install $STARCODER_HELM_RELEASE_NAME -f values.yaml --namespace $HELM_NAMESPACE ./charts/ialacol
echo "Wait for the pod to be ready"
sleep 20
sleep 120
- if: always()
run: |
kubectl get pods -n $HELM_NAMESPACE
Expand All @@ -303,7 +303,7 @@ jobs:
kubectl logs --tail=200 --selector app.kubernetes.io/name=$STARCODER_HELM_RELEASE_NAME -n $HELM_NAMESPACE
gptq-smoke-test:
runs-on: ubuntu-latest
needs: build-gptq-cuda12-image
needs: build-gptq-image
steps:
- name: Create k8s Kind Cluster
uses: helm/[email protected]
Expand All @@ -323,7 +323,7 @@ jobs:
cat > values.yaml <<EOF
replicas: 1
deployment:
image: ${{ env.REGISTRY }}/${{ env.REPO_ORG_NAME }}/${{ env.IMAGE_NAME }}:${{ ${{ env.GPTQ_IMAGE_TAG }}
image: ${{ env.REGISTRY }}/${{ env.REPO_ORG_NAME }}/${{ env.IMAGE_NAME }}:${{ env.GPTQ_IMAGE_TAG }}
env:
DEFAULT_MODEL_HG_REPO_ID: $GPTQ_MODEL_HG_REPO_ID
DEFAULT_MODEL_HG_REPO_REVISION: $GPTQ_MODEL_HG_REVISION
Expand All @@ -347,8 +347,8 @@ jobs:
EOF
helm install $GPTQ_HELM_RELEASE_NAME -f values.yaml --namespace $HELM_NAMESPACE ./charts/ialacol
echo "Wait for the pod to be ready, it takes about 36s to download a 1.93GB model (~50MB/s)"
sleep 40
echo "Wait for the pod to be ready, GPTQ image is around 1GB"
sleep 240
- if: always()
run: |
kubectl get pods -n $HELM_NAMESPACE
Expand Down
63 changes: 62 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ialacol is inspired by other similar projects like [LocalAI](https://github.com/

See [Receipts](#receipts) below for instructions of deployments.

- [LLaMa 2 variants](https://huggingface.co/meta-llama), including [OpenLLaMA](https://github.com/openlm-research/open_llama), [Mistral](https://huggingface.co/mistralai/Mistral-7B-v0.1).
- [LLaMa 2 variants](https://huggingface.co/meta-llama), including [OpenLLaMA](https://github.com/openlm-research/open_llama), [Mistral](https://huggingface.co/mistralai/Mistral-7B-v0.1), [openchat_3.5](https://huggingface.co/openchat/openchat_3.5) and [zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta).
- [StarCoder variants](https://huggingface.co/bigcode/starcoder)
- [WizardCoder](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- [StarChat variants](https://huggingface.co/HuggingFaceH4/starchat-beta)
Expand All @@ -32,6 +32,67 @@ See [Receipts](#receipts) below for instructions of deployments.

And all LLMs supported by [ctransformers](https://github.com/marella/ctransformers/tree/main/models/llms).

## UI

`ialacol` does not have a UI, however it's compatible with any web UI that support OpenAI API, for example [chat-ui](https://github.com/huggingface/chat-ui) after [PR #541](https://github.com/huggingface/chat-ui/pull/541) merged.

Assuming `ialacol` running at port 8000, you can configure [chat-ui](https://github.com/huggingface/chat-ui) to use [`zephyr-7b-beta.Q4_K_M.gguf`](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) served by `ialacol`.
```shell
MODELS=`[
{
"name": "zephyr-7b-beta.Q4_K_M.gguf",
"displayName": "Zephyr 7B β",
"preprompt": "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.</s>\n",
"userMessageToken": "<|user|>\n",
"userMessageEndToken": "</s>\n",
"assistantMessageToken": "<|assistant|>\n",
"assistantMessageEndToken": "\n",
"parameters": {
"temperature": 0.1,
"top_p": 0.95,
"repetition_penalty": 1.2,
"top_k": 50,
"max_new_tokens": 4096,
"truncate": 999999
},
"endpoints" : [{
"type": "openai",
"baseURL": "http://localhost:8000/v1",
"completion": "chat_completions"
}]
}
]
```

[openchat_3.5.Q4_K_M.gguf](https://huggingface.co/openchat/openchat_3.5)
```shell
MODELS=`[
{
"name": "openchat_3.5.Q4_K_M.gguf",
"displayName": "OpenChat 3.5",
"preprompt": "",
"userMessageToken": "GPT4 User: ",
"userMessageEndToken": "<|end_of_turn|>",
"assistantMessageToken": "GPT4 Assistant: ",
"assistantMessageEndToken": "<|end_of_turn|>",
"parameters": {
"temperature": 0.1,
"top_p": 0.95,
"repetition_penalty": 1.2,
"top_k": 50,
"max_new_tokens": 4096,
"truncate": 999999,
"stop": ["<|end_of_turn|>"]
},
"endpoints" : [{
"type": "openai",
"baseURL": "http://localhost:8000/v1",
"completion": "chat_completions"
}]
}
]`
```

## Blogs

- [Use Code Llama (and other open LLMs) as Drop-In Replacement for Copilot Code Completion](https://dev.to/chenhunghan/use-code-llama-and-other-open-llms-as-drop-in-replacement-for-copilot-code-completion-58hg)
Expand Down
4 changes: 2 additions & 2 deletions charts/ialacol/Chart.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
apiVersion: v2
appVersion: 0.12.0
appVersion: 0.13.0
description: A Helm chart for ialacol
name: ialacol
type: application
version: 0.12.0
version: 0.13.0
13 changes: 10 additions & 3 deletions get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

THREADS = int(get_env("THREADS", str(get_default_thread())))


def get_config(
body: CompletionRequestBody | ChatCompletionRequestBody,
) -> Config:
Expand All @@ -28,8 +29,10 @@ def get_config(
# OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-max_tokens
MAX_TOKENS = int(get_env("MAX_TOKENS", DEFAULT_MAX_TOKENS))
CONTEXT_LENGTH = int(get_env("CONTEXT_LENGTH", DEFAULT_CONTEXT_LENGTH))
if (MAX_TOKENS > CONTEXT_LENGTH):
log.warning("MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH")
if MAX_TOKENS > CONTEXT_LENGTH:
log.warning(
"MAX_TOKENS is greater than CONTEXT_LENGTH, setting MAX_TOKENS < CONTEXT_LENGTH"
)
# OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-stop
STOP = get_env_or_none("STOP")

Expand All @@ -48,7 +51,11 @@ def get_config(
top_p = body.top_p if body.top_p else TOP_P
temperature = body.temperature if body.temperature else TEMPERATURE
repetition_penalty = (
body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY
body.frequency_penalty
if body.frequency_penalty
else (
body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY
)
)
last_n_tokens = body.last_n_tokens if body.last_n_tokens else LAST_N_TOKENS
seed = body.seed if body.seed else SEED
Expand Down
99 changes: 66 additions & 33 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
DEFAULT_MODEL_HG_REPO_ID = get_env(
"DEFAULT_MODEL_HG_REPO_ID", "TheBloke/Llama-2-7B-Chat-GGML"
)
DEFAULT_MODEL_HG_REPO_REVISION = get_env(
"DEFAULT_MODEL_HG_REPO_REVISION", "main"
)
DEFAULT_MODEL_HG_REPO_REVISION = get_env("DEFAULT_MODEL_HG_REPO_REVISION", "main")
DEFAULT_MODEL_FILE = get_env("DEFAULT_MODEL_FILE", "llama-2-7b-chat.ggmlv3.q4_0.bin")

log.info("DEFAULT_MODEL_HG_REPO_ID: %s", DEFAULT_MODEL_HG_REPO_ID)
Expand Down Expand Up @@ -70,13 +68,17 @@ def set_loading_model(boolean: bool):

app = FastAPI()


# https://github.com/tiangolo/fastapi/issues/3361
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
log.error("%s: %s", request, exc_str)
content = {"status_code": 10422, "message": exc_str, "data": None}
return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
return JSONResponse(
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)


@app.on_event("startup")
async def startup_event():
Expand Down Expand Up @@ -370,36 +372,67 @@ async def chat_completions(
assistant_start = "### Assistant:\n"
user_start = "### User:\n"
user_end = "\n\n"

user_message = next(
(message for message in body.messages if message.role == "user"), None
)
user_message_content = user_message.content if user_message else ""
assistant_message = next(
(message for message in body.messages if message.role == "assistant"), None
)
assistant_message_content = (
f"{assistant_start}{assistant_message.content}{assistant_end}"
if assistant_message
else ""
)
system_message = next(
(message for message in body.messages if message.role == "system"), None
)
system_message_content = system_message.content if system_message else system
# avoid duplicate user start token in prompt if user message already includes it
if len(user_start) > 0 and user_start in user_message_content:
user_start = ""
# avoid duplicate user end token in prompt if user message already includes it
if len(user_end) > 0 and user_end in user_message_content:
user_end = ""
# avoid duplicate assistant start token in prompt if user message already includes it
if len(assistant_start) > 0 and assistant_start in user_message_content:
assistant_start = ""
# avoid duplicate system_start token in prompt if system_message_content already includes it
if len(system_start) > 0 and system_start in system_message_content:
# openchat_3.5 https://huggingface.co/openchat/openchat_3.5
if "openchat" in body.model.lower():
system_start = ""
prompt = f"{system_start}{system_message_content}{system_end}{assistant_message_content}{user_start}{user_message_content}{user_end}{assistant_start}"
system = ""
system_end = ""
assistant_start = "GPT4 Assistant: "
assistant_end = "<|end_of_turn|>"
user_start = "GPT4 User: "
user_end = "<|end_of_turn|>"
# HG's zephyr https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
if "zephyr" in body.model.lower():
system_start = "<|system|>\n"
system = ""
system_end = "</s>\n"
assistant_start = "<|assistant|>"
assistant_end = "\n"
user_start = "<|user|>\n"
user_end = "</s>\n"

prompt = ""
for message in body.messages:
# Check for system message
if message.role == "system":
system_message_content = message.content if message else ""

# avoid duplicate system_start token in prompt if system_message_content already includes it
if len(system_start) > 0 and system_start in system_message_content:
system_start = ""
# avoid duplicate system_end token in prompt if system_message_content already includes it
if len(system_end) > 0 and system_end in system_message_content:
system_end = ""
prompt = f"{system_start}{system_message_content}{system_end}"
elif message.role == "user":
user_message_content = message.content if message else ""

# avoid duplicate user start token in prompt if user_message_content already includes it
if len(user_start) > 0 and user_start in user_message_content:
user_start = ""
# avoid duplicate user end token in prompt if user_message_content already includes it
if len(user_end) > 0 and user_end in user_message_content:
user_end = ""

prompt = f"{prompt}{user_start}{user_message_content}{user_end}"
elif message.role == "assistant":
assistant_message_content = message.content if message else ""

# avoid duplicate assistant start token in prompt if user message already includes it
if (
len(assistant_start) > 0
and assistant_start in assistant_message_content
):
assistant_start = ""
# avoid duplicate assistant start token in prompt if user message already includes it
if len(assistant_end) > 0 and assistant_end in assistant_message_content:
assistant_end = ""

prompt = (
f"{prompt}{assistant_start}{assistant_message_content}{assistant_end}"
)

prompt = f"{prompt}{assistant_start}"
model_name = body.model
llm = request.app.state.llm
if body.stream is True:
Expand Down
4 changes: 2 additions & 2 deletions request_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CompletionRequestBody(BaseModel):
# llama.cpp specific parameters
top_k: Optional[int]
repetition_penalty: Optional[float]
frequency_penalty: Optional[float]
last_n_tokens: Optional[int]
seed: Optional[int]
batch_size: Optional[int]
Expand All @@ -32,7 +33,6 @@ class CompletionRequestBody(BaseModel):
# ignored or currently unsupported
suffix: Any
presence_penalty: Any
frequency_penalty: Any
echo: Any
n: Any
logprobs: Any
Expand Down Expand Up @@ -73,6 +73,7 @@ class ChatCompletionRequestBody(BaseModel):
# llama.cpp specific parameters
top_k: Optional[int]
repetition_penalty: Optional[float]
frequency_penalty: Optional[float]
last_n_tokens: Optional[int]
seed: Optional[int]
batch_size: Optional[int]
Expand All @@ -83,7 +84,6 @@ class ChatCompletionRequestBody(BaseModel):
logit_bias: Any
user: Any
presence_penalty: Any
frequency_penalty: Any

class Config:
arbitrary_types_allowed = True

0 comments on commit 0c6e6b4

Please sign in to comment.