From 0246975e94a7db14875bf28a8a9db36fe500c251 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 2 Apr 2024 10:06:38 -0400 Subject: [PATCH] DH5669/store db dialect in database connection collection --- dataherald/sql_database/models/types.py | 34 +++++++++++++++++++++---- docs/api.create_database_connection.rst | 1 + docs/api.list_database_connections.rst | 2 ++ docs/api.update_database_connection.rst | 1 + 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index c4ed5b5c..05ecba5b 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -1,6 +1,7 @@ import os import re from datetime import datetime +from enum import Enum from typing import Any from pydantic import BaseModel, BaseSettings, Extra, Field, validator @@ -75,9 +76,23 @@ class InvalidURIFormatError(Exception): pass +class SupportedDialects(Enum): + POSTGRES = "postgresql" + MYSQL = "mysql" + MSSQL = "mssql" + DATABRICKS = "databricks" + SNOWFLAKE = "snowflake" + CLICKHOUSE = "clickhouse" + AWSATHENA = "awsathena" + DUCKDB = "duckdb" + BIGQUERY = "bigquery" + SQLITE = "sqlite" + + class DatabaseConnection(BaseModel): id: str | None alias: str + dialect: SupportedDialects | None use_ssh: bool = False connection_uri: str | None path_to_credentials_file: str | None @@ -93,16 +108,25 @@ def validate_uri(cls, input_string): match = re.match(pattern, input_string) if not match: raise InvalidURIFormatError(f"Invalid URI format: {input_string}") + return match.group(1) + + @classmethod + def set_dialect(cls, input_string): + for dialect in SupportedDialects: + if dialect.value in input_string: + return dialect.value + return None @validator("connection_uri", pre=True, always=True) - def connection_uri_format(cls, value: str): + def connection_uri_format(cls, value: str, values): fernet_encrypt = FernetEncrypt() try: - fernet_encrypt.decrypt(value) - return value + decrypted_value = fernet_encrypt.decrypt(value) + dialect_prefix = cls.validate_uri(decrypted_value) except Exception: - cls.validate_uri(value) - return fernet_encrypt.encrypt(value) + dialect_prefix = cls.validate_uri(value) + value = fernet_encrypt.encrypt(value) + values["dialect"] = cls.set_dialect(dialect_prefix) return value @validator("llm_api_key", pre=True, always=True) diff --git a/docs/api.create_database_connection.rst b/docs/api.create_database_connection.rst index 79f06a92..e7ca37be 100644 --- a/docs/api.create_database_connection.rst +++ b/docs/api.create_database_connection.rst @@ -85,6 +85,7 @@ HTTP 201 code response { "id": "64f251ce9614e0e94b0520bc", "alias": "string_999", + dialect: "postgresql", "use_ssh": true, "connection_uri": "gAAAAABk8lHQNAUn5XARb94Q8H1OfHpVzOtzP3b2LCpwxUsNCe7LGkwkN8FX-IF3t65oI5mTzgDMR0BY2lzvx55gO0rxlQxRDA==", "path_to_credentials_file": "string", diff --git a/docs/api.list_database_connections.rst b/docs/api.list_database_connections.rst index 02c58823..1396ee23 100644 --- a/docs/api.list_database_connections.rst +++ b/docs/api.list_database_connections.rst @@ -18,6 +18,7 @@ HTTP 200 code response { "id": "64dfa0e103f5134086f7090c", "alias": "databricks", + "dialect": "databricks", "use_ssh": false, "connection_uri": "foooAABk91Q4wjoR2h07GR7_72BdQnxi8Rm6i_EjyS-mzz_o2c3RAWaEqnlUvkK5eGD5kUfE5xheyivl1Wfbk_EM7CgV4SvdLmOOt7FJV-3kG4zAbar=", "path_to_credentials_file": null, @@ -27,6 +28,7 @@ HTTP 200 code response { "id": "64e52c5f7d6dc4bc510d6d28", "alias": "postgres", + "dialect": "postgres", "use_ssh": true, "connection_uri": null, "path_to_credentials_file": "bar-LWxPdFcjQw9lU7CeK_2ELR3jGBq0G_uQ7E2rfPLk2RcFR4aDO9e2HmeAQtVpdvtrsQ_0zjsy9q7asdsadXExYJ0g==", diff --git a/docs/api.update_database_connection.rst b/docs/api.update_database_connection.rst index 05fe2e32..f3aaa1bd 100644 --- a/docs/api.update_database_connection.rst +++ b/docs/api.update_database_connection.rst @@ -42,6 +42,7 @@ HTTP 200 code response { "id": "64f251ce9614e0e94b0520bc", "alias": "string_999", + "dialect": "sqlite", "use_ssh": false, "connection_uri": "gAAAAABk8lHQNAUn5XARb94Q8H1OfHpVzOtzP3b2LCpwxUsNCe7LGkwkN8FX-IF3t65oI5mTzgDMR0BY2lzvx55gO0rxlQxRDA==", "path_to_credentials_file": "string",