diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 54316c6d..47662150 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -2,7 +2,7 @@ from typing import List from fastapi import BackgroundTasks -from fastapi.responses import FileResponse +from fastapi.responses import JSONResponse from dataherald.api.types import Query from dataherald.config import Component @@ -39,7 +39,7 @@ def scan_db( @abstractmethod def answer_question( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: pass @@ -47,7 +47,7 @@ def answer_question( @abstractmethod def answer_question_with_timeout( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: pass @@ -109,7 +109,7 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: @abstractmethod def create_response( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, query_request: CreateResponseRequest = None, ) -> Response: pass @@ -123,9 +123,7 @@ def get_response(self, response_id: str) -> Response: pass @abstractmethod - def get_response_file( - self, response_id: str, background_tasks: BackgroundTasks - ) -> FileResponse: + def get_response_file(self, response_id: str) -> JSONResponse: pass @abstractmethod diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index d3cb7d4a..55e372f4 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -8,7 +8,7 @@ from bson import json_util from bson.objectid import InvalidId, ObjectId from fastapi import BackgroundTasks, HTTPException -from fastapi.responses import FileResponse, JSONResponse +from fastapi.responses import JSONResponse from overrides import override from dataherald.api import API @@ -64,10 +64,6 @@ def async_scanning(scanner, database, scanner_request, storage): ) -def async_removing_file(file_path: str): - os.remove(file_path) - - class FastAPI(API): def __init__(self, system: System): super().__init__(system) @@ -125,7 +121,7 @@ def scan_db( @override def answer_question( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: """Takes in an English question and answers it based on content from the registered databases""" @@ -161,7 +157,7 @@ def answer_question( user_question, database_connection, context[0], - store_substantial_query_result_in_csv, + large_query_result_in_csv, ) logger.info("Starts evaluator...") confidence_score = evaluator.get_confidence_score( @@ -172,6 +168,8 @@ def answer_question( status_code=400, content={"question_id": user_question.id, "error_message": str(e)}, ) + if generated_answer.csv_download_url: + generated_answer.sql_query_result = None generated_answer.confidence_score = confidence_score generated_answer.exec_time = time.time() - start_generated_answer response_repository = ResponseRepository(self.storage) @@ -180,7 +178,7 @@ def answer_question( @override def answer_question_with_timeout( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: result = None @@ -197,7 +195,7 @@ def run_and_catch_exceptions(): nonlocal result, exception if not stop_event.is_set(): result = self.answer_question( - store_substantial_query_result_in_csv, question_request + large_query_result_in_csv, question_request ) thread = threading.Thread(target=run_and_catch_exceptions) @@ -226,6 +224,7 @@ def create_database_connection( llm_api_key=database_connection_request.llm_api_key, use_ssh=database_connection_request.use_ssh, ssh_settings=database_connection_request.ssh_settings, + file_storage=database_connection_request.file_storage, ) SQLDatabase.get_sql_engine(db_connection, True) @@ -260,6 +259,7 @@ def update_database_connection( llm_api_key=database_connection_request.llm_api_key, use_ssh=database_connection_request.use_ssh, ssh_settings=database_connection_request.ssh_settings, + file_storage=database_connection_request.file_storage, ) SQLDatabase.get_sql_engine(db_connection, True) @@ -365,9 +365,7 @@ def get_response(self, response_id: str) -> Response: return result @override - def get_response_file( - self, response_id: str, background_tasks: BackgroundTasks - ) -> FileResponse: + def get_response_file(self, response_id: str) -> JSONResponse: response_repository = ResponseRepository(self.storage) try: @@ -378,15 +376,11 @@ def get_response_file( if not result: raise HTTPException(status_code=404, detail="Question not found") - # todo download s3 = S3() - file_path = s3.download(result.csv_file_path) - background_tasks.add_task(async_removing_file, file_path) - return FileResponse( - file_path, - media_type="text/csv", - headers={ - "Content-Disposition": f"attachment; filename={file_path.split('/')[-1]}" + return JSONResponse( + status_code=201, + content={ + "csv_download_url": s3.download_url(result.csv_file_path), }, ) @@ -440,7 +434,7 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: @override def create_response( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, query_request: CreateResponseRequest = None, # noqa: ARG002 ) -> Response: evaluator = self.system.instance(Evaluator) @@ -461,13 +455,13 @@ def create_response( start_generated_answer = time.time() try: generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) - response = generates_nl_answer.execute( - response, store_substantial_query_result_in_csv - ) + response = generates_nl_answer.execute(response, large_query_result_in_csv) confidence_score = evaluator.get_confidence_score( user_question, response, database_connection ) response.confidence_score = confidence_score + if response.csv_download_url: + response.sql_query_result = None response.exec_time = time.time() - start_generated_answer response_repository.update(response) except ValueError as e: diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index eccac896..f0b5db34 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -165,9 +165,9 @@ def __init__(self, settings: Settings): ) self.router.add_api_route( - "/api/v1/responses/{response_id}/file", + "/api/v1/responses/{response_id}/generate-csv-download-url", self.get_response_file, - methods=["GET"], + methods=["POST"], tags=["Responses"], ) @@ -225,16 +225,14 @@ def scan_db( def answer_question( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, question_request: QuestionRequest = None, ) -> Response: if os.getenv("DH_ENGINE_TIMEOUT", None): return self._api.answer_question_with_timeout( - store_substantial_query_result_in_csv, question_request + large_query_result_in_csv, question_request ) - return self._api.answer_question( - store_substantial_query_result_in_csv, question_request - ) + return self._api.answer_question(large_query_result_in_csv, question_request) def get_questions(self, db_connection_id: str | None = None) -> list[Question]: return self._api.get_questions(db_connection_id) @@ -297,11 +295,9 @@ def get_response(self, response_id: str) -> Response: """Get a response""" return self._api.get_response(response_id) - def get_response_file( - self, response_id: str, background_tasks: BackgroundTasks - ) -> FileResponse: + def get_response_file(self, response_id: str) -> JSONResponse: """Get a response file""" - return self._api.get_response_file(response_id, background_tasks) + return self._api.get_response_file(response_id) def execute_sql_query(self, query: Query) -> tuple[str, dict]: """Executes a query on the given db_connection_id""" @@ -309,13 +305,11 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]: def create_response( self, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, query_request: CreateResponseRequest = None, ) -> Response: """Executes a query on the given db_connection_id""" - return self._api.create_response( - store_substantial_query_result_in_csv, query_request - ) + return self._api.create_response(large_query_result_in_csv, query_request) def delete_golden_record(self, golden_record_id: str) -> dict: """Deletes a golden record""" diff --git a/dataherald/sql_database/models/types.py b/dataherald/sql_database/models/types.py index 43973baf..0e59da9b 100644 --- a/dataherald/sql_database/models/types.py +++ b/dataherald/sql_database/models/types.py @@ -22,6 +22,29 @@ def __getitem__(self, key: str) -> Any: return getattr(self, key) +class FileStorage(BaseModel): + name: str + access_key_id: str + secret_access_key: str + region: str | None + bucket: str + + class Config: + extra = Extra.ignore + + @validator("access_key_id", "secret_access_key", pre=True, always=True) + def encrypt(cls, value: str): + fernet_encrypt = FernetEncrypt() + try: + fernet_encrypt.decrypt(value) + return value + except Exception: + return fernet_encrypt.encrypt(value) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + class SSHSettings(BaseSettings): db_name: str | None host: str | None @@ -59,6 +82,7 @@ class DatabaseConnection(BaseModel): path_to_credentials_file: str | None llm_api_key: str | None = None ssh_settings: SSHSettings | None = None + file_storage: FileStorage | None = None @validator("uri", pre=True, always=True) def set_uri_without_ssh(cls, v, values): diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index dbc77b44..b18816d3 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -44,10 +44,16 @@ def create_sql_query_status( query: str, response: Response, top_k: int = None, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, + database_connection: DatabaseConnection | None = None, ) -> Response: return create_sql_query_status( - db, query, response, top_k, store_substantial_query_result_in_csv + db, + query, + response, + top_k, + large_query_result_in_csv, + database_connection=database_connection, ) def format_intermediate_representations( @@ -83,7 +89,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, ) -> Response: """Generates a response to a user question.""" pass diff --git a/dataherald/sql_generator/create_sql_query_status.py b/dataherald/sql_generator/create_sql_query_status.py index 11a8a3df..925c90a5 100644 --- a/dataherald/sql_generator/create_sql_query_status.py +++ b/dataherald/sql_generator/create_sql_query_status.py @@ -6,6 +6,7 @@ from sqlalchemy import text from dataherald.sql_database.base import SQLDatabase, SQLInjectionError +from dataherald.sql_database.models.types import DatabaseConnection from dataherald.types import Response, SQLQueryResult from dataherald.utils.s3 import S3 @@ -28,12 +29,13 @@ def format_error_message(response: Response, error_message: str) -> Response: def create_csv_file( - store_substantial_query_result_in_csv: bool, + large_query_result_in_csv: bool, columns: list, rows: list, response: Response, + database_connection: DatabaseConnection | None = None, ): - if store_substantial_query_result_in_csv and ( + if large_query_result_in_csv and ( len(rows) >= MAX_ROWS_TO_CREATE_CSV_FILE or len(str(rows)) > MAX_CHARACTERS_TO_CREATE_CSV_FILE ): @@ -45,7 +47,9 @@ def create_csv_file( for row in rows: writer.writerow(row.values()) s3 = S3() - s3.upload(file_location) + response.csv_download_url = s3.upload( + file_location, database_connection.file_storage + ) response.csv_file_path = f's3://k2-core/{file_location.split("/")[-1]}' response.sql_query_result = SQLQueryResult(columns=columns, rows=rows) @@ -55,7 +59,8 @@ def create_sql_query_status( query: str, response: Response, top_k: int = None, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, + database_connection: DatabaseConnection | None = None, ) -> Response: """Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message""" if query == "": @@ -95,7 +100,11 @@ def create_sql_query_status( rows.append(modified_row) create_csv_file( - store_substantial_query_result_in_csv, columns, rows, response + large_query_result_in_csv, + columns, + rows, + response, + database_connection, ) response.sql_generation_status = "VALID" diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 411fb273..6fb8fafa 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -595,7 +595,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, ) -> Response: start_time = time.time() context_store = self.system.instance(ContextStore) @@ -685,5 +685,6 @@ def generate_response( response.sql_query, response, top_k=TOP_K, - store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + large_query_result_in_csv=large_query_result_in_csv, + database_connection=database_connection, ) diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index c4a375c8..96eccaad 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -34,7 +34,7 @@ def __init__(self, system, storage): def execute( self, query_response: Response, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, ) -> Response: question_repository = QuestionRepository(self.storage) question = question_repository.find_by_id(query_response.question_id) @@ -54,7 +54,7 @@ def execute( query_response.sql_query, query_response, top_k=50, - store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + large_query_result_in_csv=large_query_result_in_csv, ) system_message_prompt = SystemMessagePromptTemplate.from_template( SYSTEM_TEMPLATE diff --git a/dataherald/sql_generator/langchain_sqlagent.py b/dataherald/sql_generator/langchain_sqlagent.py index 328041e1..e40fc376 100644 --- a/dataherald/sql_generator/langchain_sqlagent.py +++ b/dataherald/sql_generator/langchain_sqlagent.py @@ -29,7 +29,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, ) -> Response: # type: ignore logger.info(f"Generating SQL response to question: {str(user_question.dict())}") self.llm = self.model.get_model( @@ -90,5 +90,5 @@ def generate_response( self.database, response.sql_query, response, - store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + large_query_result_in_csv=large_query_result_in_csv, ) diff --git a/dataherald/sql_generator/langchain_sqlchain.py b/dataherald/sql_generator/langchain_sqlchain.py index 30c3d2c9..f4ab1dda 100644 --- a/dataherald/sql_generator/langchain_sqlchain.py +++ b/dataherald/sql_generator/langchain_sqlchain.py @@ -47,7 +47,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, ) -> Response: start_time = time.time() self.llm = self.model.get_model( @@ -99,5 +99,5 @@ def generate_response( self.database, response.sql_query, response, - store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + large_query_result_in_csv=large_query_result_in_csv, ) diff --git a/dataherald/sql_generator/llamaindex.py b/dataherald/sql_generator/llamaindex.py index 224de63d..3032f2b4 100644 --- a/dataherald/sql_generator/llamaindex.py +++ b/dataherald/sql_generator/llamaindex.py @@ -35,7 +35,7 @@ def generate_response( user_question: Question, database_connection: DatabaseConnection, context: List[dict] = None, - store_substantial_query_result_in_csv: bool = False, + large_query_result_in_csv: bool = False, ) -> Response: start_time = time.time() logger.info(f"Generating SQL response to question: {str(user_question.dict())}") @@ -114,5 +114,5 @@ def generate_response( self.database, response.sql_query, response, - store_substantial_query_result_in_csv=store_substantial_query_result_in_csv, + large_query_result_in_csv=large_query_result_in_csv, ) diff --git a/dataherald/types.py b/dataherald/types.py index 70524e29..e7ba8c1f 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -5,7 +5,7 @@ from bson.objectid import ObjectId from pydantic import BaseModel, Field, validator -from dataherald.sql_database.models.types import SSHSettings +from dataherald.sql_database.models.types import FileStorage, SSHSettings class DBConnectionValidation(BaseModel): @@ -76,6 +76,7 @@ class Response(BaseModel): sql_query: str sql_query_result: SQLQueryResult | None csv_file_path: str | None + csv_download_url: str | None sql_generation_status: str = "INVALID" error_message: str | None exec_time: float | None = None @@ -118,6 +119,7 @@ class DatabaseConnectionRequest(BaseModel): path_to_credentials_file: str | None llm_api_key: str | None ssh_settings: SSHSettings | None + file_storage: FileStorage | None class ForeignKeyDetail(BaseModel): diff --git a/dataherald/utils/s3.py b/dataherald/utils/s3.py index dd575529..f73dcd1f 100644 --- a/dataherald/utils/s3.py +++ b/dataherald/utils/s3.py @@ -4,6 +4,7 @@ from cryptography.fernet import InvalidToken from dataherald.config import Settings +from dataherald.sql_database.models.types import FileStorage from dataherald.utils.encrypt import FernetEncrypt @@ -11,21 +12,64 @@ class S3: def __init__(self): self.settings = Settings() - def upload(self, file_location) -> str: + def upload(self, file_location, file_storage: FileStorage | None = None) -> str: file_name = file_location.split("/")[-1] bucket_name = "k2-core" # Upload the file - s3_client = boto3.client( - "s3", - aws_access_key_id=self.settings.s3_aws_access_key_id, - aws_secret_access_key=self.settings.s3_aws_secret_access_key, - ) + if file_storage: + fernet_encrypt = FernetEncrypt() + bucket_name = file_storage.bucket + s3_client = boto3.client( + "s3", + aws_access_key_id=fernet_encrypt.decrypt(file_storage.access_key_id), + aws_secret_access_key=fernet_encrypt.decrypt( + file_storage.secret_access_key + ), + region_name=file_storage.region, + ) + else: + s3_client = boto3.client( + "s3", + aws_access_key_id=self.settings.s3_aws_access_key_id, + aws_secret_access_key=self.settings.s3_aws_secret_access_key, + ) s3_client.upload_file( file_location, bucket_name, os.path.basename(file_location) ) os.remove(file_location) - return f"s3://{bucket_name}/{file_name}" + + return s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": bucket_name, "Key": file_name}, + ExpiresIn=3600, # The URL will expire in 3600 seconds (1 hour) + ) + + def download_url(self, path: str, file_storage: FileStorage | None = None) -> str: + path = path.split("/") + + if file_storage: + fernet_encrypt = FernetEncrypt() + s3_client = boto3.client( + "s3", + aws_access_key_id=fernet_encrypt.decrypt(file_storage.access_key_id), + aws_secret_access_key=fernet_encrypt.decrypt( + file_storage.secret_access_key + ), + region_name=file_storage.region, + ) + else: + s3_client = boto3.client( + "s3", + aws_access_key_id=self.settings.s3_aws_access_key_id, + aws_secret_access_key=self.settings.s3_aws_secret_access_key, + ) + + return s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": path[2], "Key": path[-1]}, + ExpiresIn=3600, # The URL will expire in 3600 seconds (1 hour) + ) def download(self, path: str) -> str: fernet_encrypt = FernetEncrypt()