diff --git a/dataherald/tests/sql_generator/test_generator.py b/dataherald/tests/sql_generator/test_generator.py index 9610cb56..c463e38e 100644 --- a/dataherald/tests/sql_generator/test_generator.py +++ b/dataherald/tests/sql_generator/test_generator.py @@ -5,7 +5,7 @@ from dataherald.config import System from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator -from dataherald.types import Question, Response +from dataherald.types import Prompt, SQLGeneration class TestGenerator(SQLGenerator): @@ -15,14 +15,12 @@ def __init__(self, system: System): @override def generate_response( self, - user_question: Question, + user_question: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, # noqa: ARG002 - generate_csv: bool = None, - ) -> Response: - return Response( + ) -> SQLGeneration: + return SQLGeneration( question_id="651f2d76275132d5b65175eb", - response="Foo response", - sql_query="bar", - generate_csv=None, + sql="Foo response", + status="bar", )