Skip to content

Commit

Permalink
Add Openai compatible Completion class (#60)
Browse files Browse the repository at this point in the history
* Add openai compatible Completion class

* Add pydantic dependency
  • Loading branch information
orangetin authored Nov 22, 2023
1 parent c562a64 commit 4fdf6b1
Show file tree
Hide file tree
Showing 8 changed files with 444 additions and 115 deletions.
359 changes: 255 additions & 104 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ requests = "^2.31.0"
tqdm = "^4.66.1"
sseclient-py = "^1.7.2"
tabulate = "^0.9.0"
pydantic = "^2.5.0"

[tool.poetry.group.quality]
optional = true
Expand Down
3 changes: 2 additions & 1 deletion src/together/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

min_samples = 100

from .complete import Complete
from .complete import Complete, Completion
from .embeddings import Embeddings
from .error import *
from .files import Files
Expand All @@ -54,6 +54,7 @@
"default_embedding_model",
"Models",
"Complete",
"Completion",
"Files",
"Finetune",
"Image",
Expand Down
1 change: 1 addition & 0 deletions src/together/commands/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def do_say(self, arg: str) -> None:
top_k=self.args.top_k,
repetition_penalty=self.args.repetition_penalty,
):
assert isinstance(token, str)
print(token, end="", flush=True)
output += token
except together.AuthenticationError:
Expand Down
1 change: 1 addition & 0 deletions src/together/commands/complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _run_complete(args: argparse.Namespace) -> None:
except together.AuthenticationError:
logger.critical(together.MISSING_API_KEY_MESSAGE)
exit(0)
assert isinstance(response, dict)
no_streamer(args, response)
else:
try:
Expand Down
78 changes: 70 additions & 8 deletions src/together/complete.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, Union

import together
from together.types import TogetherResponse
from together.utils import create_post_request, get_logger, sse_client


Expand All @@ -21,7 +22,9 @@ def create(
top_k: Optional[int] = 50,
repetition_penalty: Optional[float] = None,
logprobs: Optional[int] = None,
) -> Dict[str, Any]:
api_key: Optional[str] = None,
cast: bool = False,
) -> Union[Dict[str, Any], TogetherResponse]:
if model == "":
model = together.default_text_model

Expand All @@ -39,14 +42,18 @@ def create(

# send request
response = create_post_request(
url=together.api_base_complete, json=parameter_payload
url=together.api_base_complete, json=parameter_payload, api_key=api_key
)

try:
response_json = dict(response.json())

except Exception as e:
raise together.JSONError(e, http_status=response.status_code)

if cast:
return TogetherResponse(**response_json)

return response_json

@classmethod
Expand All @@ -55,13 +62,15 @@ def create_streaming(
prompt: str,
model: Optional[str] = "",
max_tokens: Optional[int] = 128,
stop: Optional[str] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 0.7,
top_k: Optional[int] = 50,
repetition_penalty: Optional[float] = None,
raw: Optional[bool] = False,
) -> Iterator[str]:
api_key: Optional[str] = None,
cast: Optional[bool] = False,
) -> Union[Iterator[str], Iterator[TogetherResponse]]:
"""
Prints streaming responses and returns the completed text.
"""
Expand All @@ -83,19 +92,25 @@ def create_streaming(

# send request
response = create_post_request(
url=together.api_base_complete, json=parameter_payload, stream=True
url=together.api_base_complete,
json=parameter_payload,
api_key=api_key,
stream=True,
)

output = ""
client = sse_client(response)
for event in client.events():
if raw:
if cast:
if event.data != "[DONE]":
yield TogetherResponse(**json.loads(event.data))
elif raw:
yield str(event.data)
elif event.data != "[DONE]":
json_response = dict(json.loads(event.data))
if "error" in json_response.keys():
raise together.ResponseError(
json_response["error"]["error"],
json_response["error"],
request_id=json_response["error"]["request_id"],
)
elif "choices" in json_response.keys():
Expand All @@ -106,3 +121,50 @@ def create_streaming(
raise together.ResponseError(
f"Unknown error occured. Received unhandled response: {event.data}"
)


class Completion:
@classmethod
def create(
self,
prompt: str,
model: Optional[str] = "",
max_tokens: Optional[int] = 128,
stop: Optional[List[str]] = [],
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 0.7,
top_k: Optional[int] = 50,
repetition_penalty: Optional[float] = None,
logprobs: Optional[int] = None,
api_key: Optional[str] = None,
stream: bool = False,
) -> Union[
TogetherResponse, Iterator[TogetherResponse], Iterator[str], Dict[str, Any]
]:
if stream:
return Complete.create_streaming(
prompt=prompt,
model=model,
max_tokens=max_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
api_key=api_key,
cast=True,
)
else:
return Complete.create(
prompt=prompt,
model=model,
max_tokens=max_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
logprobs=logprobs,
api_key=api_key,
cast=True,
)
111 changes: 111 additions & 0 deletions src/together/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import typing
from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import BaseModel


# Decoder input tokens
class InputToken(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]


# Generated tokens
class Token(BaseModel):
# Token ID
id: int
# Logprob
logprob: Optional[float]
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool


# Generation finish reason
class FinishReason(str, Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"


# Additional sequences when using the `best_of` parameter
class BestOfSequence(BaseModel):
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]


# `generate` details
class Details(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]


# `generate` return value
class Response(BaseModel):
# Generated text
generated_text: str
# Generation details
details: Details


class Choice(BaseModel):
# Generated text
text: str
finish_reason: Optional[FinishReason] = None
logprobs: Optional[List[float]] = None


# `generate_stream` details
class StreamDetails(BaseModel):
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]


# `generate_stream` return value
class TogetherResponse(BaseModel):
choices: Optional[List[Choice]] = None
id: Optional[str] = None
token: Optional[Token] = None
error: Optional[str] = None
error_type: Optional[str] = None
generated_text: Optional[str] = None
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails] = None

def __init__(self, **kwargs: Optional[Dict[str, Any]]) -> None:
if kwargs.get("output"):
kwargs["choices"] = typing.cast(Dict[str, Any], kwargs["output"])["choices"]
kwargs["details"] = kwargs.get("details")
super().__init__(**kwargs)
5 changes: 3 additions & 2 deletions src/together/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ def create_post_request(
json: Optional[Dict[Any, Any]] = None,
stream: Optional[bool] = False,
check_auth: Optional[bool] = True,
api_key: Optional[str] = None,
) -> requests.Response:
if check_auth:
if check_auth and api_key is None:
verify_api_key()

if not headers:
headers = {
"Authorization": f"Bearer {together.api_key}",
"Authorization": f"Bearer {api_key or together.api_key}",
"Content-Type": "application/json",
"User-Agent": together.user_agent,
}
Expand Down

0 comments on commit 4fdf6b1

Please sign in to comment.