diff --git a/docs/how-to/visualize-collection.md b/docs/how-to/visualize-collection.md new file mode 100644 index 00000000..764b8128 --- /dev/null +++ b/docs/how-to/visualize-collection.md @@ -0,0 +1,30 @@ +# How-To: Visualize Collection + +db-ally provides simple way to visualize your collection with [Gradio](https://gradio.app){target="_blank"}. The app allows you to debug your views and query data using different LLMs. + +## Installation + +Install `dbally` with `gradio` extra. + +```bash +pip install dbally["gradio"] +``` + +## Run the app + +Pick the collection created using [`create_collection`][dbally.create_collection] and lunch the gradio interface. + +```python +from dbally.gradio import create_gradio_interface + +gradio_interface = create_gradio_interface(collection) +gradio_interface.launch() +``` +Visit {target="_blank"} to test the collection. + +!!! note + By default, the app will use LLM API key defined in environment variable depending on the LLM provider used. You can override the key in the app. + +## Full Example + +Access the full example on [GitHub](https://github.com/deepsense-ai/db-ally/tree/main/examples/visualize_collection.py){target="_blank"}. diff --git a/docs/how-to/visualize_views.md b/docs/how-to/visualize_views.md deleted file mode 100644 index c9b79791..00000000 --- a/docs/how-to/visualize_views.md +++ /dev/null @@ -1,44 +0,0 @@ -# How-To: Visualize Views - -To create simple UI interface use [create_gradio_interface function](https://github.com/deepsense-ai/db-ally/tree/main/src/dbally/gradio/gradio_interface.py) It allows to display Data Preview related to Views -and execute user queries. - -## Installation -```bash -pip install dbally["gradio"] -``` -When You plan to use some other feature like faiss similarity store install them as well. - -```bash -pip install dbally["faiss"] -``` - -## Create own gradio interface -Define collection with implemented views - -```python -llm = LiteLLM(model_name="gpt-3.5-turbo") -await country_similarity.update() -collection = dbally.create_collection("recruitment", llm) -collection.add(CandidateView, lambda: CandidateView(engine)) -collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) -``` - ->_**NOTE**_: The following code requires environment variables to proceed with LLM queries. For the example below, set the -> ```OPENAI_API_KEY``` environment variable. - -Create gradio interface -```python -gradio_interface = await create_gradio_interface(user_collection=collection) -``` - -Launch the gradio interface. To publish public interface pass argument `share=True` -```python -gradio_interface.launch() -``` - -The endpoint is set by triggering python module with Gradio Adapter launch command. -Private endpoint is set to http://127.0.0.1:7860/ by default. - -## Links -* [Example Gradio Interface](https://github.com/deepsense-ai/db-ally/tree/main/examples/visualize_views_code.py) \ No newline at end of file diff --git a/examples/visualize_views_code.py b/examples/visualize_collection.py similarity index 90% rename from examples/visualize_views_code.py rename to examples/visualize_collection.py index 504f2ddc..ff5ac71e 100644 --- a/examples/visualize_views_code.py +++ b/examples/visualize_collection.py @@ -11,14 +11,17 @@ from dbally.llms.litellm import LiteLLM -async def main(): +async def main() -> None: await country_similarity.update() + llm = LiteLLM(model_name="gpt-3.5-turbo") dbally.event_handlers = [CLIEventHandler(), BufferEventHandler()] + collection = dbally.create_collection("candidates", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine())) - gradio_interface = await create_gradio_interface(user_collection=collection) + + gradio_interface = create_gradio_interface(collection) gradio_interface.launch() diff --git a/examples/visualize_fallback_code.py b/examples/visualize_fallback_code.py index f70f1eed..8fa9eb72 100644 --- a/examples/visualize_fallback_code.py +++ b/examples/visualize_fallback_code.py @@ -28,7 +28,7 @@ async def main(): user_collection.set_fallback(fallback_collection).set_fallback(second_fallback_collection) - gradio_interface = await create_gradio_interface(user_collection=user_collection) + gradio_interface = create_gradio_interface(user_collection) gradio_interface.launch() diff --git a/mkdocs.yml b/mkdocs.yml index f92b8932..e89e5d24 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,7 +36,7 @@ nav: - how-to/use_elastic_store.md - how-to/use_custom_similarity_store.md - how-to/update_similarity_indexes.md - - how-to/visualize_views.md + - how-to/visualize-collection.md - how-to/log_runs_to_langsmith.md - how-to/trace_runs_with_otel.md - how-to/create_custom_event_handler.md diff --git a/src/dbally/gradio/__init__.py b/src/dbally/gradio/__init__.py index 41d84c3c..64b916a0 100644 --- a/src/dbally/gradio/__init__.py +++ b/src/dbally/gradio/__init__.py @@ -1,3 +1,3 @@ -from dbally.gradio.gradio_interface import create_gradio_interface +from dbally.gradio.interface import create_gradio_interface __all__ = ["create_gradio_interface"] diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py deleted file mode 100644 index 4a8de2b4..00000000 --- a/src/dbally/gradio/gradio_interface.py +++ /dev/null @@ -1,301 +0,0 @@ -import json -from typing import Any, Dict, List, Optional, Tuple - -import gradio -import pandas as pd - -import dbally -from dbally import BaseStructuredView -from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler -from dbally.collection import Collection -from dbally.collection.exceptions import NoViewFoundError -from dbally.iql_generator.prompt import UnsupportedQueryError -from dbally.prompt.template import PromptTemplateError - - -async def create_gradio_interface(user_collection: Collection, preview_limit: int = 10) -> gradio.Interface: - """Adapt and integrate data collection and query execution with Gradio interface components. - - Args: - user_collection: The user's collection to interact with. - preview_limit: The maximum number of preview data records to display. Default is 10. - - Returns: - The created Gradio interface. - """ - adapter = GradioAdapter() - gradio_interface = await adapter.create_interface(user_collection, preview_limit) - return gradio_interface - - -def find_event_buffer() -> Optional[BufferEventHandler]: - """ - Searches through global event handlers to find an instance of BufferEventHandler. - - This function iterates over the list of global event handlers stored in `dbally.event_handlers`. - It checks the type of each handler, and if it finds one that is an instance of `BufferEventHandler`, it - returns that handler. If no such handler is found, the function returns `None`. - - Returns: - The first instance of `BufferEventHandler` found in the list, or `None` if no such handler is found. - """ - for handler in dbally.event_handlers: - if isinstance(handler, BufferEventHandler): - return handler - return None - - -class GradioAdapter: - """ - A class to adapt and integrate data collection and query execution with Gradio interface components. - """ - - def __init__(self): - """ - Initializes the GradioAdapter with a preview limit. - - """ - self.preview_limit = None - self.selected_view_name = None - self.collection = None - - buffer_event_handler = find_event_buffer() - if not buffer_event_handler: - buffer_event_handler = BufferEventHandler() - dbally.event_handlers.append(buffer_event_handler) - - self.log: BufferEventHandler = buffer_event_handler.buffer # pylint: disable=no-member - - def _load_gradio_data(self, preview_dataframe, label) -> Tuple[gradio.DataFrame, gradio.Label]: - """ - Load data into Gradio components for preview. - - This function takes a DataFrame and a label, and returns a tuple containing a Gradio DataFrame - and a Gradio Label. The visibility of these components is determined by whether the input - DataFrame is empty. - - Args: - preview_dataframe: The DataFrame to be loaded into the Gradio DataFrame component. - label: The label to be associated with the Gradio components. - - Returns: - A tuple containing the Gradio DataFrame component with the provided data and label and A Gradio Label - indicating the availability of data. - """ - if preview_dataframe.empty: - gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=False) - empty_frame_label = gradio.Label(value=f"{label} not available", visible=True, show_label=False) - else: - gradio_preview_dataframe = gradio.DataFrame(label=label, value=preview_dataframe, visible=True) - empty_frame_label = gradio.Label(value=f"{label} not available", visible=False, show_label=False) - return gradio_preview_dataframe, empty_frame_label - - async def _ui_load_preview_data( - self, selected_view_name: str - ) -> Tuple[gradio.DataFrame, gradio.Label, None, None, None]: - """ - Asynchronously loads preview data for a selected view name. - - Args: - selected_view_name: The name of the selected view to load preview data for. - - Returns: - A tuple containing the preview dataframe, load status text, and four None values to clean gradio fields. - """ - self.selected_view_name = selected_view_name - preview_dataframe = self._load_preview_data(selected_view_name) - gradio_preview_dataframe, empty_frame_label = self._load_gradio_data(preview_dataframe, "Preview") - - return gradio_preview_dataframe, empty_frame_label, None, None, None - - def _load_preview_data(self, selected_view_name: str) -> pd.DataFrame: - """ - Loads preview data for a selected view name. - - Args: - selected_view_name: The name of the selected view to load preview data for. - - Returns: - A tuple containing the preview dataframe - """ - selected_view = self.collection.get(selected_view_name) - if issubclass(type(selected_view), BaseStructuredView): - selected_view_results = selected_view.execute() - preview_dataframe = self._load_results_into_dataframe(selected_view_results.results).head( - self.preview_limit - ) - else: - preview_dataframe = pd.DataFrame() - - return preview_dataframe - - async def _ui_ask_query( - self, question_query: str, natural_language_flag: bool - ) -> Tuple[gradio.DataFrame, gradio.Label, gradio.Text, gradio.Text, str]: - """ - Asynchronously processes a query and returns the results. - - Args: - question_query: The query to process. - natural_language_flag: Flag to indicate if the natural language shall be returned - - Returns: - A tuple containing the generated query context, the query results as a dataframe, and the log output. - """ - self.log.seek(0) - self.log.truncate(0) - textual_response = "" - try: - execution_result = await self.collection.ask( - question=question_query, return_natural_response=natural_language_flag - ) - generated_query = str(execution_result.context) - data = self._load_results_into_dataframe(execution_result.results) - textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response - - except UnsupportedQueryError: - generated_query = {"Query": "unsupported"} - data = pd.DataFrame() - except NoViewFoundError: - generated_query = {"Query": "No view matched to query"} - data = pd.DataFrame() - except PromptTemplateError: - generated_query = {"Query": "No view matched to query"} - data = pd.DataFrame() - finally: - self.log.seek(0) - log_content = self.log.read() - - gradio_dataframe, empty_dataframe_warning = self._load_gradio_data(data, "Results") - return ( - gradio_dataframe, - empty_dataframe_warning, - gradio.Text(value=generated_query, visible=True), - gradio.Text(value=textual_response, visible=natural_language_flag), - log_content, - ) - - def _clear_results(self) -> Tuple[gradio.DataFrame, gradio.Label, gradio.Text, gradio.Text]: - preview_dataframe = self._load_preview_data(self.selected_view_name) - gradio_preview_dataframe, empty_frame_label = self._load_gradio_data(preview_dataframe, "Preview") - - return ( - gradio_preview_dataframe, - empty_frame_label, - gradio.Text(visible=False), - gradio.Text(visible=False), - ) - - @staticmethod - def _load_results_into_dataframe(results: List[Dict[str, Any]]) -> pd.DataFrame: - """ - Load the results into a pandas DataFrame. Makes sure that the results are json serializable. - - Args: - results: The results to load into the DataFrame. - - Returns: - The loaded DataFrame. - """ - return pd.DataFrame(json.loads(json.dumps(results, default=str))) - - async def create_interface(self, user_collection: Collection, preview_limit: int) -> gradio.Interface: - """ - Creates a Gradio interface for interacting with the user collection and similarity stores. - - Args: - user_collection: The user's collection to interact with. - preview_limit: The maximum number of preview data records to display. - - Returns: - The created Gradio interface. - """ - - self.preview_limit = preview_limit - self.collection = user_collection - - data_preview_frame = pd.DataFrame() - question_interactive = False - - view_list = [*user_collection.list()] - if view_list: - self.selected_view_name = view_list[0] - data_preview_frame = self._load_preview_data(self.selected_view_name) - question_interactive = True - - with gradio.Blocks() as demo: - with gradio.Row(): - with gradio.Column(): - view_dropdown = gradio.Dropdown( - label="Data View preview", choices=view_list, value=self.selected_view_name - ) - query = gradio.Text(label="Ask question", interactive=question_interactive) - query_button = gradio.Button("Ask db-ally", interactive=question_interactive) - clear_button = gradio.ClearButton(components=[query], interactive=question_interactive) - natural_language_response_checkbox = gradio.Checkbox( - label="Return natural language answer", interactive=question_interactive - ) - - with gradio.Column(): - if not data_preview_frame.empty: - loaded_data_frame = gradio.Dataframe( - label="Preview", value=data_preview_frame, interactive=False - ) - empty_frame_label = gradio.Label(value="Preview not available", visible=False) - else: - loaded_data_frame = gradio.Dataframe(interactive=False, visible=False) - empty_frame_label = gradio.Label(value="Preview not available", visible=True) - - query_sql_result = gradio.Text(label="Generated query context", visible=False) - generated_natural_language_answer = gradio.Text( - label="Generated answer in natural language:", visible=False - ) - - with gradio.Row(): - log_console = gradio.Code(label="Logs", language="shell") - - clear_button.add( - [ - natural_language_response_checkbox, - loaded_data_frame, - query_sql_result, - generated_natural_language_answer, - log_console, - ] - ) - - clear_button.click( - fn=self._clear_results, - inputs=[], - outputs=[ - loaded_data_frame, - empty_frame_label, - query_sql_result, - generated_natural_language_answer, - ], - ) - - view_dropdown.change( - fn=self._ui_load_preview_data, - inputs=view_dropdown, - outputs=[ - loaded_data_frame, - empty_frame_label, - query, - query_sql_result, - log_console, - ], - ) - query_button.click( - fn=self._ui_ask_query, - inputs=[query, natural_language_response_checkbox], - outputs=[ - loaded_data_frame, - empty_frame_label, - query_sql_result, - generated_natural_language_answer, - log_console, - ], - ) - - return demo diff --git a/src/dbally/gradio/interface.py b/src/dbally/gradio/interface.py new file mode 100644 index 00000000..8535d2d8 --- /dev/null +++ b/src/dbally/gradio/interface.py @@ -0,0 +1,362 @@ +import json +from io import StringIO +from typing import Any, Dict, List, Optional, Tuple + +import gradio as gr +import pandas as pd + +import dbally +from dbally import BaseStructuredView +from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler +from dbally.collection import Collection +from dbally.collection.exceptions import NoViewFoundError +from dbally.views.exceptions import ViewExecutionError + + +def create_gradio_interface(collection: Collection, *, preview_limit: Optional[int] = None) -> gr.Interface: + """ + Creates a Gradio interface for interacting with the user collection and similarity stores. + + Args: + collection: The collection to interact with. + preview_limit: The maximum number of preview data records to display. Default is None. + + Returns: + The created Gradio interface. + """ + adapter = GradioAdapter(collection=collection, preview_limit=preview_limit) + return adapter.create_interface() + + +class GradioAdapter: + """ + Gradio adapter for the db-ally lab. + """ + + def __init__(self, collection: Collection, *, preview_limit: Optional[int] = None) -> None: + """ + Creates the gradio adapter. + + Args: + collection: The collection to interact with. + preview_limit: The maximum number of preview data records to display. + """ + self.collection = collection + self.preview_limit = preview_limit + self.log = self._setup_event_buffer() + + def _setup_event_buffer(self) -> StringIO: + """ + Setup the event buffer for the gradio interface. + + Returns: + The buffer event handler. + """ + buffer_event_handler = None + for handler in dbally.event_handlers: + if isinstance(handler, BufferEventHandler): + buffer_event_handler = handler + + if not buffer_event_handler: + buffer_event_handler = BufferEventHandler() + dbally.event_handlers.append(buffer_event_handler) + + return buffer_event_handler.buffer + + def _render_dataframe(self, df: pd.DataFrame, message: Optional[str] = None) -> Tuple[gr.Dataframe, gr.Label]: + """ + Renders the dataframe and label for the gradio interface. + + Args: + df: The dataframe to render. + message: The message to display if the dataframe is empty. + + Returns: + A tuple containing the dataframe and label. + """ + return ( + gr.Dataframe(value=df, visible=not df.empty, height=325), + gr.Label(value=message, visible=df.empty, show_label=False), + ) + + def _render_view_preview(self, view_name: str) -> Tuple[gr.Dataframe, gr.Label]: + """ + Loads preview data for a selected view name. + + Args: + view_name: The name of the selected view to load preview data for. + + Returns: + A tuple containing the preview dataframe, load status text, and four None values to clean gradio fields. + """ + data = pd.DataFrame() + view = self.collection.get(view_name) + + if isinstance(view, BaseStructuredView): + results = view.execute().results + data = self._load_results_into_dataframe(results) + if self.preview_limit is not None: + data = data.head(self.preview_limit) + + return self._render_dataframe(data, "Preview not available") + + async def _ask_collection( + self, + question: str, + model_name: str, + api_key: str, + return_natural_response: bool, + ) -> Tuple[gr.Code, gr.Code, gr.Code, gr.Textbox, gr.Dataframe, gr.Label, str]: + """ + Processes the question and returns the results. + + Args: + question: The question to ask the collection. + return_natural_response: Flag to indicate if the natural language shall be returned. + + Returns: + A tuple containing the generated query context, the query results as a dataframe, and the log output. + """ + self.log.seek(0) + self.log.truncate(0) + + # pylint: disable=protected-access + self.collection._llm.model_name = model_name + if hasattr(self.collection._llm, "api_key"): + self.collection._llm.api_key = api_key + + try: + result = await self.collection.ask( + question=question, + return_natural_response=return_natural_response, + ) + except (NoViewFoundError, ViewExecutionError): + sql = "" + iql_filters = "" + iql_aggregation = "" + retrieved_rows = pd.DataFrame() + textual_response = "" + else: + sql = result.context.get("sql", "") + iql_filters = result.context.get("iql", {}).get("filters", "") + iql_aggregation = result.context.get("iql", {}).get("aggregation", "") + retrieved_rows = self._load_results_into_dataframe(result.results) + textual_response = result.textual_response or "" + + retrieved_rows, empty_retrieved_rows_warning = self._render_dataframe(retrieved_rows, "No rows retrieved") + + self.log.seek(0) + log_content = self.log.read() + + return ( + gr.Code(value=iql_filters, visible=bool(iql_filters)), + gr.Code(value=iql_aggregation, visible=bool(iql_aggregation)), + gr.Code(value=sql, visible=bool(sql)), + gr.Textbox(value=textual_response, visible=return_natural_response), + retrieved_rows, + empty_retrieved_rows_warning, + log_content, + ) + + def _clear_results(self) -> Tuple[gr.Textbox, gr.Code, gr.Code, gr.Code, gr.Dataframe, gr.Label]: + """ + Clears the results from the gradio interface. + + Returns: + A tuple containing the cleared results. + """ + retrieved_rows, retrieved_rows_label = self._render_dataframe(pd.DataFrame(), "No rows retrieved") + return ( + gr.Textbox(visible=False), + gr.Code(visible=False), + gr.Code(visible=False), + gr.Code(visible=False), + retrieved_rows, + retrieved_rows_label, + ) + + @staticmethod + def _load_results_into_dataframe(results: List[Dict[str, Any]]) -> pd.DataFrame: + """ + Load the results into a pandas DataFrame. Makes sure that the results are json serializable. + + Args: + results: The results to load into the DataFrame. + + Returns: + The loaded DataFrame. + """ + return pd.DataFrame(json.loads(json.dumps(results, default=str))) + + def create_interface(self) -> gr.Interface: + """ + Creates a Gradio interface for interacting with the collection. + + Returns: + The Gradio interface. + """ + views = list(self.collection.list()) + selected_view = views[0] if views else None + + with gr.Blocks(title="db-ally lab") as demo: + gr.Markdown("# 🔍 db-ally lab") + + with gr.Tab("Collection"): + with gr.Row(): + with gr.Column(): + api_key = gr.Textbox( + label="API Key", + placeholder="Enter your API Key (optional)", + type="password", + interactive=bool(views), + ) + model_name = gr.Textbox( + label="Model Name", + placeholder="Enter your model name", + value=self.collection._llm.model_name, # pylint: disable=protected-access + interactive=bool(views), + max_lines=1, + ) + question = gr.Textbox( + label="Question", + placeholder="Enter your question", + interactive=bool(views), + max_lines=1, + ) + natural_language_response_checkbox = gr.Checkbox( + label="Use Natural Language Responder", + interactive=bool(views), + ) + ask_button = gr.Button( + value="Ask", + variant="primary", + interactive=bool(views), + ) + clear_button = gr.ClearButton( + value="Reset", + components=[question], + interactive=bool(views), + ) + + with gr.Column(): + view_dropdown = gr.Dropdown( + label="View Preview", + choices=views, + value=selected_view, + interactive=bool(views), + ) + if selected_view: + view_preview, view_preview_label = self._render_view_preview(selected_view) + else: + view_preview, view_preview_label = self._render_dataframe( + pd.DataFrame(), "No view selected" + ) + + with gr.Tab("Results"): + natural_language_response = gr.Textbox( + label="Natural Language Response", + visible=False, + ) + + with gr.Row(): + iql_fitlers_result = gr.Code( + label="IQL Filters Query", + lines=1, + language="python", + visible=False, + ) + iql_aggregation_result = gr.Code( + label="IQL Aggreagation Query", + lines=1, + language="python", + visible=False, + ) + + sql_result = gr.Code( + label="SQL Query", + lines=3, + language="sql", + visible=False, + ) + retrieved_rows = gr.Dataframe( + interactive=False, + height=325, + visible=False, + ) + retrieved_rows_label = gr.Label( + value="No rows retrieved", + visible=True, + show_label=False, + ) + + with gr.Tab("Logs"): + log_console = gr.Code(label="Logs", language="shell") + + with gr.Tab("Help"): + gr.Markdown( + """ + ## How to use this app: + 1. Enter your API Key for the LLM you want to use in the provided field. + 2. Choose the [model](https://docs.litellm.ai/docs/providers) you want to use. + 3. Type your question in the textbox. + 4. Click on `Ask`. The retrieval results will appear in the `Results` tab. + + ## Learn more: + Want to learn more about db-ally? Check out our resources: + - [Website](https://deepsense.ai/db-ally) + - [GitHub](https://github.com/deepsense-ai/db-ally) + - [Documentation](https://db-ally.deepsense.ai) + """ + ) + + clear_button.add( + [ + natural_language_response_checkbox, + natural_language_response, + iql_fitlers_result, + iql_aggregation_result, + sql_result, + retrieved_rows, + retrieved_rows_label, + log_console, + ] + ) + clear_button.click( + fn=self._clear_results, + outputs=[ + natural_language_response, + iql_fitlers_result, + iql_aggregation_result, + sql_result, + retrieved_rows, + retrieved_rows_label, + ], + ) + view_dropdown.change( + fn=self._render_view_preview, + inputs=view_dropdown, + outputs=[ + view_preview, + view_preview_label, + ], + ) + ask_button.click( + fn=self._ask_collection, + inputs=[ + question, + model_name, + api_key, + natural_language_response_checkbox, + ], + outputs=[ + iql_fitlers_result, + iql_aggregation_result, + sql_result, + natural_language_response, + retrieved_rows, + retrieved_rows_label, + log_console, + ], + ) + + return demo diff --git a/src/dbally/llms/litellm.py b/src/dbally/llms/litellm.py index 077474e9..06ab6c45 100644 --- a/src/dbally/llms/litellm.py +++ b/src/dbally/llms/litellm.py @@ -1,4 +1,3 @@ -from functools import cached_property from typing import Optional try: @@ -53,7 +52,7 @@ def __init__( self.api_key = api_key self.api_version = api_version - @cached_property + @property def client(self) -> LiteLLMClient: """ Client for the LLM.