Skip to content

Commit

Permalink
[DH-5733] Support schemas column to add a db connection (#466)
Browse files Browse the repository at this point in the history
* [DH-5733] Support schemas column to add a db connection

* DBs without schema should store None

* Add ids in sync-schemas endpoint

* Support multi-schemas for refresh endpoint

* Add schema not support error exception

* Add documentation for multi-schemas

* Fix sync schema method

* Sync-schemas endpoint let adding ids from different db connection

* Fix refresh endpoint

* Fix table-description storage

* Fix schema_name filter in table-description repository

* DH-5735/add support for multiple schemas for agents

* DH-5766/adding the validation to raise exception for queries without schema in multiple schema setting

* DH-5765/add support multiple schema for finetuning

---------

Co-authored-by: mohammadrezapourreza <[email protected]>
  • Loading branch information
2 people authored and DishenWang2023 committed May 7, 2024
1 parent 817581f commit 31b88df
Show file tree
Hide file tree
Showing 32 changed files with 594 additions and 198 deletions.
25 changes: 22 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,24 @@ curl -X 'POST' \
}'
```

##### Connecting multi-schemas
You can connect many schemas using one db connection if you want to create SQL joins between schemas.
Currently only `BigQuery`, `Snowflake`, `Databricks` and `Postgres` support this feature.
To use multi-schemas instead of sending the `schema` in the `connection_uri` set it in the `schemas` param, like this:

```
curl -X 'POST' \
'<host>/api/v1/database-connections' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"alias": "my_db_alias",
"use_ssh": false,
"connection_uri": snowflake://<user>:<password>@<organization>-<account-name>/<database>",
"schemas": ["schema_1", "schema_2", ...]
}'
```

##### Connecting to supported Data warehouses and using SSH
You can find the details on how to connect to the supported data warehouses in the [docs](https://dataherald.readthedocs.io/en/latest/api.create_database_connection.html)

Expand All @@ -235,7 +253,8 @@ While only the Database scan part is required to start generating SQL, adding ve
#### Scanning the Database
The database scan is used to gather information about the database including table and column names and identifying low cardinality columns and their values to be stored in the context store and used in the prompts to the LLM.
In addition, it retrieves logs, which consist of historical queries associated with each database table. These records are then stored within the query_history collection. The historical queries retrieved encompass data from the past three months and are grouped based on query and user.
db_connection_id is the id of the database connection you want to scan, which is returned when you create a database connection.
The db_connection_id param is the id of the database connection you want to scan, which is returned when you create a database connection.
The ids param is the table_description_id that you want to scan.
You can trigger a scan of a database from the `POST /api/v1/table-descriptions/sync-schemas` endpoint. Example below


Expand All @@ -246,11 +265,11 @@ curl -X 'POST' \
-H 'Content-Type: application/json' \
-d '{
"db_connection_id": "db_connection_id",
"table_names": ["table_name"]
"ids": ["<table_description_id_1>", "<table_description_id_2>", ...]
}'
```

Since the endpoint identifies low cardinality columns (and their values) it can take time to complete. Therefore while it is possible to trigger a scan on the entire DB by not specifying the `table_names`, we recommend against it for large databases.
Since the endpoint identifies low cardinality columns (and their values) it can take time to complete.

#### Get logs per db connection
Once a database was scanned you can use this endpoint to retrieve the tables logs
Expand Down
132 changes: 67 additions & 65 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
SQLInjectionError,
)
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_database.services.database_connection import (
DatabaseConnectionService,
)
from dataherald.types import (
BaseLLM,
CancelFineTuningRequest,
Expand All @@ -88,17 +91,20 @@
)
from dataherald.utils.encrypt import FernetEncrypt
from dataherald.utils.error_codes import error_response, stream_error_response
from dataherald.utils.sql_utils import (
filter_golden_records_based_on_schema,
validate_finetuning_schema,
)

logger = logging.getLogger(__name__)

MAX_ROWS_TO_CREATE_CSV_FILE = 50


def async_scanning(scanner, database, scanner_request, storage):
def async_scanning(scanner, database, table_descriptions, storage):
scanner.scan(
database,
scanner_request.db_connection_id,
scanner_request.table_names,
table_descriptions,
TableDescriptionRepository(storage),
QueryHistoryRepository(storage),
)
Expand Down Expand Up @@ -130,70 +136,52 @@ def scan_db(
self, scanner_request: ScannerRequest, background_tasks: BackgroundTasks
) -> list[TableDescriptionResponse]:
"""Takes a db_connection_id and scan all the tables columns"""
try:
db_connection_repository = DatabaseConnectionRepository(self.storage)

