Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gradio): better preview for structured views #93

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ langsmith=
elasticsearch =
elasticsearch~=8.13.1
gradio =
gradio~=4.31.5
gradio_client~=0.16.4
gradio~=4.42.0
gradio_client~=1.3.0
local =
accelerate~=0.31.0
torch~=2.2.1
Expand Down
211 changes: 148 additions & 63 deletions src/dbally/gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,40 @@
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.views.exceptions import ViewExecutionError


def create_gradio_interface(collection: Collection, *, preview_limit: Optional[int] = None) -> gr.Interface:
from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping


def create_gradio_interface(
collection: Collection,
*,
title: str = "db-ally lab",
header: str = "🔍 db-ally lab",
examples: Optional[List[str]] = None,
examples_per_page: int = 4,
preview_limit: Optional[int] = None,
) -> gr.Interface:
"""
Creates a Gradio interface for interacting with the user collection and similarity stores.

Args:
collection: The collection to interact with.
title: The title of the gradio interface.
header: The header of the gradio interface.
examples: The example questions to display.
examples_per_page: The number of examples to display per page.
preview_limit: The maximum number of preview data records to display. Default is None.

Returns:
The created Gradio interface.
"""
adapter = GradioAdapter(collection=collection, preview_limit=preview_limit)
adapter = GradioAdapter(
collection=collection,
title=title,
header=header,
examples=examples,
examples_per_page=examples_per_page,
preview_limit=preview_limit,
)
return adapter.create_interface()


Expand All @@ -33,16 +53,33 @@ class GradioAdapter:
Gradio adapter for the db-ally lab.
"""

def __init__(self, collection: Collection, *, preview_limit: Optional[int] = None) -> None:
def __init__(
self,
collection: Collection,
*,
title: str = "db-ally lab",
header: str = "🔍 db-ally lab",
examples: Optional[List[str]] = None,
examples_per_page: int = 4,
preview_limit: Optional[int] = None,
) -> None:
"""
Creates the gradio adapter.

Args:
collection: The collection to interact with.
title: The title of the gradio interface.
header: The header of the gradio interface.
examples: The example questions to display.
examples_per_page: The number of examples to display per page.
preview_limit: The maximum number of preview data records to display.
"""
self.collection = collection
self.preview_limit = preview_limit
self.title = title
self.header = header
self.examples = examples or []
self.examples_per_page = examples_per_page
self.log = self._setup_event_buffer()

def _setup_event_buffer(self) -> StringIO:
Expand Down Expand Up @@ -79,27 +116,6 @@ def _render_dataframe(self, df: pd.DataFrame, message: Optional[str] = None) ->
gr.Label(value=message, visible=df.empty, show_label=False),
)

def _render_view_preview(self, view_name: str) -> Tuple[gr.Dataframe, gr.Label]:
"""
Loads preview data for a selected view name.

Args:
view_name: The name of the selected view to load preview data for.

Returns:
A tuple containing the preview dataframe, load status text, and four None values to clean gradio fields.
"""
data = pd.DataFrame()
view = self.collection.get(view_name)

if isinstance(view, BaseStructuredView):
results = view.execute().results
data = self._load_results_into_dataframe(results)
if self.preview_limit is not None:
data = data.head(self.preview_limit)

return self._render_dataframe(data, "Preview not available")

async def _ask_collection(
self,
question: str,
Expand Down Expand Up @@ -130,13 +146,15 @@ async def _ask_collection(
question=question,
return_natural_response=return_natural_response,
)
except (NoViewFoundError, ViewExecutionError):
except (NoViewFoundError, ViewExecutionError) as e:
view_name = e.view_name
sql = ""
iql_filters = ""
iql_aggregation = ""
retrieved_rows = pd.DataFrame()
textual_response = ""
else:
view_name = result.view_name
sql = result.context.get("sql", "")
iql_filters = result.context.get("iql", {}).get("filters", "")
iql_aggregation = result.context.get("iql", {}).get("aggregation", "")
Expand All @@ -149,10 +167,11 @@ async def _ask_collection(
log_content = self.log.read()

return (
gr.Textbox(value=textual_response, visible=return_natural_response),
gr.Textbox(value=view_name, visible=True),
gr.Code(value=iql_filters, visible=bool(iql_filters)),
gr.Code(value=iql_aggregation, visible=bool(iql_aggregation)),
gr.Code(value=sql, visible=bool(sql)),
gr.Textbox(value=textual_response, visible=return_natural_response),
retrieved_rows,
empty_retrieved_rows_warning,
log_content,
Expand Down Expand Up @@ -188,6 +207,36 @@ def _load_results_into_dataframe(results: List[Dict[str, Any]]) -> pd.DataFrame:
"""
return pd.DataFrame(json.loads(json.dumps(results, default=str)))

def _render_param(self, param: MethodParamWithTyping) -> str:
if param.similarity_index:
return f"{param.name}: {str(param.type).replace('typing.', '')}"
return str(param)

def _render_tab_data(self, data: pd.DataFrame) -> None:
with gr.Tab("Data"):
if data.empty:
gr.Label("No data available", show_label=False)
else:
gr.Dataframe(value=data, height=320)

def _render_tab_iql(self, methods: List[ExposedFunction], label: str) -> None:
with gr.Tab(f"IQL {label}"):
if methods:
gr.Dataframe(
value=[
[
f"{method.name}({', '.join(self._render_param(param) for param in method.parameters)})",
method.description,
]
for method in methods
],
headers=["signature", "description"],
interactive=False,
height=325,
)
else:
gr.Label(f"No {label.lower()} available", show_label=False)

