Skip to content

Commit

Permalink
Merge pull request #148 from stanford-oval/dev-fix-vllm
Browse files Browse the repository at this point in the history
[Bug Fix] Fix `VLLMClient` to reflect the recent update in vllm.
  • Loading branch information
shaoyijia authored Aug 23, 2024
2 parents 493d796 + 71f4cd9 commit 4076b0e
Showing 1 changed file with 73 additions and 35 deletions.
108 changes: 73 additions & 35 deletions knowledge_storm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import requests
from dsp import ERRORS, backoff_hdlr, giveup_hdlr
from dsp.modules.hf import openai_to_hf
from dsp.modules.hf_client import send_hfvllm_request_v00, send_hftgi_request_v01_wrapped
from dsp.modules.hf_client import send_hftgi_request_v01_wrapped
from openai import OpenAI
from transformers import AutoTokenizer

try:
Expand Down Expand Up @@ -123,7 +124,8 @@ def __init__(
self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
self.api_base = api_base
if not self.api_key:
raise ValueError("DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY")
raise ValueError(
"DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY")

def log_usage(self, response):
"""Log the total tokens from the DeepSeek API response."""
Expand Down Expand Up @@ -251,7 +253,8 @@ def __init__(
self.api_key = api_key or os.getenv("GROQ_API_KEY")
self.api_base = api_base
if not self.api_key:
raise ValueError("Groq API key must be provided either as an argument or as an environment variable GROQ_API_KEY")
raise ValueError(
"Groq API key must be provided either as an argument or as an environment variable GROQ_API_KEY")

def log_usage(self, response):
"""Log the total tokens from the Groq API response."""
Expand Down Expand Up @@ -466,49 +469,84 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
return completions


class VLLMClient(dspy.HFClientVLLM):
"""A wrapper class for dspy.HFClientVLLM."""
class VLLMClient(dspy.dsp.LM):
"""A client compatible with vLLM HTTP server.
def __init__(self, model, port, url="http://localhost", **kwargs):
"""Copied from dspy/dsp/modules/hf_client.py with the addition of storing additional kwargs."""
vLLM HTTP server is designed to be compatible with the OpenAI API. Use OpenAI client to interact with the server.
"""

super().__init__(model=model, port=port, url=url, **kwargs)
def __init__(self, model, port, model_type: Literal["chat", "text"] = "text", url="http://localhost",
api_key="null", **kwargs):
"""Check out https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html for more information."""
super().__init__(model=model)
# Store additional kwargs for the generate method.
self.kwargs = {**self.kwargs, **kwargs}
self.model = model
self.base_url = f"{url}:{port}/v1/"
if model_type == "chat":
self.base_url += "chat/"
self.client = OpenAI(base_url=self.base_url, api_key=api_key)
self.prompt_tokens = 0
self.completion_tokens = 0
self._token_usage_lock = threading.Lock()

def _generate(self, prompt, **kwargs):
"""Copied from dspy/dsp/modules/hf_client.py with the addition of passing kwargs to VLLM server."""
kwargs = {**self.kwargs, **kwargs}
def basic_request(self, prompt, **kwargs):
completion = self.client.chat.completions.create(
**kwargs,
messages=[{"role": "user", "content": prompt}],
)
return completion

# payload = {
# "model": kwargs["model"],
# "prompt": prompt,
# "max_tokens": kwargs["max_tokens"],
# "temperature": kwargs["temperature"],
# }
payload = {
"prompt": prompt,
**kwargs
@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=1000,
on_backoff=backoff_hdlr,
)
def request(self, prompt: str, **kwargs):
return self.basic_request(prompt, **kwargs)

def log_usage(self, response):
"""Log the total tokens from the response."""
usage_data = response.usage
if usage_data:
with self._token_usage_lock:
self.prompt_tokens += usage_data.prompt_tokens
self.completion_tokens += usage_data.completion_tokens

def get_usage_and_reset(self):
"""Get the total tokens used and reset the token usage."""
usage = {
self.kwargs.get('model') or self.kwargs.get('engine'):
{'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens}
}
self.prompt_tokens = 0
self.completion_tokens = 0

response = send_hfvllm_request_v00(
f"{self.url}/v1/completions",
json=payload,
headers=self.headers,
)
return usage

try:
json_response = response.json()
completions = json_response["choices"]
response = {
"prompt": prompt,
"choices": [{"text": c["text"]} for c in completions],
}
return response
def __call__(self, prompt: str, **kwargs):
kwargs = {**self.kwargs, **kwargs}

try:
response = self.request(prompt, **kwargs)
except Exception as e:
print("Failed to parse JSON response:", response.text)
raise Exception("Received invalid JSON response from server")
print(f"Failed to generate completion: {e}")
raise Exception(e)

self.log_usage(response)

choices = response.choices
completions = [choice.message.content for choice in choices]

history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
}
self.history.append(history)

return completions


class OllamaClient(dspy.OllamaLocal):
Expand Down

0 comments on commit 4076b0e

Please sign in to comment.