diff --git a/README.md b/README.md index fcfdd1aa..c2a4b768 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,24 @@ curl -X 'POST' \ }' ``` +##### Connecting multi-schemas +You can connect many schemas using one db connection if you want to create SQL joins between schemas. +Currently only `BigQuery`, `Snowflake`, `Databricks` and `Postgres` support this feature. +To use multi-schemas instead of sending the `schema` in the `connection_uri` set it in the `schemas` param, like this: + +``` +curl -X 'POST' \ + '/api/v1/database-connections' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "alias": "my_db_alias", + "use_ssh": false, + "connection_uri": snowflake://:@-/", + "schemas": ["schema_1", "schema_2", ...] +}' +``` + ##### Connecting to supported Data warehouses and using SSH You can find the details on how to connect to the supported data warehouses in the [docs](https://dataherald.readthedocs.io/en/latest/api.create_database_connection.html) @@ -194,7 +212,8 @@ While only the Database scan part is required to start generating SQL, adding ve #### Scanning the Database The database scan is used to gather information about the database including table and column names and identifying low cardinality columns and their values to be stored in the context store and used in the prompts to the LLM. In addition, it retrieves logs, which consist of historical queries associated with each database table. These records are then stored within the query_history collection. The historical queries retrieved encompass data from the past three months and are grouped based on query and user. -db_connection_id is the id of the database connection you want to scan, which is returned when you create a database connection. +The db_connection_id param is the id of the database connection you want to scan, which is returned when you create a database connection. +The ids param is the table_description_id that you want to scan. You can trigger a scan of a database from the `POST /api/v1/table-descriptions/sync-schemas` endpoint. Example below @@ -205,11 +224,11 @@ curl -X 'POST' \ -H 'Content-Type: application/json' \ -d '{ "db_connection_id": "db_connection_id", - "table_names": ["table_name"] + "ids": ["", "", ...] }' ``` -Since the endpoint identifies low cardinality columns (and their values) it can take time to complete. Therefore while it is possible to trigger a scan on the entire DB by not specifying the `table_names`, we recommend against it for large databases. +Since the endpoint identifies low cardinality columns (and their values) it can take time to complete. #### Get logs per db connection Once a database was scanned you can use this endpoint to retrieve the tables logs diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index d854da05..23e3976c 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -70,6 +70,9 @@ SQLInjectionError, ) from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.sql_database.services.database_connection import ( + DatabaseConnectionService, +) from dataherald.types import ( BaseLLM, CancelFineTuningRequest, @@ -88,17 +91,20 @@ ) from dataherald.utils.encrypt import FernetEncrypt from dataherald.utils.error_codes import error_response, stream_error_response +from dataherald.utils.sql_utils import ( + filter_golden_records_based_on_schema, + validate_finetuning_schema, +) logger = logging.getLogger(__name__) MAX_ROWS_TO_CREATE_CSV_FILE = 50 -def async_scanning(scanner, database, scanner_request, storage): +def async_scanning(scanner, database, table_descriptions, storage): scanner.scan( database, - scanner_request.db_connection_id, - scanner_request.table_names, + table_descriptions, TableDescriptionRepository(storage), QueryHistoryRepository(storage), ) @@ -130,42 +136,42 @@ def scan_db( self, scanner_request: ScannerRequest, background_tasks: BackgroundTasks ) -> list[TableDescriptionResponse]: """Takes a db_connection_id and scan all the tables columns""" - try: - db_connection_repository = DatabaseConnectionRepository(self.storage) - - db_connection = db_connection_repository.find_by_id( - scanner_request.db_connection_id - ) + scanner_repository = TableDescriptionRepository(self.storage) + data = {} + for id in scanner_request.ids: + table_description = scanner_repository.find_by_id(id) + if not table_description: + raise Exception("Table description not found") + if table_description.db_connection_id not in data.keys(): + data[table_description.db_connection_id] = {} + if ( + table_description.schema_name + not in data[table_description.db_connection_id].keys() + ): + data[table_description.db_connection_id][ + table_description.schema_name + ] = [] + data[table_description.db_connection_id][ + table_description.schema_name + ].append(table_description) - if not db_connection: - raise DatabaseConnectionNotFoundError( - f"Database connection {scanner_request.db_connection_id} not found" + db_connection_repository = DatabaseConnectionRepository(self.storage) + scanner = self.system.instance(Scanner) + rows = scanner.synchronizing( + scanner_request, + TableDescriptionRepository(self.storage), + ) + database_connection_service = DatabaseConnectionService(scanner, self.storage) + for db_connection_id, schemas_and_table_descriptions in data.items(): + for schema, table_descriptions in schemas_and_table_descriptions.items(): + db_connection = db_connection_repository.find_by_id(db_connection_id) + database = database_connection_service.get_sql_database( + db_connection, schema ) - database = SQLDatabase.get_sql_engine(db_connection, True) - all_tables = database.get_tables_and_views() - - if scanner_request.table_names: - for table in scanner_request.table_names: - if table not in all_tables: - raise HTTPException( - status_code=404, - detail=f"Table named: {table} doesn't exist", - ) # noqa: B904 - else: - scanner_request.table_names = all_tables - - scanner = self.system.instance(Scanner) - rows = scanner.synchronizing( - scanner_request, - TableDescriptionRepository(self.storage), - ) - except Exception as e: - return error_response(e, scanner_request.dict(), "invalid_database_sync") - - background_tasks.add_task( - async_scanning, scanner, database, scanner_request, self.storage - ) + background_tasks.add_task( + async_scanning, scanner, database, table_descriptions, self.storage + ) return [TableDescriptionResponse(**row.dict()) for row in rows] @override @@ -173,27 +179,9 @@ def create_database_connection( self, database_connection_request: DatabaseConnectionRequest ) -> DatabaseConnectionResponse: try: - db_connection = DatabaseConnection( - alias=database_connection_request.alias, - connection_uri=database_connection_request.connection_uri.strip(), - path_to_credentials_file=database_connection_request.path_to_credentials_file, - llm_api_key=database_connection_request.llm_api_key, - use_ssh=database_connection_request.use_ssh, - ssh_settings=database_connection_request.ssh_settings, - file_storage=database_connection_request.file_storage, - metadata=database_connection_request.metadata, - ) - sql_database = SQLDatabase.get_sql_engine(db_connection, True) - - # Get tables and views and create table-descriptions as NOT_SCANNED - db_connection_repository = DatabaseConnectionRepository(self.storage) - - scanner_repository = TableDescriptionRepository(self.storage) scanner = self.system.instance(Scanner) - - tables = sql_database.get_tables_and_views() - db_connection = db_connection_repository.insert(db_connection) - scanner.create_tables(tables, str(db_connection.id), scanner_repository) + db_connection_service = DatabaseConnectionService(scanner, self.storage) + db_connection = db_connection_service.create(database_connection_request) except Exception as e: # Encrypt sensible values fernet_encrypt = FernetEncrypt() @@ -209,7 +197,6 @@ def create_database_connection( return error_response( e, database_connection_request.dict(), "invalid_database_connection" ) - return DatabaseConnectionResponse(**db_connection.dict()) @override @@ -220,18 +207,30 @@ def refresh_table_description( db_connection = db_connection_repository.find_by_id( refresh_table_description.db_connection_id ) - + scanner = self.system.instance(Scanner) + database_connection_service = DatabaseConnectionService(scanner, self.storage) try: - sql_database = SQLDatabase.get_sql_engine(db_connection, True) - tables = sql_database.get_tables_and_views() + data = {} + if db_connection.schemas: + for schema in db_connection.schemas: + sql_database = database_connection_service.get_sql_database( + db_connection, schema + ) + if schema not in data.keys(): + data[schema] = [] + data[schema] = sql_database.get_tables_and_views() + else: + sql_database = database_connection_service.get_sql_database( + db_connection + ) + data[None] = sql_database.get_tables_and_views() - # Get tables and views and create missing table-descriptions as NOT_SCANNED and update DEPRECATED scanner_repository = TableDescriptionRepository(self.storage) - scanner = self.system.instance(Scanner) + return [ TableDescriptionResponse(**record.dict()) for record in scanner.refresh_tables( - tables, str(db_connection.id), scanner_repository + data, str(db_connection.id), scanner_repository ) ] except Exception as e: @@ -569,7 +568,6 @@ def create_finetuning_job( ) -> Finetuning: try: db_connection_repository = DatabaseConnectionRepository(self.storage) - db_connection = db_connection_repository.find_by_id( fine_tuning_request.db_connection_id ) @@ -577,7 +575,7 @@ def create_finetuning_job( raise DatabaseConnectionNotFoundError( f"Database connection not found, {fine_tuning_request.db_connection_id}" ) - + validate_finetuning_schema(fine_tuning_request, db_connection) golden_sqls_repository = GoldenSQLRepository(self.storage) golden_sqls = [] if fine_tuning_request.golden_sqls: @@ -598,6 +596,9 @@ def create_finetuning_job( raise GoldenSQLNotFoundError( f"No golden sqls found for db_connection: {fine_tuning_request.db_connection_id}" ) + golden_sqls = filter_golden_records_based_on_schema( + golden_sqls, fine_tuning_request.schemas + ) default_base_llm = BaseLLM( model_provider="openai", model_name="gpt-3.5-turbo-1106", @@ -611,6 +612,7 @@ def create_finetuning_job( model = model_repository.insert( Finetuning( db_connection_id=fine_tuning_request.db_connection_id, + schemas=fine_tuning_request.schemas, alias=fine_tuning_request.alias if fine_tuning_request.alias else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}", diff --git a/dataherald/api/types/requests.py b/dataherald/api/types/requests.py index fcb90348..108c775e 100644 --- a/dataherald/api/types/requests.py +++ b/dataherald/api/types/requests.py @@ -6,6 +6,7 @@ class PromptRequest(BaseModel): text: str db_connection_id: str + schemas: list[str] | None metadata: dict | None diff --git a/dataherald/api/types/responses.py b/dataherald/api/types/responses.py index 7edd7c28..ebacf601 100644 --- a/dataherald/api/types/responses.py +++ b/dataherald/api/types/responses.py @@ -25,6 +25,7 @@ def created_at_as_string(cls, v): class PromptResponse(BaseResponse): text: str db_connection_id: str + schemas: list[str] | None class SQLGenerationResponse(BaseResponse): diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index df0f40bd..34a1d8e8 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -13,6 +13,7 @@ from dataherald.repositories.golden_sqls import GoldenSQLRepository from dataherald.repositories.instructions import InstructionRepository from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt +from dataherald.utils.sql_utils import extract_the_schemas_from_sql logger = logging.getLogger(__name__) @@ -86,6 +87,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL f"Database connection not found, {record.db_connection_id}" ) + if db_connection.schemas: + schema_not_found = True + used_schemas = extract_the_schemas_from_sql(record.sql) + for schema in db_connection.schemas: + if schema in used_schemas: + schema_not_found = False + break + if schema_not_found: + raise MalformedGoldenSQLError( + f"SQL {record.sql} does not contain any of the schemas {db_connection.schemas}" + ) + prompt_text = record.prompt_text golden_sql = GoldenSQL( prompt_text=prompt_text, diff --git a/dataherald/db_scanner/__init__.py b/dataherald/db_scanner/__init__.py index 67ea3bff..cd62dddd 100644 --- a/dataherald/db_scanner/__init__.py +++ b/dataherald/db_scanner/__init__.py @@ -14,8 +14,7 @@ class Scanner(Component, ABC): def scan( self, db_engine: SQLDatabase, - db_connection_id: str, - table_names: list[str] | None, + table_descriptions: list[TableDescription], repository: TableDescriptionRepository, query_history_repository: QueryHistoryRepository, ) -> None: @@ -34,6 +33,7 @@ def create_tables( self, tables: list[str], db_connection_id: str, + schema: str, repository: TableDescriptionRepository, metadata: dict = None, ) -> None: @@ -42,7 +42,7 @@ def create_tables( @abstractmethod def refresh_tables( self, - tables: list[str], + schemas_and_tables: dict[str, list], db_connection_id: str, repository: TableDescriptionRepository, metadata: dict = None, diff --git a/dataherald/db_scanner/models/types.py b/dataherald/db_scanner/models/types.py index aa737495..c6b33810 100644 --- a/dataherald/db_scanner/models/types.py +++ b/dataherald/db_scanner/models/types.py @@ -31,6 +31,7 @@ class TableDescriptionStatus(Enum): class TableDescription(BaseModel): id: str | None db_connection_id: str + schema_name: str | None table_name: str description: str | None table_schema: str | None diff --git a/dataherald/db_scanner/repository/base.py b/dataherald/db_scanner/repository/base.py index 7e6cf3da..870263cb 100644 --- a/dataherald/db_scanner/repository/base.py +++ b/dataherald/db_scanner/repository/base.py @@ -53,13 +53,18 @@ def save_table_info(self, table_info: TableDescription) -> TableDescription: table_info_dict = { k: v for k, v in table_info_dict.items() if v is not None and v != [] } + + query = { + "db_connection_id": table_info_dict["db_connection_id"], + "table_name": table_info_dict["table_name"], + } + if "schema_name" in table_info_dict: + query["schema_name"] = table_info_dict["schema_name"] + table_info.id = str( self.storage.update_or_create( DB_COLLECTION, - { - "db_connection_id": table_info_dict["db_connection_id"], - "table_name": table_info_dict["table_name"], - }, + query, table_info_dict, ) ) diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index e1897083..2d3da7ca 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -44,6 +44,7 @@ def create_tables( self, tables: list[str], db_connection_id: str, + schema: str, repository: TableDescriptionRepository, metadata: dict = None, ) -> None: @@ -51,6 +52,7 @@ def create_tables( repository.save_table_info( TableDescription( db_connection_id=db_connection_id, + schema_name=schema, table_name=table, status=TableDescriptionStatus.NOT_SCANNED.value, metadata=metadata, @@ -60,34 +62,38 @@ def create_tables( @override def refresh_tables( self, - tables: list[str], + schemas_and_tables: dict[str, list], db_connection_id: str, repository: TableDescriptionRepository, metadata: dict = None, ) -> list[TableDescription]: - stored_tables = repository.find_by({"db_connection_id": str(db_connection_id)}) - stored_tables_list = [table.table_name for table in stored_tables] - rows = [] - for table_description in stored_tables: - if table_description.table_name not in tables: - table_description.status = TableDescriptionStatus.DEPRECATED.value - rows.append(repository.save_table_info(table_description)) - else: - rows.append(TableDescription(**table_description.dict())) + for schema, tables in schemas_and_tables.items(): + stored_tables = repository.find_by( + {"db_connection_id": str(db_connection_id), "schema_name": schema} + ) + stored_tables_list = [table.table_name for table in stored_tables] - for table in tables: - if table not in stored_tables_list: - rows.append( - repository.save_table_info( - TableDescription( - db_connection_id=db_connection_id, - table_name=table, - status=TableDescriptionStatus.NOT_SCANNED.value, - metadata=metadata, + for table_description in stored_tables: + if table_description.table_name not in tables: + table_description.status = TableDescriptionStatus.DEPRECATED.value + rows.append(repository.save_table_info(table_description)) + else: + rows.append(TableDescription(**table_description.dict())) + + for table in tables: + if table not in stored_tables_list: + rows.append( + repository.save_table_info( + TableDescription( + db_connection_id=db_connection_id, + table_name=table, + status=TableDescriptionStatus.NOT_SCANNED.value, + metadata=metadata, + schema_name=schema, + ) ) ) - ) return rows @override @@ -96,16 +102,17 @@ def synchronizing( scanner_request: ScannerRequest, repository: TableDescriptionRepository, ) -> list[TableDescription]: - # persist tables to be scanned rows = [] - for table in scanner_request.table_names: + for id in scanner_request.ids: + table_description = repository.find_by_id(id) rows.append( repository.save_table_info( TableDescription( - db_connection_id=scanner_request.db_connection_id, - table_name=table, + db_connection_id=table_description.db_connection_id, + table_name=table_description.table_name, status=TableDescriptionStatus.SYNCHRONIZING.value, metadata=scanner_request.metadata, + schema_name=table_description.schema_name, ) ) ) @@ -235,6 +242,7 @@ def scan_single_table( db_connection_id: str, repository: TableDescriptionRepository, scanner_service: AbstractScanner, + schema: str | None = None, ) -> TableDescription: print(f"Scanning table: {table}") inspector = inspect(db_engine.engine) @@ -267,6 +275,7 @@ def scan_single_table( last_schema_sync=datetime.now(), error_message="", status=TableDescriptionStatus.SCANNED.value, + schema_name=schema, ) repository.save_table_info(object) @@ -276,8 +285,7 @@ def scan_single_table( def scan( self, db_engine: SQLDatabase, - db_connection_id: str, - table_names: list[str] | None, + table_descriptions: list[TableDescription], repository: TableDescriptionRepository, query_history_repository: QueryHistoryRepository, ) -> None: @@ -293,41 +301,35 @@ def scan( if db_engine.engine.dialect.name in services.keys(): scanner_service = services[db_engine.engine.dialect.name]() - inspector = inspect(db_engine.engine) + inspect(db_engine.engine) meta = MetaData(bind=db_engine.engine) MetaData.reflect(meta, views=True) - tables = inspector.get_table_names() + inspector.get_view_names() - if table_names: - table_names = [table.lower() for table in table_names] - tables = [ - table for table in tables if table and table.lower() in table_names - ] - if len(tables) == 0: - raise ValueError("No table found") - for table in tables: + for table in table_descriptions: try: self.scan_single_table( meta=meta, - table=table, + table=table.table_name, db_engine=db_engine, - db_connection_id=db_connection_id, + db_connection_id=table.db_connection_id, repository=repository, scanner_service=scanner_service, + schema=table.schema_name, ) except Exception as e: repository.save_table_info( TableDescription( - db_connection_id=db_connection_id, + db_connection_id=table.db_connection_id, table_name=table, status=TableDescriptionStatus.FAILED.value, error_message=f"{e}", + schema_name=table.schema_name, ) ) try: logger.info(f"Get logs table: {table}") query_history = scanner_service.get_logs( - table, db_engine, db_connection_id + table.table_name, db_engine, table.db_connection_id ) if len(query_history) > 0: for query in query_history: diff --git a/dataherald/finetuning/openai_finetuning.py b/dataherald/finetuning/openai_finetuning.py index 5c553579..95b0f10f 100644 --- a/dataherald/finetuning/openai_finetuning.py +++ b/dataherald/finetuning/openai_finetuning.py @@ -69,6 +69,12 @@ def map_finetuning_status(status: str) -> str: return FineTuningStatus.QUEUED.value return mapped_statuses[status] + @staticmethod + def _filter_tables_by_schema(db_scan: List[TableDescription], schemas: List[str]): + if schemas: + return [table for table in db_scan if table.schema_name in schemas] + return db_scan + def format_columns( self, table: TableDescription, top_k: int = CATEGORICAL_COLUMNS_THRESHOLD ) -> str: @@ -197,6 +203,7 @@ def create_fintuning_dataset(self): "status": TableDescriptionStatus.SCANNED.value, } ) + db_scan = self._filter_tables_by_schema(db_scan, self.fine_tuning_model.schemas) golden_sqls_repository = GoldenSQLRepository(self.storage) finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl" model_repository = FinetuningsRepository(self.storage) diff --git a/dataherald/services/prompts.py b/dataherald/services/prompts.py index f42c272b..b7392b4a 100644 --- a/dataherald/services/prompts.py +++ b/dataherald/services/prompts.py @@ -4,6 +4,7 @@ DatabaseConnectionRepository, ) from dataherald.repositories.prompts import PromptNotFoundError, PromptRepository +from dataherald.sql_database.services.database_connection import SchemaNotSupportedError from dataherald.types import Prompt @@ -22,9 +23,16 @@ def create(self, prompt_request: PromptRequest) -> Prompt: f"Database connection {prompt_request.db_connection_id} not found" ) + if not db_connection.schemas and prompt_request.schemas: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {db_connection.dialect} dialect doesn't support schemas", + ) + prompt = Prompt( text=prompt_request.text, db_connection_id=prompt_request.db_connection_id, + schemas=prompt_request.schemas, metadata=prompt_request.metadata, ) return self.prompt_repository.insert(prompt) diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 2a0cc0cc..dc969fa6 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -96,6 +96,7 @@ class DatabaseConnection(BaseModel): dialect: SupportedDialects | None use_ssh: bool = False connection_uri: str | None + schemas: list[str] | None path_to_credentials_file: str | None llm_api_key: str | None = None ssh_settings: SSHSettings | None = None @@ -105,7 +106,7 @@ class DatabaseConnection(BaseModel): @classmethod def get_dialect(cls, input_string): - pattern = r"([^:/]+):/+([^/]+)/?([^/]+)" + pattern = r"([^:/]+)://" match = re.match(pattern, input_string) if not match: raise InvalidURIFormatError(f"Invalid URI format: {input_string}") diff --git a/dataherald/sql_database/services/__init__.py b/dataherald/sql_database/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dataherald/sql_database/services/database_connection.py b/dataherald/sql_database/services/database_connection.py new file mode 100644 index 00000000..0d7bf052 --- /dev/null +++ b/dataherald/sql_database/services/database_connection.py @@ -0,0 +1,155 @@ +import re + +from sqlalchemy import inspect + +from dataherald.db import DB +from dataherald.db_scanner import Scanner +from dataherald.db_scanner.repository.base import TableDescriptionRepository +from dataherald.repositories.database_connections import DatabaseConnectionRepository +from dataherald.sql_database.base import SQLDatabase +from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.types import DatabaseConnectionRequest +from dataherald.utils.encrypt import FernetEncrypt +from dataherald.utils.error_codes import CustomError + + +class SchemaNotSupportedError(CustomError): + pass + + +class DatabaseConnectionService: + def __init__(self, scanner: Scanner, storage: DB): + self.scanner = scanner + self.storage = storage + + def get_sql_database( + self, database_connection: DatabaseConnection, schema: str = None + ) -> SQLDatabase: + fernet_encrypt = FernetEncrypt() + if schema: + database_connection.connection_uri = fernet_encrypt.encrypt( + self.add_schema_in_uri( + fernet_encrypt.decrypt(database_connection.connection_uri), + schema, + database_connection.dialect.value, + ) + ) + return SQLDatabase.get_sql_engine(database_connection, True) + + def get_current_schema( + self, database_connection: DatabaseConnection + ) -> list[str] | None: + sql_database = SQLDatabase.get_sql_engine(database_connection, True) + inspector = inspect(sql_database.engine) + if inspector.default_schema_name and database_connection.dialect not in [ + "mssql", + "mysql", + "clickhouse", + "duckdb", + ]: + return [inspector.default_schema_name] + if database_connection.dialect == "bigquery": + pattern = r"([^:/]+)://([^/]+)/([^/]+)(\?[^/]+)" + match = re.match(pattern, str(sql_database.engine.url)) + if match: + return [match.group(3)] + elif database_connection.dialect == "databricks": + pattern = r"&schema=([^&]*)" + match = re.search(pattern, str(sql_database.engine.url)) + if match: + return [match.group(1)] + return None + + def remove_schema_in_uri(self, connection_uri: str, dialect: str) -> str: + if dialect in ["snowflake"]: + pattern = r"([^:/]+)://([^:]+):([^@]+)@([^:/]+)(?::(\d+))?/([^/]+)" + match = re.match(pattern, connection_uri) + if match: + return match.group(0) + if dialect in ["bigquery"]: + pattern = r"([^:/]+)://([^/]+)" + match = re.match(pattern, connection_uri) + if match: + return match.group(0) + elif dialect in ["databricks"]: + pattern = r"&schema=[^&]*" + return re.sub(pattern, "", connection_uri) + elif dialect in ["postgresql"]: + pattern = r"\?options=-csearch_path" r"=[^&]*" + return re.sub(pattern, "", connection_uri) + return connection_uri + + def add_schema_in_uri(self, connection_uri: str, schema: str, dialect: str) -> str: + connection_uri = self.remove_schema_in_uri(connection_uri, dialect) + if dialect in ["snowflake", "bigquery"]: + return f"{connection_uri}/{schema}" + if dialect in ["databricks"]: + return f"{connection_uri}&schema={schema}" + if dialect in ["postgresql"]: + return f"{connection_uri}?options=-csearch_path={schema}" + return connection_uri + + def create( + self, database_connection_request: DatabaseConnectionRequest + ) -> DatabaseConnection: + database_connection = DatabaseConnection( + alias=database_connection_request.alias, + connection_uri=database_connection_request.connection_uri.strip(), + schemas=database_connection_request.schemas, + path_to_credentials_file=database_connection_request.path_to_credentials_file, + llm_api_key=database_connection_request.llm_api_key, + use_ssh=database_connection_request.use_ssh, + ssh_settings=database_connection_request.ssh_settings, + file_storage=database_connection_request.file_storage, + metadata=database_connection_request.metadata, + ) + if database_connection.schemas and database_connection.dialect in [ + "redshift", + "awsathena", + "mssql", + "mysql", + "clickhouse", + "duckdb", + ]: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {database_connection.dialect} dialect doesn't support schemas", + ) + if not database_connection.schemas: + database_connection.schemas = self.get_current_schema(database_connection) + + schemas_and_tables = {} + fernet_encrypt = FernetEncrypt() + + if database_connection.schemas: + for schema in database_connection.schemas: + database_connection.connection_uri = fernet_encrypt.encrypt( + self.add_schema_in_uri( + database_connection_request.connection_uri.strip(), + schema, + str(database_connection.dialect), + ) + ) + sql_database = SQLDatabase.get_sql_engine(database_connection, True) + schemas_and_tables[schema] = sql_database.get_tables_and_views() + else: + sql_database = SQLDatabase.get_sql_engine(database_connection, True) + schemas_and_tables[None] = sql_database.get_tables_and_views() + + # Connect db + db_connection_repository = DatabaseConnectionRepository(self.storage) + database_connection.connection_uri = fernet_encrypt.encrypt( + self.remove_schema_in_uri( + database_connection_request.connection_uri.strip(), + str(database_connection.dialect), + ) + ) + db_connection = db_connection_repository.insert(database_connection) + + scanner_repository = TableDescriptionRepository(self.storage) + # Add created tables + for schema, tables in schemas_and_tables.items(): + self.scanner.create_tables( + tables, str(db_connection.id), schema, scanner_repository + ) + return DatabaseConnection(**db_connection.dict()) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 16fc94d4..e997920e 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -11,10 +11,10 @@ from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, LLMResult -from langchain.schema.messages import BaseMessage from langchain_community.callbacks import get_openai_callback from dataherald.config import Component, System +from dataherald.db_scanner.models.types import TableDescription from dataherald.model.chat_model import ChatModel from dataherald.repositories.sql_generations import ( SQLGenerationRepository, @@ -62,6 +62,21 @@ def remove_markdown(self, query: str) -> str: return matches[0].strip() return query + @staticmethod + def get_table_schema(table_name: str, db_scan: List[TableDescription]) -> str: + for table in db_scan: + if table.table_name == table_name: + return table.schema_name + return "" + + @staticmethod + def filter_tables_by_schema( + db_scan: List[TableDescription], prompt: Prompt + ) -> List[TableDescription]: + if prompt.schemas: + return [table for table in db_scan if table.schema_name in prompt.schemas] + return db_scan + def format_sql_query_intermediate_steps(self, step: str) -> str: pattern = r"```sql(.*?)```" diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 5fe0ed4d..5e2b423c 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -190,12 +190,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s tables = Parser(example["sql"]).tables except Exception as e: logger.error(f"Error parsing SQL: {str(e)}") - most_similar_tables.update(tables) - df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + for table in tables: + found_tables = df[df.table_name == table] + for _, row in found_tables.iterrows(): + most_similar_tables.add((row["schema_name"], row["table_name"])) + df.drop( + df[ + df.table_name.isin([table[1] for table in most_similar_tables]) + ].index, + inplace=True, + ) return most_similar_tables @catch_exceptions() - def _run( + def _run( # noqa: PLR0912 self, user_question: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -214,9 +222,12 @@ def _run( table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}" else: table_rep = f"Table {table.table_name} contain columns: [{col_rep}]" - table_representations.append([table.table_name, table_rep]) + table_representations.append( + [table.schema_name, table.table_name, table_rep] + ) df = pd.DataFrame( - table_representations, columns=["table_name", "table_representation"] + table_representations, + columns=["schema_name", "table_name", "table_representation"], ) df["table_embedding"] = self.get_docs_embedding(df.table_representation) df["similarities"] = df.table_embedding.apply( @@ -227,12 +238,20 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' + if row["schema_name"] is not None: + table_name = row["schema_name"] + "." + row["table_name"] + else: + table_name = row["table_name"] + table_relevance += ( + f'Table: `{table_name}`, relevance score: {row["similarities"]}\n' + ) if len(most_similar_tables) > 0: for table in most_similar_tables: - table_relevance += ( - f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" - ) + if table[0] is not None: + table_name = table[0] + "." + table[1] + else: + table_name = table[1] + table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n" return table_relevance async def _arun( @@ -358,27 +377,32 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( + def _run( # noqa: C901 self, table_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Get the schema for tables in a comma-separated list.""" table_names_list = table_names.split(", ") - table_names_list = [ - replace_unprocessable_characters(table_name) - for table_name in table_names_list - ] + processed_table_names = [] + for table in table_names_list: + formatted_table = replace_unprocessable_characters(table) + if "." in formatted_table: + processed_table_names.append(formatted_table.split(".")[1]) + else: + processed_table_names.append(formatted_table) tables_schema = "" for table in self.db_scan: - if table.table_name in table_names_list: + if table.table_name in processed_table_names: tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: - descriptions.append( - f"Table `{table.table_name}`: {table.description}\n" - ) + if table.schema_name: + table_name = f"{table.schema_name}.{table.table_name}" + else: + table_name = table.table_name + descriptions.append(f"Table `{table_name}`: {table.description}\n") for column in table.columns: if column.description is not None: descriptions.append( @@ -555,6 +579,9 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=5 ) @@ -658,6 +685,9 @@ def stream_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) _, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=1 ) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 899e6e87..b686bf2f 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -10,7 +10,6 @@ import numpy as np import openai import pandas as pd -import sqlalchemy from google.api_core.exceptions import GoogleAPIError from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit @@ -27,9 +26,7 @@ from overrides import override from pydantic import BaseModel, Field from sql_metadata import Parser -from sqlalchemy import MetaData from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.sql import func from dataherald.context_store import ContextStore from dataherald.db import DB @@ -254,12 +251,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s tables = Parser(example["sql"]).tables except Exception as e: logger.error(f"Error parsing SQL: {str(e)}") - most_similar_tables.update(tables) - df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + for table in tables: + found_tables = df[df.table_name == table] + for _, row in found_tables.iterrows(): + most_similar_tables.add((row["schema_name"], row["table_name"])) + df.drop( + df[ + df.table_name.isin([table[1] for table in most_similar_tables]) + ].index, + inplace=True, + ) return most_similar_tables @catch_exceptions() - def _run( + def _run( # noqa: PLR0912 self, user_question: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -278,9 +283,12 @@ def _run( table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}" else: table_rep = f"Table {table.table_name} contain columns: [{col_rep}]" - table_representations.append([table.table_name, table_rep]) + table_representations.append( + [table.schema_name, table.table_name, table_rep] + ) df = pd.DataFrame( - table_representations, columns=["table_name", "table_representation"] + table_representations, + columns=["schema_name", "table_name", "table_representation"], ) df["table_embedding"] = self.get_docs_embedding(df.table_representation) df["similarities"] = df.table_embedding.apply( @@ -291,12 +299,20 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' + if row["schema_name"] is not None: + table_name = row["schema_name"] + "." + row["table_name"] + else: + table_name = row["table_name"] + table_relevance += ( + f'Table: `{table_name}`, relevance score: {row["similarities"]}\n' + ) if len(most_similar_tables) > 0: for table in most_similar_tables: - table_relevance += ( - f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" - ) + if table[0] is not None: + table_name = table[0] + "." + table[1] + else: + table_name = table[1] + table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n" return table_relevance async def _arun( @@ -318,6 +334,8 @@ class ColumnEntityChecker(BaseSQLDatabaseTool, BaseTool): Example Input: table1 -> column2, entity """ + db_scan: List[TableDescription] + is_multiple_schema: bool def find_similar_strings( self, input_list: List[tuple], target_string: str, threshold=0.4 @@ -341,21 +359,25 @@ def _run( try: schema, entity = tool_input.split(",") table_name, column_name = schema.split("->") + table_name = replace_unprocessable_characters(table_name) + column_name = replace_unprocessable_characters(column_name).strip() + if "." not in table_name and self.is_multiple_schema: + raise Exception( + "Table name should be in the format schema_name.table_name" + ) except ValueError: return "Invalid input format, use following format: table_name -> column_name, entity (entity should be a string without ',')" search_pattern = f"%{entity.strip().lower()}%" - meta = MetaData(bind=self.db.engine) - table = sqlalchemy.Table(table_name.strip(), meta, autoload=True) + search_query = f"SELECT DISTINCT {column_name} FROM {table_name} WHERE {column_name} ILIKE :search_pattern" # noqa: S608 try: - search_query = sqlalchemy.select( - [func.distinct(table.c[column_name.strip()])] - ).where(func.lower(table.c[column_name.strip()]).like(search_pattern)) - search_results = self.db.engine.execute(search_query).fetchall() + search_results = self.db.engine.execute( + search_query, {"search_pattern": search_pattern} + ).fetchall() search_results = search_results[:25] except SQLAlchemyError: search_results = [] - distinct_query = sqlalchemy.select( - [func.distinct(table.c[column_name.strip()])] + distinct_query = ( + f"SELECT DISTINCT {column_name} FROM {table_name}" # noqa: S608 ) results = self.db.engine.execute(distinct_query).fetchall() results = self.find_similar_strings(results, entity) @@ -392,26 +414,31 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( + def _run( # noqa: C901 self, table_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Get the schema for tables in a comma-separated list.""" table_names_list = table_names.split(", ") - table_names_list = [ - replace_unprocessable_characters(table_name) - for table_name in table_names_list - ] + processed_table_names = [] + for table in table_names_list: + formatted_table = replace_unprocessable_characters(table) + if "." in formatted_table: + processed_table_names.append(formatted_table.split(".")[1]) + else: + processed_table_names.append(formatted_table) tables_schema = "```sql\n" for table in self.db_scan: - if table.table_name in table_names_list: + if table.table_name in processed_table_names: tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: - descriptions.append( - f"Table `{table.table_name}`: {table.description}\n" - ) + if table.schema_name: + table_name = f"{table.schema_name}.{table.table_name}" + else: + table_name = table.table_name + descriptions.append(f"Table `{table_name}`: {table.description}\n") for column in table.columns: if column.description is not None: descriptions.append( @@ -446,7 +473,7 @@ class InfoRelevantColumns(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( # noqa: C901 + def _run( # noqa: C901, PLR0912 self, column_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -457,6 +484,8 @@ def _run( # noqa: C901 for item in items_list: if " -> " in item: table_name, column_name = item.split(" -> ") + if "." in table_name: + table_name = table_name.split(".")[1] table_name = replace_unprocessable_characters(table_name) column_name = replace_unprocessable_characters(column_name) found = False @@ -474,7 +503,11 @@ def _run( # noqa: C901 for row in table.examples: col_info += row[column_name] + ", " col_info = col_info[:-2] - column_full_info += f"Table: {table_name}, column: {column_name}, additional info: {col_info}\n" + if table.schema_name: + schema_table = f"{table.schema_name}.{table.table_name}" + else: + schema_table = table.table_name + column_full_info += f"Table: {schema_table}, column: {column_name}, additional info: {col_info}\n" else: return "Malformed input, input should be in the following format Example Input: table1 -> column1, table1 -> column2, table2 -> column1" # noqa: E501 if not found: @@ -539,6 +572,7 @@ class SQLDatabaseToolkit(BaseToolkit): instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableDescription] = Field(exclude=True) embedding: OpenAIEmbeddings = Field(exclude=True) + is_multiple_schema: bool = False @property def dialect(self) -> str: @@ -579,7 +613,12 @@ def get_tools(self) -> List[BaseTool]: db=self.db, context=self.context, db_scan=self.db_scan ) tools.append(info_relevant_tool) - column_sample_tool = ColumnEntityChecker(db=self.db, context=self.context) + column_sample_tool = ColumnEntityChecker( + db=self.db, + context=self.context, + db_scan=self.db_scan, + is_multiple_schema=self.is_multiple_schema, + ) tools.append(column_sample_tool) if self.few_shot_examples is not None: get_fewshot_examples_tool = GetFewShotExamples( @@ -700,6 +739,9 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=self.max_number_of_examples ) @@ -716,6 +758,7 @@ def generate_response( context=context, few_shot_examples=new_fewshot_examples, instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, embedding=OpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), @@ -802,6 +845,9 @@ def stream_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=self.max_number_of_examples ) @@ -818,6 +864,7 @@ def stream_response( context=[{}], few_shot_examples=new_fewshot_examples, instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, embedding=OpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), diff --git a/dataherald/tests/test_api.py b/dataherald/tests/test_api.py index 5d086bc9..87869812 100644 --- a/dataherald/tests/test_api.py +++ b/dataherald/tests/test_api.py @@ -12,24 +12,3 @@ def test_heartbeat(): response = client.get("/api/v1/heartbeat") assert response.status_code == HTTP_200_CODE - - -def test_scan_all_tables(): - response = client.post( - "/api/v1/table-descriptions/sync-schemas", - json={"db_connection_id": "64dfa0e103f5134086f7090c"}, - ) - assert response.status_code == HTTP_201_CODE - - -def test_scan_one_table(): - try: - client.post( - "/api/v1/table-descriptions/sync-schemas", - json={ - "db_connection_id": "64dfa0e103f5134086f7090c", - "table_names": ["foo"], - }, - ) - except ValueError as e: - assert str(e) == "No table found" diff --git a/dataherald/types.py b/dataherald/types.py index 44566756..de113e72 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -78,15 +78,25 @@ class SupportedDatabase(Enum): BIGQUERY = "BIGQUERY" -class ScannerRequest(DBConnectionValidation): - table_names: list[str] | None +class ScannerRequest(BaseModel): + ids: list[str] | None metadata: dict | None + @validator("ids") + def ids_validation(cls, ids: list = None): + try: + for id in ids: + ObjectId(id) + except InvalidId: + raise ValueError("Must be a valid ObjectId") # noqa: B904 + return ids + class DatabaseConnectionRequest(BaseModel): alias: str use_ssh: bool = False connection_uri: str + schemas: list[str] | None path_to_credentials_file: str | None llm_api_key: str | None ssh_settings: SSHSettings | None @@ -140,6 +150,7 @@ class Finetuning(BaseModel): id: str | None = None alias: str | None = None db_connection_id: str | None = None + schemas: list[str] | None status: str = "QUEUED" error: str | None = None base_llm: BaseLLM | None = None @@ -153,6 +164,7 @@ class Finetuning(BaseModel): class FineTuningRequest(BaseModel): db_connection_id: str + schemas: list[str] | None alias: str | None = None base_llm: BaseLLM | None = None golden_sqls: list[str] | None = None @@ -168,6 +180,7 @@ class Prompt(BaseModel): id: str | None = None text: str db_connection_id: str + schemas: list[str] | None created_at: datetime = Field(default_factory=datetime.now) metadata: dict | None diff --git a/dataherald/utils/error_codes.py b/dataherald/utils/error_codes.py index 6680feb5..3b508f72 100644 --- a/dataherald/utils/error_codes.py +++ b/dataherald/utils/error_codes.py @@ -19,6 +19,7 @@ "SQLGenerationNotFoundError": "sql_generation_not_found", "NLGenerationError": "nl_generation_not_created", "MalformedGoldenSQLError": "invalid_golden_sql", + "SchemaNotSupportedError": "schema_not_supported", } diff --git a/dataherald/utils/sql_utils.py b/dataherald/utils/sql_utils.py new file mode 100644 index 00000000..a1ab3680 --- /dev/null +++ b/dataherald/utils/sql_utils.py @@ -0,0 +1,47 @@ +from sql_metadata import Parser + +from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.sql_database.services.database_connection import SchemaNotSupportedError +from dataherald.types import FineTuningRequest, GoldenSQL + + +def extract_the_schemas_from_sql(sql: str) -> list[str]: + table_names = Parser(sql).tables + schemas = [] + for table_name in table_names: + if "." in table_name: + schema = table_name.split(".")[0] + schemas.append(schema.strip()) + return schemas + + +def filter_golden_records_based_on_schema( + golden_sqls: list[GoldenSQL], schemas: list[str] +) -> list[GoldenSQL]: + filtered_records = [] + if not schemas: + return golden_sqls + for record in golden_sqls: + used_schemas = extract_the_schemas_from_sql(record.sql) + for schema in schemas: + if schema in used_schemas: + filtered_records.append(record) + break + return filtered_records + + +def validate_finetuning_schema( + finetuning_request: FineTuningRequest, db_connection: DatabaseConnection +): + if finetuning_request.schemas: + if not db_connection.schemas: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {db_connection.id} db doesn't have schemas", + ) + for schema in finetuning_request.schemas: + if schema not in db_connection.schemas: + raise SchemaNotSupportedError( + f"Schema {schema} not supported for this db", + description=f"The {db_connection.dialect} dialect doesn't support schema {schema}", + ) diff --git a/docs/api.create_database_connection.rst b/docs/api.create_database_connection.rst index 60e7ba77..c2fbf724 100644 --- a/docs/api.create_database_connection.rst +++ b/docs/api.create_database_connection.rst @@ -26,6 +26,9 @@ Once the database connection is established, it retrieves the table names and cr "alias": "string", "use_ssh": false, "connection_uri": "string", + "schemas": [ + "string" + ], "path_to_credentials_file": "string", "llm_api_key": "string", "ssh_settings": { @@ -189,7 +192,7 @@ Connections to supported Data warehouses ----------------------------------------- The format of the ``connection_uri`` parameter in the API call will depend on the data warehouse type you are connecting to. -You can find samples and how to generate them :ref:. +You can find samples and how to generate them below. Postgres ^^^^^^^^^^^^ @@ -324,3 +327,23 @@ Example:: "connection_uri": bigquery://v2-real-estate/K2 +**Connecting multi-schemas** + +You can connect many schemas using one db connection if you want to create SQL joins between schemas. +Currently only `BigQuery`, `Snowflake`, `Databricks` and `Postgres` support this feature. +To use multi-schemas instead of sending the `schema` in the `connection_uri` set it in the `schemas` param, like this: + +**Example** + +.. code-block:: rst + + curl -X 'POST' \ + '/api/v1/database-connections' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "alias": "my_db_alias_identifier", + "use_ssh": false, + "connection_uri": "snowflake://:@-/", + "schemas": ["foo", "bar"] + }' diff --git a/docs/api.create_prompt.rst b/docs/api.create_prompt.rst index e0bf7982..71867776 100644 --- a/docs/api.create_prompt.rst +++ b/docs/api.create_prompt.rst @@ -17,6 +17,9 @@ Request this ``POST`` endpoint to create a finetuning job:: { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } @@ -31,7 +34,10 @@ HTTP 201 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } **Request example** diff --git a/docs/api.create_prompt_sql_generation.rst b/docs/api.create_prompt_sql_generation.rst index 3226e7d9..5b486e17 100644 --- a/docs/api.create_prompt_sql_generation.rst +++ b/docs/api.create_prompt_sql_generation.rst @@ -45,6 +45,9 @@ Request this ``POST`` endpoint to create a SQL query and a NL response for a giv "prompt": { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } } diff --git a/docs/api.create_prompt_sql_generation_nl_generation.rst b/docs/api.create_prompt_sql_generation_nl_generation.rst index 8b2f62af..7851b9c7 100644 --- a/docs/api.create_prompt_sql_generation_nl_generation.rst +++ b/docs/api.create_prompt_sql_generation_nl_generation.rst @@ -61,6 +61,9 @@ Request this ``POST`` endpoint to create a SQL query and a NL response for a giv "prompt": { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } } diff --git a/docs/api.get_prompt.rst b/docs/api.get_prompt.rst index fa02f0be..f86a9043 100644 --- a/docs/api.get_prompt.rst +++ b/docs/api.get_prompt.rst @@ -16,7 +16,10 @@ HTTP 200 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } **Request example** @@ -37,4 +40,7 @@ HTTP 200 code response "created_at": "2024-01-03 15:40:27.624000+00:00", "text": "What is the expected total full year sales in 2023?", "db_connection_id": "659435a5453d359345834" + "schemas": [ + "public" + ] } \ No newline at end of file diff --git a/docs/api.get_table_description.rst b/docs/api.get_table_description.rst index 330e89c7..76b0f8dc 100644 --- a/docs/api.get_table_description.rst +++ b/docs/api.get_table_description.rst @@ -24,6 +24,7 @@ HTTP 200 code response "table_schema": "string", "status": "NOT_SCANNED | SYNCHRONIZING | DEPRECATED | SCANNED | FAILED" "error_message": "string", + "table_schema": "string", "columns": [ { "name": "string", diff --git a/docs/api.list_database_connections.rst b/docs/api.list_database_connections.rst index 1396ee23..d688d912 100644 --- a/docs/api.list_database_connections.rst +++ b/docs/api.list_database_connections.rst @@ -21,6 +21,7 @@ HTTP 200 code response "dialect": "databricks", "use_ssh": false, "connection_uri": "foooAABk91Q4wjoR2h07GR7_72BdQnxi8Rm6i_EjyS-mzz_o2c3RAWaEqnlUvkK5eGD5kUfE5xheyivl1Wfbk_EM7CgV4SvdLmOOt7FJV-3kG4zAbar=", + "schemas": null, "path_to_credentials_file": null, "llm_api_key": null, "ssh_settings": null @@ -31,6 +32,7 @@ HTTP 200 code response "dialect": "postgres", "use_ssh": true, "connection_uri": null, + "schemas": null, "path_to_credentials_file": "bar-LWxPdFcjQw9lU7CeK_2ELR3jGBq0G_uQ7E2rfPLk2RcFR4aDO9e2HmeAQtVpdvtrsQ_0zjsy9q7asdsadXExYJ0g==", "llm_api_key": "gAAAAABlCz5TeU0ym4hW3bf9u21dz7B9tlnttOGLRDt8gq2ykkblNvpp70ZjT9FeFcoyMv-Csvp3GNQfw66eYvQBrcBEPsLokkLO2Jc2DD-Q8Aw6g_8UahdOTxJdT4izA6MsiQrf7GGmYBGZqbqsjTdNmcq661wF9Q==", "ssh_settings": { diff --git a/docs/api.list_prompts.rst b/docs/api.list_prompts.rst index 9f434a21..2122f117 100644 --- a/docs/api.list_prompts.rst +++ b/docs/api.list_prompts.rst @@ -29,7 +29,10 @@ HTTP 201 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } ] diff --git a/docs/api.list_table_description.rst b/docs/api.list_table_description.rst index e257805d..9bff1318 100644 --- a/docs/api.list_table_description.rst +++ b/docs/api.list_table_description.rst @@ -33,6 +33,7 @@ HTTP 200 code response "table_schema": "string", "status": "NOT_SCANNED | SYNCHRONIZING | DEPRECATED | SCANNED | FAILED" "error_message": "string", + "table_schema": "string", "columns": [ { "name": "string", diff --git a/docs/api.refresh_table_description.rst b/docs/api.refresh_table_description.rst index 74c600c1..6e392e79 100644 --- a/docs/api.refresh_table_description.rst +++ b/docs/api.refresh_table_description.rst @@ -34,6 +34,7 @@ HTTP 201 code response "table_schema": "string", "status": "NOT_SCANNED | SYNCHRONIZING | DEPRECATED | SCANNED | FAILED" "error_message": "string", + "table_schema": "string", "columns": [ { "name": "string", diff --git a/docs/api.scan_table_description.rst b/docs/api.scan_table_description.rst index 1ccc9f8e..55488cc1 100644 --- a/docs/api.scan_table_description.rst +++ b/docs/api.scan_table_description.rst @@ -9,7 +9,7 @@ which consist of historical queries associated with each database table. These r query_history collection. The historical queries retrieved encompass data from the past three months and are grouped based on query and user. -It can scan all db tables or if you specify a `table_names` then It will only scan those tables. +The `ids` param is used to set the table description ids that you want to scan. The process is carried out through Background Tasks, ensuring that even if it operates slowly, taking several minutes, the HTTP response remains swift. @@ -23,7 +23,7 @@ Request this ``POST`` endpoint:: { "db_connection_id": "string", - "table_names": ["string"] # Optional + "ids": ["string"] } **Responses** @@ -36,7 +36,6 @@ HTTP 201 code response **Request example** -To scan all the tables in a db don't specify a `table_names` .. code-block:: rst @@ -45,5 +44,6 @@ To scan all the tables in a db don't specify a `table_names` -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_connection_id": "db_connection_id" + "db_connection_id": "db_connection_id", + "ids": ["14e52c5f7d6dc4bc510d6d27", "15e52c5f7d6dc4bc510d6d34"] }'