def create_interface(self) -> gr.Interface:
"""
Creates a Gradio interface for interacting with the collection.
Expand All @@ -198,8 +247,8 @@ def create_interface(self) -> gr.Interface:
views = list(self.collection.list())
selected_view = views[0] if views else None

with gr.Blocks(title="db-ally lab") as demo:
gr.Markdown("# 🔍 db-ally lab")
with gr.Blocks(title=self.title) as demo:
gr.Markdown(f"# {self.header}")

with gr.Tab("Collection"):
with gr.Row():
Expand All @@ -208,55 +257,96 @@ def create_interface(self) -> gr.Interface:
label="API Key",
placeholder="Enter your API Key",
type="password",
interactive=bool(views),
interactive=bool(selected_view),
)
model_name = gr.Textbox(
label="Model Name",
placeholder="Enter your model name",
value=self.collection._llm.model_name, # pylint: disable=protected-access
interactive=bool(views),
interactive=bool(selected_view),
max_lines=1,
)
question = gr.Textbox(
label="Question",
placeholder="Enter your question",
interactive=bool(views),
interactive=bool(selected_view),
max_lines=1,
)
natural_language_response_checkbox = gr.Checkbox(
label="Use Natural Language Responder",
interactive=bool(views),
)
ask_button = gr.Button(
value="Ask",
variant="primary",
interactive=bool(views),
interactive=bool(selected_view),
)
clear_button = gr.ClearButton(
value="Reset",
components=[question],
interactive=bool(views),

if self.examples and selected_view:
gr.Examples(
label="Example questions",
examples=self.examples,
inputs=question,
examples_per_page=self.examples_per_page,
)

with gr.Row():
clear_button = gr.ClearButton(
value="Reset",
components=[question],
interactive=bool(selected_view),
)
ask_button = gr.Button(
value="Ask",
variant="primary",
interactive=bool(selected_view),
)

gr.HTML(
"""
<div style="text-align: end; font-weight: bold;">
POWERED BY <a href="https://github.com/deepsense-ai/db-ally" target="_blank">DB-ALLY</a>
</div>
"""
)

with gr.Column():
view_dropdown = gr.Dropdown(
label="View Preview",
choices=views,
value=selected_view,
interactive=bool(views),
interactive=bool(selected_view),
)
if selected_view:
view_preview, view_preview_label = self._render_view_preview(selected_view)
else:
view_preview, view_preview_label = self._render_dataframe(
pd.DataFrame(), "No view selected"
)

@gr.render(inputs=view_dropdown, triggers=[demo.load, view_dropdown.change])
def render_view_preview(view_name: Optional[str]) -> None:
if view_name is None:
gr.Label("No views", show_label=False)
return

view = self.collection.get(view_name)

if not isinstance(view, BaseStructuredView):
gr.Label(value="Preview not available", show_label=False)
return

result = view.execute()
data = self._load_results_into_dataframe(result.results)
if self.preview_limit is not None:
data = data.head(self.preview_limit)

filters = view.list_filters()
aggregations = view.list_aggregations()

self._render_tab_data(data)
self._render_tab_iql(filters, "Filters")
self._render_tab_iql(aggregations, "Aggregations")

with gr.Tab("Results"):
natural_language_response = gr.Textbox(
label="Natural Language Response",
visible=False,
)
selected_view_name = gr.Textbox(
label="Selected View",
visible=False,
max_lines=1,
)

with gr.Row():
iql_fitlers_result = gr.Code(
Expand All @@ -266,7 +356,7 @@ def create_interface(self) -> gr.Interface:
visible=False,
)
iql_aggregation_result = gr.Code(
label="IQL Aggreagation Query",
label="IQL Aggregation Query",
lines=1,
language="python",
visible=False,
Expand All @@ -290,18 +380,18 @@ def create_interface(self) -> gr.Interface:
)

with gr.Tab("Logs"):
log_console = gr.Code(label="Logs", language="shell")
log_console = gr.Code(language="shell", show_label=False)

with gr.Tab("Help"):
gr.Markdown(
"""
## How to use this app:
## How to use this app
1. Enter your API Key for the LLM you want to use in the provided field.
2. Choose the [model](https://docs.litellm.ai/docs/providers) you want to use.
3. Type your question in the textbox.
4. Click on `Ask`. The retrieval results will appear in the `Results` tab.

## Learn more:
## Learn more
Want to learn more about db-ally? Check out our resources:
- [Website](https://deepsense.ai/db-ally)
- [GitHub](https://github.com/deepsense-ai/db-ally)
Expand All @@ -313,6 +403,7 @@ def create_interface(self) -> gr.Interface:
[
natural_language_response_checkbox,
natural_language_response,
selected_view_name,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
Expand All @@ -325,21 +416,14 @@ def create_interface(self) -> gr.Interface:
fn=self._clear_results,
outputs=[
natural_language_response,
selected_view_name,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
retrieved_rows,
retrieved_rows_label,
],
)
view_dropdown.change(
fn=self._render_view_preview,
inputs=view_dropdown,
outputs=[
view_preview,
view_preview_label,
],
)
ask_button.click(
fn=self._ask_collection,
inputs=[
Expand All @@ -349,10 +433,11 @@ def create_interface(self) -> gr.Interface:
natural_language_response_checkbox,
],
outputs=[
natural_language_response,
selected_view_name,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
natural_language_response,
retrieved_rows,
retrieved_rows_label,
log_console,
Expand Down
Loading