db_connection = db_connection_repository.find_by_id(
scanner_request.db_connection_id
)
scanner_repository = TableDescriptionRepository(self.storage)
data = {}
for id in scanner_request.ids:
table_description = scanner_repository.find_by_id(id)
if not table_description:
raise Exception("Table description not found")
if table_description.db_connection_id not in data.keys():
data[table_description.db_connection_id] = {}
if (
table_description.schema_name
not in data[table_description.db_connection_id].keys()
):
data[table_description.db_connection_id][
table_description.schema_name
] = []
data[table_description.db_connection_id][
table_description.schema_name
].append(table_description)

if not db_connection:
raise DatabaseConnectionNotFoundError(
f"Database connection {scanner_request.db_connection_id} not found"
db_connection_repository = DatabaseConnectionRepository(self.storage)
scanner = self.system.instance(Scanner)
rows = scanner.synchronizing(
scanner_request,
TableDescriptionRepository(self.storage),
)
database_connection_service = DatabaseConnectionService(scanner, self.storage)
for db_connection_id, schemas_and_table_descriptions in data.items():
for schema, table_descriptions in schemas_and_table_descriptions.items():
db_connection = db_connection_repository.find_by_id(db_connection_id)
database = database_connection_service.get_sql_database(
db_connection, schema
)

database = SQLDatabase.get_sql_engine(db_connection, True)
all_tables = database.get_tables_and_views()

if scanner_request.table_names:
for table in scanner_request.table_names:
if table not in all_tables:
raise HTTPException(
status_code=404,
detail=f"Table named: {table} doesn't exist",
) # noqa: B904
else:
scanner_request.table_names = all_tables

scanner = self.system.instance(Scanner)
rows = scanner.synchronizing(
scanner_request,
TableDescriptionRepository(self.storage),
)
except Exception as e:
return error_response(e, scanner_request.dict(), "invalid_database_sync")

background_tasks.add_task(
async_scanning, scanner, database, scanner_request, self.storage
)
background_tasks.add_task(
async_scanning, scanner, database, table_descriptions, self.storage
)
return [TableDescriptionResponse(**row.dict()) for row in rows]

@override
def create_database_connection(
self, database_connection_request: DatabaseConnectionRequest
) -> DatabaseConnectionResponse:
try:
db_connection = DatabaseConnection(
alias=database_connection_request.alias,
connection_uri=database_connection_request.connection_uri.strip(),
path_to_credentials_file=database_connection_request.path_to_credentials_file,
llm_api_key=database_connection_request.llm_api_key,
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
file_storage=database_connection_request.file_storage,
metadata=database_connection_request.metadata,
)
sql_database = SQLDatabase.get_sql_engine(db_connection, True)

# Get tables and views and create table-descriptions as NOT_SCANNED
db_connection_repository = DatabaseConnectionRepository(self.storage)

scanner_repository = TableDescriptionRepository(self.storage)
scanner = self.system.instance(Scanner)

tables = sql_database.get_tables_and_views()
db_connection = db_connection_repository.insert(db_connection)
scanner.create_tables(tables, str(db_connection.id), scanner_repository)
db_connection_service = DatabaseConnectionService(scanner, self.storage)
db_connection = db_connection_service.create(database_connection_request)
except Exception as e:
# Encrypt sensible values
fernet_encrypt = FernetEncrypt()
Expand All @@ -209,7 +197,6 @@ def create_database_connection(
return error_response(
e, database_connection_request.dict(), "invalid_database_connection"
)

return DatabaseConnectionResponse(**db_connection.dict())

@override
Expand All @@ -220,18 +207,30 @@ def refresh_table_description(
db_connection = db_connection_repository.find_by_id(
refresh_table_description.db_connection_id
)

scanner = self.system.instance(Scanner)
database_connection_service = DatabaseConnectionService(scanner, self.storage)
try:
sql_database = SQLDatabase.get_sql_engine(db_connection, True)
tables = sql_database.get_tables_and_views()
data = {}
if db_connection.schemas:
for schema in db_connection.schemas:
sql_database = database_connection_service.get_sql_database(
db_connection, schema
)
if schema not in data.keys():
data[schema] = []
data[schema] = sql_database.get_tables_and_views()
else:
sql_database = database_connection_service.get_sql_database(
db_connection
)
data[None] = sql_database.get_tables_and_views()

# Get tables and views and create missing table-descriptions as NOT_SCANNED and update DEPRECATED
scanner_repository = TableDescriptionRepository(self.storage)
scanner = self.system.instance(Scanner)

return [
TableDescriptionResponse(**record.dict())
for record in scanner.refresh_tables(
tables, str(db_connection.id), scanner_repository
data, str(db_connection.id), scanner_repository
)
]
except Exception as e:
Expand Down Expand Up @@ -569,15 +568,14 @@ def create_finetuning_job(
) -> Finetuning:
try:
db_connection_repository = DatabaseConnectionRepository(self.storage)

db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
if not db_connection:
raise DatabaseConnectionNotFoundError(
f"Database connection not found, {fine_tuning_request.db_connection_id}"
)

validate_finetuning_schema(fine_tuning_request, db_connection)
golden_sqls_repository = GoldenSQLRepository(self.storage)
golden_sqls = []
if fine_tuning_request.golden_sqls:
Expand All @@ -598,6 +596,9 @@ def create_finetuning_job(
raise GoldenSQLNotFoundError(
f"No golden sqls found for db_connection: {fine_tuning_request.db_connection_id}"
)
golden_sqls = filter_golden_records_based_on_schema(
golden_sqls, fine_tuning_request.schemas
)
default_base_llm = BaseLLM(
model_provider="openai",
model_name="gpt-3.5-turbo-1106",
Expand All @@ -611,6 +612,7 @@ def create_finetuning_job(
model = model_repository.insert(
Finetuning(
db_connection_id=fine_tuning_request.db_connection_id,
schemas=fine_tuning_request.schemas,
alias=fine_tuning_request.alias
if fine_tuning_request.alias
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}",
Expand Down
1 change: 1 addition & 0 deletions dataherald/api/types/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class PromptRequest(BaseModel):
text: str
db_connection_id: str
schemas: list[str] | None
metadata: dict | None


Expand Down
1 change: 1 addition & 0 deletions dataherald/api/types/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def created_at_as_string(cls, v):
class PromptResponse(BaseResponse):
text: str
db_connection_id: str
schemas: list[str] | None


class SQLGenerationResponse(BaseResponse):
Expand Down
13 changes: 13 additions & 0 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataherald.repositories.golden_sqls import GoldenSQLRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt
from dataherald.utils.sql_utils import extract_the_schemas_from_sql

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,6 +87,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
f"Database connection not found, {record.db_connection_id}"
)

if db_connection.schemas:
schema_not_found = True
used_schemas = extract_the_schemas_from_sql(record.sql)
for schema in db_connection.schemas:
if schema in used_schemas:
schema_not_found = False
break
if schema_not_found:
raise MalformedGoldenSQLError(
f"SQL {record.sql} does not contain any of the schemas {db_connection.schemas}"
)

prompt_text = record.prompt_text
golden_sql = GoldenSQL(
prompt_text=prompt_text,
Expand Down
6 changes: 3 additions & 3 deletions dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class Scanner(Component, ABC):
def scan(
self,
db_engine: SQLDatabase,
db_connection_id: str,
table_names: list[str] | None,
table_descriptions: list[TableDescription],
repository: TableDescriptionRepository,
query_history_repository: QueryHistoryRepository,
) -> None:
Expand All @@ -34,6 +33,7 @@ def create_tables(
self,
tables: list[str],
db_connection_id: str,
schema: str,
repository: TableDescriptionRepository,
metadata: dict = None,
) -> None:
Expand All @@ -42,7 +42,7 @@ def create_tables(
@abstractmethod
def refresh_tables(
self,
tables: list[str],
schemas_and_tables: dict[str, list],
db_connection_id: str,
repository: TableDescriptionRepository,
metadata: dict = None,
Expand Down
1 change: 1 addition & 0 deletions dataherald/db_scanner/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TableDescriptionStatus(Enum):
class TableDescription(BaseModel):
id: str | None
db_connection_id: str
schema_name: str | None
table_name: str
description: str | None
table_schema: str | None
Expand Down
13 changes: 9 additions & 4 deletions dataherald/db_scanner/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,18 @@ def save_table_info(self, table_info: TableDescription) -> TableDescription:
table_info_dict = {
k: v for k, v in table_info_dict.items() if v is not None and v != []
}

query = {
"db_connection_id": table_info_dict["db_connection_id"],
"table_name": table_info_dict["table_name"],
}
if "schema_name" in table_info_dict:
query["schema_name"] = table_info_dict["schema_name"]

table_info.id = str(
self.storage.update_or_create(
DB_COLLECTION,
{
"db_connection_id": table_info_dict["db_connection_id"],
"table_name": table_info_dict["table_name"],
},
query,
table_info_dict,
)
)
Expand Down
Loading

0 comments on commit 31b88df

Please sign in to comment.