diff --git a/src/suql/prompt_continuation.py b/src/suql/prompt_continuation.py index f82e2d1..6e204e1 100644 --- a/src/suql/prompt_continuation.py +++ b/src/suql/prompt_continuation.py @@ -10,7 +10,6 @@ import openai from openai import OpenAI -client = OpenAI() import os import time import traceback @@ -35,14 +34,8 @@ lstrip_blocks=True, line_comment_prefix="#", ) -# # uncomment if using Azure OpenAI -openai.api_type == "open_ai" -# openai.api_type = "azure" -# openai.api_base = "https://ovalopenairesource.openai.azure.com/" -# openai.api_version = "2023-05-15" ENABLE_CACHING = False - if ENABLE_CACHING: import pymongo mongo_client = pymongo.MongoClient("localhost", 27017) @@ -82,6 +75,12 @@ def _model_name_to_cost(model_name: str) -> float: def openai_chat_completion_with_backoff(**kwargs): + client = OpenAI() + # # uncomment if using Azure OpenAI + openai.api_type == "open_ai" + # openai.api_type = "azure" + # openai.api_base = "https://ovalopenairesource.openai.azure.com/" + # openai.api_version = "2023-05-15" global total_cost ret = client.chat.completions.create(**kwargs) num_prompt_tokens = ret.usage.prompt_tokens