diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index 5f939c2e..06d84287 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -22,10 +22,10 @@ import dbally from dbally.collection import Collection -from dbally.data_models.prompts.iql_prompt_template import default_iql_template -from dbally.data_models.prompts.view_selector_prompt_template import default_view_selector_template +from dbally.iql_generator.iql_prompt_template import default_iql_template from dbally.llm_client.openai_client import OpenAIClient from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError +from dbally.view_selection.view_selector_prompt_template import default_view_selector_template async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult: diff --git a/benchmark/dbally_benchmark/iql_benchmark.py b/benchmark/dbally_benchmark/iql_benchmark.py index d725fb72..f99411e5 100644 --- a/benchmark/dbally_benchmark/iql_benchmark.py +++ b/benchmark/dbally_benchmark/iql_benchmark.py @@ -20,8 +20,8 @@ from sqlalchemy import create_engine from dbally.audit.event_tracker import EventTracker -from dbally.data_models.prompts.iql_prompt_template import default_iql_template from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql_generator.iql_prompt_template import default_iql_template from dbally.llm_client.openai_client import OpenAIClient from dbally.utils.errors import UnsupportedQueryError from dbally.views.structured import BaseStructuredView diff --git a/benchmark/dbally_benchmark/text2sql/prompt_template.py b/benchmark/dbally_benchmark/text2sql/prompt_template.py index ac18ae21..abee9659 100644 --- a/benchmark/dbally_benchmark/text2sql/prompt_template.py +++ b/benchmark/dbally_benchmark/text2sql/prompt_template.py @@ -1,4 +1,4 @@ -from dbally.prompts.prompt_builder import PromptTemplate +from dbally.prompts import PromptTemplate TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate( ( diff --git a/examples/recruiting.py b/examples/recruiting.py index 9b89fbeb..b5290344 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -9,7 +9,7 @@ from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.audit.event_tracker import EventTracker from dbally.llm_client.openai_client import OpenAIClient -from dbally.prompts.prompt_builder import PromptTemplate +from dbally.prompts import PromptTemplate TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate( ( diff --git a/pyproject.toml b/pyproject.toml index 460247c6..28f8957f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ testpaths = ['tests'] [tool.pytest.ini_options] asyncio_mode = "auto" +testpaths = ["tests"] +pythonpath = ["."] [tool.mypy] warn_unused_configs = true diff --git a/src/dbally/__version__.py b/src/dbally/__version__.py index d7860100..e58a6ce1 100644 --- a/src/dbally/__version__.py +++ b/src/dbally/__version__.py @@ -1,2 +1,3 @@ """Version information.""" + __version__ = "0.1.0" diff --git a/src/dbally/data_models/audit.py b/src/dbally/data_models/audit.py index 577bf769..4dc466ad 100644 --- a/src/dbally/data_models/audit.py +++ b/src/dbally/data_models/audit.py @@ -3,7 +3,7 @@ from typing import Optional, Union from dbally.data_models.execution_result import ExecutionResult -from dbally.data_models.prompts.prompt_template import ChatFormat +from dbally.prompts import ChatFormat class EventType(Enum): diff --git a/src/dbally/data_models/prompts/__init__.py b/src/dbally/data_models/prompts/__init__.py deleted file mode 100644 index f9dffb3c..00000000 --- a/src/dbally/data_models/prompts/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .common_validation_utils import ChatFormat, PromptTemplateError -from .iql_prompt_template import IQLPromptTemplate, default_iql_template -from .prompt_template import PromptTemplate -from .view_selector_prompt_template import ViewSelectorPromptTemplate, default_view_selector_template - -__all__ = [ - "IQLPromptTemplate", - "default_iql_template", - "PromptTemplateError", - "ChatFormat", - "PromptTemplateError", - "PromptTemplate", - "ViewSelectorPromptTemplate", - "default_view_selector_template", -] diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 07795158..ccc54268 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Tuple, TypeVar from dbally.audit.event_tracker import EventTracker -from dbally.data_models.prompts.iql_prompt_template import IQLPromptTemplate, default_iql_template +from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template from dbally.llm_client.base import LLMClient from dbally.prompts.prompt_builder import PromptBuilder from dbally.views.exposed_functions import ExposedFunction diff --git a/src/dbally/data_models/prompts/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py similarity index 88% rename from src/dbally/data_models/prompts/iql_prompt_template.py rename to src/dbally/iql_generator/iql_prompt_template.py index c7c60a68..1bb0a71c 100644 --- a/src/dbally/data_models/prompts/iql_prompt_template.py +++ b/src/dbally/iql_generator/iql_prompt_template.py @@ -1,7 +1,6 @@ from typing import Callable, Dict, Optional -from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables -from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate +from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables from dbally.utils.errors import UnsupportedQueryError @@ -17,7 +16,7 @@ def __init__( llm_response_parser: Callable = lambda x: x, ): super().__init__(chat, response_format, llm_response_parser) - self.chat = _check_prompt_variables(chat, {"filters", "question"}) + self.chat = check_prompt_variables(chat, {"filters", "question"}) def _validate_iql_response(llm_response: str) -> str: diff --git a/src/dbally/llm_client/base.py b/src/dbally/llm_client/base.py index dfd064da..55cf320a 100644 --- a/src/dbally/llm_client/base.py +++ b/src/dbally/llm_client/base.py @@ -7,7 +7,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.data_models.audit import LLMEvent from dbally.data_models.llm_options import LLMOptions -from dbally.prompts.prompt_builder import ChatFormat, PromptBuilder, PromptTemplate +from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate class LLMClient(abc.ABC): diff --git a/src/dbally/llm_client/openai_client.py b/src/dbally/llm_client/openai_client.py index 3b730d6d..dad76fe3 100644 --- a/src/dbally/llm_client/openai_client.py +++ b/src/dbally/llm_client/openai_client.py @@ -3,7 +3,7 @@ from dbally.data_models.audit import LLMEvent from dbally.data_models.llm_options import LLMOptions from dbally.llm_client.base import LLMClient -from dbally.prompts.prompt_builder import ChatFormat +from dbally.prompts import ChatFormat class OpenAIClient(LLMClient): diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index 90524bb0..eabc33e4 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -5,15 +5,12 @@ from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult -from dbally.data_models.prompts.nl_responder_prompt_template import ( - NLResponderPromptTemplate, - default_nl_responder_template, -) -from dbally.data_models.prompts.query_explainer_prompt_template import ( +from dbally.llm_client.base import LLMClient +from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template +from dbally.nl_responder.query_explainer_prompt_template import ( QueryExplainerPromptTemplate, default_query_explainer_template, ) -from dbally.llm_client.base import LLMClient from dbally.nl_responder.token_counters import count_tokens_for_huggingface, count_tokens_for_openai diff --git a/src/dbally/data_models/prompts/nl_responder_prompt_template.py b/src/dbally/nl_responder/nl_responder_prompt_template.py similarity index 85% rename from src/dbally/data_models/prompts/nl_responder_prompt_template.py rename to src/dbally/nl_responder/nl_responder_prompt_template.py index 904153d2..9e6e687e 100644 --- a/src/dbally/data_models/prompts/nl_responder_prompt_template.py +++ b/src/dbally/nl_responder/nl_responder_prompt_template.py @@ -1,7 +1,6 @@ from typing import Callable, Dict, Optional -from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables -from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate +from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables class NLResponderPromptTemplate(PromptTemplate): @@ -25,7 +24,7 @@ def __init__( """ super().__init__(chat, response_format, llm_response_parser) - self.chat = _check_prompt_variables(chat, {"rows", "question"}) + self.chat = check_prompt_variables(chat, {"rows", "question"}) default_nl_responder_template = NLResponderPromptTemplate( diff --git a/src/dbally/data_models/prompts/query_explainer_prompt_template.py b/src/dbally/nl_responder/query_explainer_prompt_template.py similarity index 87% rename from src/dbally/data_models/prompts/query_explainer_prompt_template.py rename to src/dbally/nl_responder/query_explainer_prompt_template.py index e0767318..00a3e6a6 100644 --- a/src/dbally/data_models/prompts/query_explainer_prompt_template.py +++ b/src/dbally/nl_responder/query_explainer_prompt_template.py @@ -1,7 +1,6 @@ from typing import Callable, Dict, Optional -from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables -from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate +from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables class QueryExplainerPromptTemplate(PromptTemplate): @@ -22,7 +21,7 @@ def __init__( llm_response_parser: Callable = lambda x: x, ) -> None: super().__init__(chat, response_format, llm_response_parser) - self.chat = _check_prompt_variables(chat, {"question", "query", "number_of_results"}) + self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"}) default_query_explainer_template = QueryExplainerPromptTemplate( diff --git a/src/dbally/nl_responder/token_counters.py b/src/dbally/nl_responder/token_counters.py index 6529d0b4..2b242eba 100644 --- a/src/dbally/nl_responder/token_counters.py +++ b/src/dbally/nl_responder/token_counters.py @@ -1,6 +1,6 @@ from typing import Dict -from dbally.data_models.prompts.common_validation_utils import ChatFormat +from dbally.prompts import ChatFormat def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int: diff --git a/src/dbally/prompts/__init__.py b/src/dbally/prompts/__init__.py index 1cb764c5..279ac384 100644 --- a/src/dbally/prompts/__init__.py +++ b/src/dbally/prompts/__init__.py @@ -1,3 +1,5 @@ +from .common_validation_utils import ChatFormat, PromptTemplateError, check_prompt_variables from .prompt_builder import PromptBuilder +from .prompt_template import PromptTemplate -__all__ = ["PromptBuilder"] +__all__ = ["PromptBuilder", "PromptTemplate", "PromptTemplateError", "check_prompt_variables", "ChatFormat"] diff --git a/src/dbally/data_models/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py similarity index 93% rename from src/dbally/data_models/prompts/common_validation_utils.py rename to src/dbally/prompts/common_validation_utils.py index 2ce9fd0b..f62d72b1 100644 --- a/src/dbally/data_models/prompts/common_validation_utils.py +++ b/src/dbally/prompts/common_validation_utils.py @@ -22,7 +22,7 @@ def _extract_variables(text: str) -> List[str]: return re.findall(pattern, text) -def _check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> ChatFormat: +def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> ChatFormat: """ Function validates a given chat to make sure it contains variables required. diff --git a/src/dbally/prompts/prompt_builder.py b/src/dbally/prompts/prompt_builder.py index 72ca97f2..a6e5596c 100644 --- a/src/dbally/prompts/prompt_builder.py +++ b/src/dbally/prompts/prompt_builder.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Union -from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate +from .common_validation_utils import ChatFormat +from .prompt_template import PromptTemplate if TYPE_CHECKING: from transformers.tokenization_utils import PreTrainedTokenizer diff --git a/src/dbally/data_models/prompts/prompt_template.py b/src/dbally/prompts/prompt_template.py similarity index 96% rename from src/dbally/data_models/prompts/prompt_template.py rename to src/dbally/prompts/prompt_template.py index 923110e0..2bd382f6 100644 --- a/src/dbally/data_models/prompts/prompt_template.py +++ b/src/dbally/prompts/prompt_template.py @@ -2,7 +2,7 @@ from typing_extensions import Self -from dbally.data_models.prompts.common_validation_utils import ChatFormat, PromptTemplateError +from .common_validation_utils import ChatFormat, PromptTemplateError def _check_chat_order(chat: ChatFormat) -> ChatFormat: diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py index d0ec59c0..e53e5a74 100644 --- a/src/dbally/view_selection/llm_view_selector.py +++ b/src/dbally/view_selection/llm_view_selector.py @@ -2,10 +2,11 @@ from typing import Callable, Dict, Optional from dbally.audit.event_tracker import EventTracker -from dbally.data_models.prompts import IQLPromptTemplate, default_view_selector_template +from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate from dbally.llm_client.base import LLMClient from dbally.prompts import PromptBuilder from dbally.view_selection.base import ViewSelector +from dbally.view_selection.view_selector_prompt_template import default_view_selector_template class LLMViewSelector(ViewSelector): diff --git a/src/dbally/data_models/prompts/view_selector_prompt_template.py b/src/dbally/view_selection/view_selector_prompt_template.py similarity index 87% rename from src/dbally/data_models/prompts/view_selector_prompt_template.py rename to src/dbally/view_selection/view_selector_prompt_template.py index 5cbfad6f..60440c84 100644 --- a/src/dbally/data_models/prompts/view_selector_prompt_template.py +++ b/src/dbally/view_selection/view_selector_prompt_template.py @@ -1,8 +1,7 @@ import json from typing import Callable, Dict, Optional -from dbally.data_models.prompts.common_validation_utils import _check_prompt_variables -from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate +from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables class ViewSelectorPromptTemplate(PromptTemplate): @@ -17,7 +16,7 @@ def __init__( llm_response_parser: Callable = lambda x: x, ): super().__init__(chat, response_format, llm_response_parser) - self.chat = _check_prompt_variables(chat, {"views"}) + self.chat = check_prompt_variables(chat, {"views"}) def _convert_llm_json_response_to_selected_view(llm_response_json: str) -> str: diff --git a/src/dbally/views/freeform/text2sql/_autodiscovery.py b/src/dbally/views/freeform/text2sql/_autodiscovery.py index 12601e88..b579876d 100644 --- a/src/dbally/views/freeform/text2sql/_autodiscovery.py +++ b/src/dbally/views/freeform/text2sql/_autodiscovery.py @@ -4,8 +4,8 @@ from sqlalchemy.sql.ddl import CreateTable from typing_extensions import Self -from dbally.data_models.prompts import PromptTemplate from dbally.llm_client.base import LLMClient +from dbally.prompts import PromptTemplate from ._config import Text2SQLConfig, Text2SQLTableConfig diff --git a/src/dbally/views/freeform/text2sql/_view.py b/src/dbally/views/freeform/text2sql/_view.py index 1ba6bd94..55ab2ba2 100644 --- a/src/dbally/views/freeform/text2sql/_view.py +++ b/src/dbally/views/freeform/text2sql/_view.py @@ -5,8 +5,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult -from dbally.data_models.prompts import PromptTemplate from dbally.llm_client.base import LLMClient +from dbally.prompts import PromptTemplate from dbally.views.base import BaseView from ._config import Text2SQLConfig diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index d45d99e8..b7a22913 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -7,9 +7,9 @@ from typing import List, Tuple from unittest.mock import create_autospec -from dbally.data_models.prompts.iql_prompt_template import IQLPromptTemplate, default_iql_template from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template from dbally.llm_client.base import LLMClient from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index fabd2f51..cdedeb06 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -7,9 +7,9 @@ from dbally import decorators from dbally.audit.event_tracker import EventTracker -from dbally.data_models.prompts.iql_prompt_template import default_iql_template from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql_generator.iql_prompt_template import default_iql_template from dbally.views.methods_base import MethodsBaseView diff --git a/tests/unit/test_prompt_builder.py b/tests/unit/test_prompt_builder.py index f846f5d6..a4a2fb59 100644 --- a/tests/unit/test_prompt_builder.py +++ b/tests/unit/test_prompt_builder.py @@ -1,8 +1,7 @@ import pytest -from dbally.data_models.prompts.iql_prompt_template import IQLPromptTemplate -from dbally.data_models.prompts.prompt_template import ChatFormat, PromptTemplate, PromptTemplateError -from dbally.prompts.prompt_builder import PromptBuilder +from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate +from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate, PromptTemplateError @pytest.fixture()