Skip to content

Commit

Permalink
General implementations for core functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
markurtz committed Jun 11, 2024
1 parent 7b57e18 commit 297d480
Show file tree
Hide file tree
Showing 31 changed files with 1,839 additions and 114 deletions.
4 changes: 2 additions & 2 deletions src/guidellm/backend/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Iterator, List, Optional, Type
from typing import Iterator, List, Optional, Type, Union
from dataclasses import dataclass
import uuid
from loguru import logger
Expand Down Expand Up @@ -52,7 +52,7 @@ def inner_wrapper(wrapped_class: Type["Backend"]):
return inner_wrapper

@staticmethod
def create_backend(backend_type: BackendTypes, **kwargs) -> "Backend":
def create_backend(backend_type: Union[str, BackendTypes], **kwargs) -> "Backend":
"""
Factory method to create a backend based on the backend type.
Expand Down
82 changes: 49 additions & 33 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import openai
from typing import Iterator, List, Optional, Dict, Any
from urllib.parse import urlparse
from transformers import AutoTokenizer
from loguru import logger
from guidellm.backend import Backend, BackendTypes, GenerativeResponse
Expand All @@ -24,8 +23,10 @@ class OpenAIBackend(Backend):
:type path: Optional[str]
:param model: The OpenAI model to use, defaults to the first available model.
:type model: Optional[str]
:param model_args: Additional model arguments for the request.
:type model_args: Optional[Dict[str, Any]]
:param api_key: The OpenAI API key to use.
:type api_key: Optional[str]
:param request_args: Optional arguments for the OpenAI request.
:type request_args: Dict[str, Any]
"""

def __init__(
Expand All @@ -35,21 +36,30 @@ def __init__(
port: Optional[int] = None,
path: Optional[str] = None,
model: Optional[str] = None,
**model_args,
api_key: Optional[str] = None,
**request_args,
):
if target:
parsed_url = urlparse(target)
self.host = parsed_url.hostname
self.port = parsed_url.port
self.path = parsed_url.path
else:
self.host = host
self.port = port
self.path = path
self.target = target
self.model = model
self.model_args = model_args
openai.api_key = model_args.get("api_key", None)
logger.info(f"Initialized OpenAIBackend with model: {self.model}")
self.request_args = request_args

if not self.target:
if not host:
raise ValueError("Host is required if target is not provided.")

port_incl = f":{port}" if port else ""
path_incl = path if path else ""
self.target = f"http://{host}{port_incl}{path_incl}"

openai.api_base = self.target
openai.api_key = api_key

if not model:
self.model = self.default_model()

logger.info(
f"Initialized OpenAIBackend with target: {self.target} and model: {self.model}"
)

def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse]:
"""
Expand All @@ -61,14 +71,20 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
:rtype: Iterator[GenerativeResponse]
"""
logger.debug(f"Making request to OpenAI backend with prompt: {request.prompt}")
num_gen_tokens = request.params.get("generated_tokens", None)
request_args = {
"n": 1,
}

if num_gen_tokens:
request_args["max_tokens"] = num_gen_tokens
request_args["stop"] = None

if self.request_args:
request_args.update(self.request_args)

response = openai.Completion.create(
engine=self.model or self.default_model(),
prompt=request.prompt,
max_tokens=request.params.get("max_tokens", 100),
n=request.params.get("n", 1),
stop=request.params.get("stop", None),
stream=True,
**self.model_args,
engine=self.model, prompt=request.prompt, stream=True, **request_args,
)

for chunk in response:
Expand All @@ -80,8 +96,16 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
type_="final",
output=choice["text"],
prompt=request.prompt,
prompt_token_count=self._token_count(request.prompt),
output_token_count=self._token_count(choice["text"]),
prompt_token_count=(
request.token_count
if request.token_count
else self._token_count(request.prompt)
),
output_token_count=(
num_gen_tokens
if num_gen_tokens
else self._token_count(choice["text"])
),
)
break
else:
Expand Down Expand Up @@ -133,14 +157,6 @@ def model_tokenizer(self, model: str) -> Optional[Any]:
return None

def _token_count(self, text: str) -> int:
"""
Count the number of tokens in a text.
:param text: The text to tokenize.
:type text: str
:return: The number of tokens.
:rtype: int
"""
token_count = len(text.split())
logger.debug(f"Token count for text '{text}': {token_count}")
return token_count
36 changes: 33 additions & 3 deletions src/guidellm/core/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Any, Optional
import uuid


__all__ = ["BenchmarkRequest"]
Expand All @@ -10,11 +11,18 @@ class BenchmarkRequest:
:param prompt: The input prompt for the benchmark request.
:type prompt: str
:param token_count: The number of tokens to generate, defaults to None.
:type token_count: Optional[int]
:param params: Optional parameters for the benchmark request, defaults to None.
:type params: Optional[Dict[str, Any]]
"""

def __init__(self, prompt: str, params: Optional[Dict[str, Any]] = None):
def __init__(
self,
prompt: str,
token_count: Optional[int] = None,
params: Optional[Dict[str, Any]] = None,
):
"""
Initialize the BenchmarkRequest with a prompt and optional parameters.
Expand All @@ -23,9 +31,21 @@ def __init__(self, prompt: str, params: Optional[Dict[str, Any]] = None):
:param params: Optional parameters for the benchmark request, defaults to None.
:type params: Optional[Dict[str, Any]]
"""
self._id = str(uuid.uuid4())
self._prompt = prompt
self._token_count = token_count
self._params = params or {}

@property
def id(self) -> str:
"""
Get the unique identifier for the benchmark request.
:return: The unique identifier.
:rtype: str
"""
return self._id

@property
def prompt(self) -> str:
"""
Expand All @@ -36,6 +56,16 @@ def prompt(self) -> str:
"""
return self._prompt

@property
def token_count(self) -> Optional[int]:
"""
Get the number of tokens to generate for the benchmark request.
:return: The number of tokens to generate.
:rtype: Optional[int]
"""
return self._token_count

@property
def params(self) -> Dict[str, Any]:
"""
Expand All @@ -53,7 +83,7 @@ def __str__(self) -> str:
:return: String representation of the BenchmarkRequest.
:rtype: str
"""
return f"BenchmarkRequest(prompt={self._prompt}, params={self._params})"
return f"BenchmarkRequest(id={self.id}, prompt={self._prompt}, params={self._params})"

def __repr__(self) -> str:
"""
Expand All @@ -62,4 +92,4 @@ def __repr__(self) -> str:
:return: Unambiguous string representation of the BenchmarkRequest.
:rtype: str
"""
return f"BenchmarkRequest(prompt={self._prompt}, params={self._params})"
return f"BenchmarkRequest(id={self.id}, prompt={self._prompt}, params={self._params})"
Loading

0 comments on commit 297d480

Please sign in to comment.