diff --git a/docs/install_pip.md b/docs/install_pip.md index 9922d40..5192f82 100644 --- a/docs/install_pip.md +++ b/docs/install_pip.md @@ -73,7 +73,7 @@ embedding_store.start_embedding_server(host = host, port = port) - You can add more columns as needed using ``embedding_store.add()`; - This will be set up on port 8501, which matches the default keyword argument `embedding_server_address` in `suql_execute`. Make sure both addresses match if you modify it. -5. Set up the backend server for the `answer`, `summary` functions. In a separate terminal, first set up OpenAI API key with `export OPENAI_API_KEY=[your OpenAI API key here]`. Write the following content into a Python script and execute in that terminal: +5. Set up the backend server for the `answer`, `summary` functions. In a separate terminal, first set up your LLM API key environment variable following [the litellm provider doc](https://docs.litellm.ai/docs/providers) (e.g., for OpenAI, run `export OPENAI_API_KEY=[your OpenAI API key here]`). Write the following content into a Python script and execute in that terminal: ```python from suql.free_text_fcns_server import start_free_text_fncs_server @@ -84,7 +84,7 @@ start_free_text_fncs_server(host=host, port=port) # Test with the entry point -You should be good to go! In a separate terminal, run `export OPENAI_API_KEY=[your OpenAI API key here]`, and test with +You should be good to go! In a separate terminal, set up your LLM API key environment variable following [the litellm provider doc](https://docs.litellm.ai/docs/providers) (e.g., for OpenAI, run `export OPENAI_API_KEY=[your OpenAI API key here]`), and test with ```python >>> from suql import suql_execute diff --git a/docs/install_source.md b/docs/install_source.md index 0fde4bb..58f3960 100644 --- a/docs/install_source.md +++ b/docs/install_source.md @@ -73,7 +73,7 @@ under `if __name__ == "__main__":` to match your database with its column names. - For instance, this line instructs the SUQL compiler to set up an embedding server for the `restaurants` database, which has `_id` column as the unique row identifier, for the `popular_dishes` column (such column need to be of type `TEXT` or `TEXT[]`, or other fixed-length strings/list of strings) under table `restaurants`. This is executed with user privilege `user="select_user"` and `password="select_user"`; - By default, this will be set up on port 8501, which is then called by `src/suql/execute_free_text_sql.py`. In case you need to use another port, please change both addresses. -5. Set up the backend server for the `answer`, `summary` functions. In a separate terminal, first set up OpenAI API key with `export OPENAI_API_KEY=[your OpenAI API key here]`. Then, run `python src/suql/free_text_fcns_server.py`. +5. Set up the backend server for the `answer`, `summary` functions. In a separate terminal, first set up your LLM API key environment variable following [the litellm provider doc](https://docs.litellm.ai/docs/providers) (e.g., for OpenAI, run `export OPENAI_API_KEY=[your OpenAI API key here]`). Then, run `python src/suql/free_text_fcns_server.py`. - As you probably noticed, the code in `custom_functions.sql` is just making queries to this server, which handles the LLM API calls. If you changed the address in `custom_functions.sql`, then also update the address under `if __name__ == "__main__":`. ## Write 2 few-shot prompts @@ -88,7 +88,7 @@ We are very close to a fully-working LLM-powered agent! - If you decide to keep this, then modify the examples to match your domain; - If you decide to delete this, then simply set the line `enable_classifier=True` to be `enable_classifier=False`. -8. In a separate terminal from the two servers above, run `export OPENAI_API_KEY=[your OpenAI API key here]`. Test with `python src/suql/agent.py`. You should be able to interact with your agent on your CLI! +8. In a separate terminal from the two servers above, set up your LLM API key environment variable following [the litellm provider doc](https://docs.litellm.ai/docs/providers) (e.g., for OpenAI, run `export OPENAI_API_KEY=[your OpenAI API key here]`). Test with `python src/suql/agent.py`. You should be able to interact with your agent on your CLI! # Set up with Chainlit diff --git a/requirements.txt b/requirements.txt index fccee24..1523409 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -openai==1.3.2 Jinja2==3.1.2 Flask==2.3.2 Flask-Cors==4.0.0 @@ -8,4 +7,5 @@ spacy==3.7.4 tiktoken==0.4.0 psycopg2-binary==2.9.7 # you can also install from source if it works pglast==5.3 -FlagEmbedding==1.2.5 \ No newline at end of file +FlagEmbedding==1.2.5 +litellm==1.34.34 \ No newline at end of file diff --git a/src/suql/prompt_continuation.py b/src/suql/prompt_continuation.py index 6e204e1..58d5736 100644 --- a/src/suql/prompt_continuation.py +++ b/src/suql/prompt_continuation.py @@ -7,9 +7,6 @@ from functools import partial from typing import List -import openai -from openai import OpenAI - import os import time import traceback @@ -19,6 +16,8 @@ from jinja2 import Environment, FileSystemLoader, select_autoescape from suql.utils import num_tokens_from_string +from litellm import completion, completion_cost + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -41,55 +40,17 @@ mongo_client = pymongo.MongoClient("localhost", 27017) prompt_cache_db = mongo_client["open_ai_prompts"]["caches"] -# inference_cost_per_1000_tokens = {'ada': 0.0004, 'babbage': 0.0005, 'curie': 0.002, 'davinci': 0.02, 'turbo': 0.003, 'gpt-4': 0.03} # for Azure -inference_input_cost_per_1000_tokens = { - "gpt-4": 0.03, - "gpt-3.5-turbo-0613": 0.0010, - "gpt-3.5-turbo-1106": 0.0010, - "gpt-4-1106-preview": 0.01, -} # for OpenAI -inference_output_cost_per_1000_tokens = { - "gpt-4": 0.06, - "gpt-3.5-turbo-0613": 0.0010, - "gpt-3.5-turbo-1106": 0.0020, - "gpt-4-1106-preview": 0.03, -} # for OpenAI -total_cost = 0 # in USD - +total_cost = 0 # in USD def get_total_cost(): global total_cost return total_cost -def _model_name_to_cost(model_name: str) -> float: - if ( - model_name in inference_input_cost_per_1000_tokens - and model_name in inference_output_cost_per_1000_tokens - ): - return ( - inference_input_cost_per_1000_tokens[model_name], - inference_output_cost_per_1000_tokens[model_name], - ) - raise ValueError("Did not recognize GPT model name %s" % model_name) - - -def openai_chat_completion_with_backoff(**kwargs): - client = OpenAI() - # # uncomment if using Azure OpenAI - openai.api_type == "open_ai" - # openai.api_type = "azure" - # openai.api_base = "https://ovalopenairesource.openai.azure.com/" - # openai.api_version = "2023-05-15" +def chat_completion_with_backoff(**kwargs): global total_cost - ret = client.chat.completions.create(**kwargs) - num_prompt_tokens = ret.usage.prompt_tokens - num_completion_tokens = ret.usage.completion_tokens - prompt_cost, completion_cost = _model_name_to_cost(kwargs["model"]) - total_cost += ( - num_prompt_tokens / 1000 * prompt_cost - + num_completion_tokens / 1000 * completion_cost - ) # TODO: update this + ret = completion(**kwargs) + total_cost += completion_cost(ret) return ret.choices[0].message.content @@ -126,6 +87,7 @@ def _generate( no_line_break_start = "" no_line_break_length = 0 kwargs = { + "model": engine, "messages": [ {"role": "system", "content": filled_prompt + no_line_break_start} ], @@ -136,28 +98,8 @@ def _generate( "presence_penalty": presence_penalty, "stop": stop_tokens, } - if openai.api_type == "azure": - kwargs.update({"engine": engine}) - else: - engine_model_map = { - "gpt-4": "gpt-4", - "gpt-35-turbo": "gpt-3.5-turbo-1106", - "gpt-3.5-turbo": "gpt-3.5-turbo-1106", - "gpt-4-turbo": "gpt-4-1106-preview", - } - # https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models - # https://platform.openai.com/docs/models/model-endpoint-compatibility - kwargs.update( - { - "model": ( - engine_model_map[engine] - if engine in engine_model_map - else engine - ) - } - ) - - generation_output = openai_chat_completion_with_backoff(**kwargs) + + generation_output = chat_completion_with_backoff(**kwargs) generation_output = no_line_break_start + generation_output logger.info("LLM output = %s", generation_output)