diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 76ca9e97..d4ed1c50 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -207,7 +207,7 @@ def create_database_connection( alias=database_connection_request.alias, uri=database_connection_request.connection_uri, path_to_credentials_file=database_connection_request.path_to_credentials_file, - llm_credentials=database_connection_request.llm_credentials, + llm_api_key=database_connection_request.llm_api_key, use_ssh=database_connection_request.use_ssh, ssh_settings=database_connection_request.ssh_settings, ) @@ -241,7 +241,7 @@ def update_database_connection( alias=database_connection_request.alias, uri=database_connection_request.connection_uri, path_to_credentials_file=database_connection_request.path_to_credentials_file, - llm_credentials=database_connection_request.llm_credentials, + llm_api_key=database_connection_request.llm_api_key, use_ssh=database_connection_request.use_ssh, ssh_settings=database_connection_request.ssh_settings, ) diff --git a/dataherald/model/base_models.py b/dataherald/model/base_models.py index 507e40dc..4fb52b80 100644 --- a/dataherald/model/base_models.py +++ b/dataherald/model/base_models.py @@ -25,11 +25,9 @@ def get_model( model_name="davinci-003", **kwargs: Any ) -> Any: - if database_connection.llm_credentials is not None: + if database_connection.llm_api_key is not None: fernet_encrypt = FernetEncrypt() - api_key = fernet_encrypt.decrypt( - database_connection.llm_credentials.api_key - ) + api_key = fernet_encrypt.decrypt(database_connection.llm_api_key) if model_family == "openai": self.openai_api_key = api_key elif model_family == "anthropic": diff --git a/dataherald/model/chat_model.py b/dataherald/model/chat_model.py index 6e55ad4a..2b19ebbb 100644 --- a/dataherald/model/chat_model.py +++ b/dataherald/model/chat_model.py @@ -21,11 +21,9 @@ def get_model( model_name="gpt-4-32k", **kwargs: Any ) -> Any: - if database_connection.llm_credentials is not None: + if database_connection.llm_api_key is not None: fernet_encrypt = FernetEncrypt() - api_key = fernet_encrypt.decrypt( - database_connection.llm_credentials.api_key - ) + api_key = fernet_encrypt.decrypt(database_connection.llm_api_key) if model_family == "openai": os.environ["OPENAI_API_KEY"] = api_key elif model_family == "anthropic": diff --git a/dataherald/scripts/migrate_v004_to_v005.py b/dataherald/scripts/migrate_v004_to_v005.py new file mode 100644 index 00000000..987cadc3 --- /dev/null +++ b/dataherald/scripts/migrate_v004_to_v005.py @@ -0,0 +1,17 @@ +import dataherald.config +from dataherald.config import System +from dataherald.db import DB + +if __name__ == "__main__": + settings = dataherald.config.Settings() + system = System(settings) + system.start() + storage = system.instance(DB) + + for db_connection in storage.find_all("database_connections"): + if "llm_credentials" in db_connection and db_connection["llm_credentials"]: + db_connection["llm_api_key"] = db_connection["llm_credentials"]["api_key"] + db_connection["llm_credentials"] = None + storage.update_or_create( + "database_connections", {"_id": db_connection["_id"]}, db_connection + ) diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 3fd25377..43973baf 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -57,7 +57,7 @@ class DatabaseConnection(BaseModel): use_ssh: bool = False uri: str | None path_to_credentials_file: str | None - llm_credentials: LLMCredentials | None = None + llm_api_key: str | None = None ssh_settings: SSHSettings | None = None @validator("uri", pre=True, always=True) diff --git a/dataherald/types.py b/dataherald/types.py index 93a6cf5a..c85ef9ad 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -5,7 +5,7 @@ from bson.objectid import ObjectId from pydantic import BaseModel, Field, validator -from dataherald.sql_database.models.types import LLMCredentials, SSHSettings +from dataherald.sql_database.models.types import SSHSettings class DBConnectionValidation(BaseModel): @@ -115,7 +115,7 @@ class DatabaseConnectionRequest(BaseModel): use_ssh: bool = False connection_uri: str | None path_to_credentials_file: str | None - llm_credentials: LLMCredentials | None + llm_api_key: str | None ssh_settings: SSHSettings | None diff --git a/docs/api.create_database_connection.rst b/docs/api.create_database_connection.rst index dec0ebec..1d2143bf 100644 --- a/docs/api.create_database_connection.rst +++ b/docs/api.create_database_connection.rst @@ -24,10 +24,7 @@ You can find additional details on how to connect to each of the supported data "use_ssh": false, "connection_uri": "string", "path_to_credentials_file": "string", - "llm_credentials": { - "organization_id": "string", - "api_key": "string" - }, + "llm_api_key": "string", "ssh_settings": { "db_name": "string", "host": "string", @@ -70,10 +67,7 @@ HTTP 201 code response "use_ssh": false, "uri": "gAAAAABk8lHQNAUn5XARb94Q8H1OfHpVzOtzP3b2LCpwxUsNCe7LGkwkN8FX-IF3t65oI5mTzgDMR0BY2lzvx55gO0rxlQxRDA==", "path_to_credentials_file": "string", - "llm_credentials": { - "organization_id": "gAAAAABlCz5TvOWQQ9TeSKgtCbaisl343oG3SaBlSniTsqs9R8aTIrptvzQq7b2a13ocBPuV6kGw17bximFbqAF_yaHmJF-Psw==", - "api_key": "gAAAAABlCz5TeU0ym4hW3bf9u21dz7B9tlnttOGLRDt8gq2ykkblNvpp70ZjT9FeFcoyMv-Csvp3GNQfw66eYvQBrcBEPsLokkLO2Jc2DD-Q8Aw6g_8UahdOTxJdT4izA6MsiQrf7GGmYBGZqbqsjTdNmcq661wF9Q==" - }, + "llm_api_key": "gAAAAABlCz5TeU0ym4hW3bf9u21dz7B9tlnttOGLRDt8gq2ykkblNvpp70ZjT9FeFcoyMv-Csvp3GNQfw66eYvQBrcBEPsLokkLO2Jc2DD-Q8Aw6g_8UahdOTxJdT4izA6MsiQrf7GGmYBGZqbqsjTdNmcq661wF9Q==", "ssh_settings": { "db_name": "string", "host": "string", @@ -150,10 +144,7 @@ With a SSH connection and LLM credentials -d '{ "alias": "my_db_alias", "use_ssh": true, - "llm_credentials": { - "organization_id": "organization_id", - "api_key": "api_key" - }, + "llm_api_key": "api_key", "ssh_settings": { "db_name": "db_name", "host": "string", diff --git a/docs/api.list_database_connections.rst b/docs/api.list_database_connections.rst index c35a5572..7e01abc9 100644 --- a/docs/api.list_database_connections.rst +++ b/docs/api.list_database_connections.rst @@ -21,7 +21,7 @@ HTTP 200 code response "use_ssh": false, "uri": "foooAABk91Q4wjoR2h07GR7_72BdQnxi8Rm6i_EjyS-mzz_o2c3RAWaEqnlUvkK5eGD5kUfE5xheyivl1Wfbk_EM7CgV4SvdLmOOt7FJV-3kG4zAbar=", "path_to_credentials_file": null, - "llm_credentials": null, + "llm_api_key": null, "ssh_settings": null }, { @@ -30,10 +30,7 @@ HTTP 200 code response "use_ssh": true, "uri": null, "path_to_credentials_file": null, - "llm_credentials": { - "organization_id": "gAAAAABlCz5TvOWQQ9TeSKgtCbaisl343oG3SaBlSniTsqs9R8aTIrptvzQq7b2a13ocBPuV6kGw17bximFbqAF_yaHmJF-Psw==", - "api_key": "gAAAAABlCz5TeU0ym4hW3bf9u21dz7B9tlnttOGLRDt8gq2ykkblNvpp70ZjT9FeFcoyMv-Csvp3GNQfw66eYvQBrcBEPsLokkLO2Jc2DD-Q8Aw6g_8UahdOTxJdT4izA6MsiQrf7GGmYBGZqbqsjTdNmcq661wF9Q==" - }, + "llm_api_key": "gAAAAABlCz5TeU0ym4hW3bf9u21dz7B9tlnttOGLRDt8gq2ykkblNvpp70ZjT9FeFcoyMv-Csvp3GNQfw66eYvQBrcBEPsLokkLO2Jc2DD-Q8Aw6g_8UahdOTxJdT4izA6MsiQrf7GGmYBGZqbqsjTdNmcq661wF9Q==", "ssh_settings": { "db_name": "string", "host": "string", diff --git a/docs/api.rst b/docs/api.rst index 23c30b04..554fc2f9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -30,10 +30,7 @@ Related endpoints are: "use_ssh": false, "connection_uri": "string", "path_to_credentials_file": "string", - "llm_credentials": { - "organization_id": "string", - "api_key": "string" - }, + "llm_api_key": "string", "ssh_settings": { "db_name": "string", "host": "string", diff --git a/docs/api.update_database_connection.rst b/docs/api.update_database_connection.rst index 506efa52..d640db0a 100644 --- a/docs/api.update_database_connection.rst +++ b/docs/api.update_database_connection.rst @@ -24,10 +24,7 @@ This endpoint is used to update a Database connection "use_ssh": true, "connection_uri": "string", "path_to_credentials_file": "string", - "llm_credentials": { - "organization_id": "string", - "api_key": "string" - }, + "llm_api_key": "string", "ssh_settings": { "db_name": "string", "host": "string", @@ -53,10 +50,7 @@ HTTP 200 code response "use_ssh": false, "uri": "gAAAAABk8lHQNAUn5XARb94Q8H1OfHpVzOtzP3b2LCpwxUsNCe7LGkwkN8FX-IF3t65oI5mTzgDMR0BY2lzvx55gO0rxlQxRDA==", "path_to_credentials_file": "string", - "llm_credentials": { - "organization_id": "string", - "api_key": "string" - }, + "llm_api_key": "string", "ssh_settings": { "db_name": "string", "host": "string", @@ -133,10 +127,7 @@ With a SSH connection and LLM credentials -d '{ "alias": "my_db_alias", "use_ssh": true, - "llm_credentials": { - "organization_id": "organization_id", - "api_key": "api_key" - }, + "llm_api_key": "api_key", "ssh_settings": { "db_name": "db_name", "host": "string",