Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
filopedraz committed Feb 16, 2024
2 parents 2a371fb + cd486c1 commit 07efeef
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 12 deletions.
3 changes: 3 additions & 0 deletions e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
mistral,
octoai,
openai,
prem,
replicate,
together,
)
Expand Down Expand Up @@ -69,6 +70,8 @@ def main():
api_key=os.environ["DEEP_INFRA_API_KEY"],
base_url="https://api.deepinfra.com/v1/openai",
)
elif connector["provider"] == "prem":
connector_object = prem.PremConnector(api_key=os.environ["PREMAI_BEARER_TOKEN"])
else:
print(f"No connector for {connector['provider']}")

Expand Down
8 changes: 7 additions & 1 deletion prem_utils/connectors/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ def embeddings(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="azure", model=model, provider_message=str(error))
return {
"data": [embedding.embedding for embedding in response.data],
"data": [
{
"index": embedding.index,
"embedding": embedding.embedding,
}
for embedding in response.data
],
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
Expand Down
4 changes: 3 additions & 1 deletion prem_utils/connectors/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def embeddings(
)
response = response.json()
return {
"data": response["result"]["data"],
"data": [
{"index": index, "embedding": embedding} for index, embedding in enumerate(response["result"]["data"])
],
"model": model,
"usage": None,
"provider_name": "Cloudflare",
Expand Down
2 changes: 1 addition & 1 deletion prem_utils/connectors/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def embeddings(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="cohere", model=model, provider_message=str(error))
return {
"data": response.embeddings,
"data": [{"index": index, "embedding": embedding} for index, embedding in enumerate(response.embeddings)],
"model": model,
"usage": None,
"provider_name": "Cohere",
Expand Down
8 changes: 7 additions & 1 deletion prem_utils/connectors/lamini.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ def embeddings(
embed = Embedding()
embeddngs = [list(emb) for emb in embed.generate(input)]
return {
"data": embeddngs,
"data": [
{
"index": index,
"embedding": embedding,
}
for index, embedding in enumerate(embeddngs)
],
"model": None,
"usage": None,
"provider_name": "LaMini",
Expand Down
10 changes: 5 additions & 5 deletions prem_utils/connectors/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def chat_completion(
],
"created": None,
"model": response.model,
"provider_name": "Anthropic",
"provider_id": "anthropic",
"provider_name": "Mistral",
"provider_id": "mistralai",
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
Expand All @@ -120,15 +120,15 @@ def embeddings(
input=input if type(input) is list else [input],
)
return {
"data": [emb.embedding for emb in response.data],
"data": [{"index": emb.index, "embedding": emb.embedding} for emb in response.data],
"model": response.model,
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
"provider_name": "Anthropic",
"provider_id": "anthropic",
"provider_name": "Mistral",
"provider_id": "mistralai",
}
except (MistralAPIException, MistralConnectionException) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
Expand Down
8 changes: 7 additions & 1 deletion prem_utils/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,13 @@ def embeddings(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))
return {
"data": [embedding.embedding for embedding in response.data],
"data": [
{
"index": embedding.index,
"embedding": embedding.embedding,
}
for embedding in response.data
],
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
Expand Down
78 changes: 78 additions & 0 deletions prem_utils/connectors/prem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any

import requests

from prem_utils import errors
from prem_utils.connectors.base import BaseConnector


class PremConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str | None = None) -> None:
super().__init__(prompt_template=prompt_template)
self._api_key = api_key

def parse_chunk(self, chunk) -> dict[str, Any]:
# Todo: Need to understand how it is used.
pass

def build_messages(self, messages: list[dict]) -> list[str]:
# Todo: Whether it can be used in current providers
pass

def preprocess_messages(self, messages):
# Todo: Need to understand whether to use it and how to use it.
pass

def chat_completion(
self,
model_name: str,
messages: list[dict[str]],
max_tokens: int,
temperature: float | None = 1.0,
top_p: float | None = 1.0,
):
assert model_name in ["phi-1-5", "phi-2", "tinyllama", "mamba-chat"], ValueError(
"Models other than 'phi-1-5', 'phi-2', 'tinyllama', 'mamba-chat' are not supported"
)

# this is how msgs look like: [{'role': 'user', 'content': ...}]
if model_name == "mamba-chat":
_base_url = "https://mamba.compute.premai.io/v1/chat/completions"
data = {
"model": model_name,
"messages": messages,
"max_length": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
else:
_base_url = f"https://{model_name}.compute.premai.io/mii/default"
data = {
"prompts": [message["content"] for message in messages],
"max_length": max_tokens,
"temperature": temperature,
"top_p": top_p,
}

_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self._api_key}"}

try:
response = requests.post(_base_url, json=data, headers=_headers)
except (
errors.PremProviderAPIErrror,
errors.PremProviderPermissionDeniedError,
errors.PremProviderUnprocessableEntityError,
errors.PremProviderInternalServerError,
errors.PremProviderAuthenticationError,
errors.PremProviderBadRequestError,
errors.PremProviderNotFoundError,
errors.PremProviderRateLimitError,
errors.PremProviderAPIResponseValidationError,
errors.PremProviderConflictError,
errors.PremProviderAPIStatusError,
errors.PremProviderAPITimeoutError,
errors.PremProviderAPIConnectionError,
) as error:
raise error

return response.text
2 changes: 1 addition & 1 deletion prem_utils/connectors/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def embeddings(
input={"text": input},
)
return {
"data": [response[0]["embedding"]],
"data": [{"index": index, "embedding": data["embedding"]} for index, data in enumerate(response)],
"model": model,
"provider_name": "Replicate",
"provider_id": "replicate",
Expand Down
28 changes: 27 additions & 1 deletion prem_utils/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@
"model_type": "text2text",
"context_window": 4096,
"input_cost_per_token": 0.0000009,
"output_cost_per_token": 0.0000009
"output_cost_per_token": 0.0000009,
"deprecated": true
},
{
"slug": "Open-Orca/Mistral-7B-OpenOrca",
Expand Down Expand Up @@ -512,6 +513,31 @@
"input_cost_per_token": 0.00000001
}
]
},
{
"provider": "prem",
"models": [
{
"slug": "phi-1-5",
"model_type": "text2text",
"context_window": 2048
},
{
"slug": "phi-2",
"model_type": "text2text",
"context_window": 2048
},
{
"slug": "tinyllama",
"model_type": "text2text",
"context_window": 2048
},
{
"slug": "mamba-chat",
"model_type": "text2text",
"context_window": 2048
}
]
}
]
}

0 comments on commit 07efeef

Please sign in to comment.