Skip to content

Commit

Permalink
DH-4951 Resubmit: regenerate response without query (#239)
Browse files Browse the repository at this point in the history
* DH-4951 Resubmit: regenerate response without query

* Fix code
  • Loading branch information
jcjc712 authored Nov 6, 2023
1 parent 45e77c7 commit caa14b3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
41 changes: 26 additions & 15 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def create_response(
query_request: CreateResponseRequest = None, # noqa: ARG002
) -> Response:
question_repository = QuestionRepository(self.storage)
response_repository = ResponseRepository(self.storage)
user_question = question_repository.find_by_id(query_request.question_id)
db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
Expand All @@ -440,29 +441,39 @@ def create_response(
if not database_connection:
raise HTTPException(status_code=404, detail="Database connection not found")

response = Response(
question_id=query_request.question_id, sql_query=query_request.sql_query
)
response_repository = ResponseRepository(self.storage)
response_repository.insert(response)
start_generated_answer = time.time()
try:
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
response = generates_nl_answer.execute(response, sql_response_only)
if run_evaluator:
evaluator = self.system.instance(Evaluator)
confidence_score = evaluator.get_confidence_score(
user_question, response, database_connection
if not query_request.sql_query:
sql_generation = self.system.instance(SQLGenerator)
context_store = self.system.instance(ContextStore)
context = context_store.retrieve_context_for_question(user_question)
start_generated_answer = time.time()
response = sql_generation.generate_response(
user_question, database_connection, context[0]
)
response.confidence_score = confidence_score
response.exec_time = time.time() - start_generated_answer
response_repository.update(response)
else:
response = Response(
question_id=query_request.question_id,
sql_query=query_request.sql_query,
)
start_generated_answer = time.time()

generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
response = generates_nl_answer.execute(response, sql_response_only)
except openai.error.AuthenticationError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except SQLInjectionError as e:
raise HTTPException(status_code=404, detail=str(e)) from e

if run_evaluator:
evaluator = self.system.instance(Evaluator)
confidence_score = evaluator.get_confidence_score(
user_question, response, database_connection
)
response.confidence_score = confidence_score
response.exec_time = time.time() - start_generated_answer
response_repository.insert(response)
return response

@override
Expand Down
4 changes: 2 additions & 2 deletions dataherald/db_scanner/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_table_examples(
examples_dict.append(temp_dict)
return examples_dict

def get_processed_column(
def get_processed_column( # noqa: PLR0911
self, meta: MetaData, table: str, column: dict, db_engine: SQLDatabase
) -> ColumnDetail:
dynamic_meta_table = meta.tables[table]
Expand All @@ -95,7 +95,7 @@ def get_processed_column(
if db_engine.engine.driver == "psycopg2":
# TODO escape table and column names
rs = db_engine.engine.execute(
f"SELECT n_distinct, most_common_vals::TEXT::TEXT[] FROM pg_catalog.pg_stats WHERE tablename = '{table}' AND attname = '{column['name']}'"
f"SELECT n_distinct, most_common_vals::TEXT::TEXT[] FROM pg_catalog.pg_stats WHERE tablename = '{table}' AND attname = '{column['name']}'" # noqa: S608 E501
).fetchall()
if MIN_CATEGORY_VALUE < rs[0]["n_distinct"] <= MAX_CATEGORY_VALUE:
return ColumnDetail(
Expand Down
2 changes: 1 addition & 1 deletion dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def object_id_validation(cls, v: str):

class CreateResponseRequest(BaseModel):
question_id: str
sql_query: str = Field(None, min_length=3)
sql_query: str | None = Field(None, min_length=3)


class SQLQueryResult(BaseModel):
Expand Down

0 comments on commit caa14b3

Please sign in to comment.