Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3 from premAI-io/feat/integrated-dalle3
Browse files Browse the repository at this point in the history
Added dalle3 integration
  • Loading branch information
filopedraz authored Dec 28, 2023
2 parents 21797ad + 1dfce2e commit c78b4f0
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
18 changes: 18 additions & 0 deletions e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ def main():

text2text_models = [model for model in connector["models"] if model["model_type"] == "text2text"]
text2vector_models = [model for model in connector["models"] if model["model_type"] == "text2vector"]
text2image_models = [model for model in connector["models"] if model["model_type"] == "text2image"]

if len(text2text_models) > 0:
model_object = text2text_models[0]

parameters = {}
parameters["model"] = model_object["slug"]
messages = [{"role": "user", "content": "Hello, how is it going?"}]
Expand All @@ -92,8 +95,23 @@ def main():
print(connector_object.parse_chunk(text))
print(f"\n\n\n Model {model_object['slug']} succeeed with streaming 🚀 \n\n\n")

if len(text2image_models) > 0:
model_object = text2image_models[0]

parameters = {}
parameters["model"] = model_object["slug"]
parameters["prompt"] = "A cute baby sea otter"
parameters["size"] = "1024x1024"
parameters["n"] = 1

print(f"Testing model {model_object['slug']} from {connector['provider']} connector \n\n\n")
response = connector_object.generate_image(**parameters)
print(response)
print(f"\n\n\n Model {model_object['slug']} succeeed 🚀 \n\n\n")

if len(text2vector_models) > 0:
model_object = text2vector_models[0]

input = "Hello, how is it going?"
print(f"Testing model {model_object['slug']} from {connector['provider']} connector")
response = connector_object.embeddings(model=model_object["slug"], input=input)
Expand Down
3 changes: 3 additions & 0 deletions prem_utils/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def embeddings(
"provider_id": "cohere",
}

def generate_image(self):
raise NotImplementedError

def finetuning(
self, model: str, training_data: list[dict], validation_data: list[dict] | None = None, num_epochs: int = 3
) -> str:
Expand Down
32 changes: 31 additions & 1 deletion prem_utils/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def finetuning(
) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))

return response.id

def get_finetuning_job(self, job_id) -> dict[str, any]:
Expand Down Expand Up @@ -297,3 +296,34 @@ def _upload_and_transform_data(self, data: list[dict], size: int) -> str:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=None, provider_message=str(error))
return response.id

def generate_image(
self,
model: str,
prompt: str,
size: str = "1024x1024",
n: int = 1,
quality: str = "standard",
style: str = "vivid",
):
try:
response = self.client.images.generate(
model=model, prompt=prompt, n=n, size=size, quality=quality, style=style
)
except (
NotFoundError,
APIResponseValidationError,
ConflictError,
APIStatusError,
APITimeoutError,
RateLimitError,
BadRequestError,
APIConnectionError,
AuthenticationError,
InternalServerError,
PermissionDeniedError,
UnprocessableEntityError,
) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))
return {"created": None, "data": [{"url": image.url} for image in response.data]}
5 changes: 5 additions & 0 deletions prem_utils/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@
"context_window" : 8191,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000004
},
{
"slug" : "dall-e-3",
"model_type" : "text2image",
"cost_per_image": 0.120
}
]
},
Expand Down

0 comments on commit c78b4f0

Please sign in to comment.