Skip to content

Commit

Permalink
add iql preview
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Sep 10, 2024
1 parent f8d817a commit da3238f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 40 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ langsmith=
elasticsearch =
elasticsearch~=8.13.1
gradio =
gradio~=4.31.5
gradio_client~=0.16.4
gradio~=4.42.0
gradio_client~=1.3.0
local =
accelerate~=0.31.0
torch~=2.2.1
Expand Down
107 changes: 69 additions & 38 deletions src/dbally/gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.views.exceptions import ViewExecutionError
from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping


def create_gradio_interface(
Expand Down Expand Up @@ -115,27 +116,6 @@ def _render_dataframe(self, df: pd.DataFrame, message: Optional[str] = None) ->
gr.Label(value=message, visible=df.empty, show_label=False),
)

def _render_view_preview(self, view_name: str) -> Tuple[gr.Dataframe, gr.Label]:
"""
Loads preview data for a selected view name.
Args:
view_name: The name of the selected view to load preview data for.
Returns:
A tuple containing the preview dataframe, load status text, and four None values to clean gradio fields.
"""
data = pd.DataFrame()
view = self.collection.get(view_name)

if isinstance(view, BaseStructuredView):
result = view.execute()
data = self._load_results_into_dataframe(result.results)
if self.preview_limit is not None:
data = data.head(self.preview_limit)

return self._render_dataframe(data, "Preview not available")

async def _ask_collection(
self,
question: str,
Expand Down Expand Up @@ -166,13 +146,15 @@ async def _ask_collection(
question=question,
return_natural_response=return_natural_response,
)
except (NoViewFoundError, ViewExecutionError):
except (NoViewFoundError, ViewExecutionError) as e:
view_name = e.view_name
sql = ""
iql_filters = ""
iql_aggregation = ""
retrieved_rows = pd.DataFrame()
textual_response = ""
else:
view_name = result.view_name
sql = result.context.get("sql", "")
iql_filters = result.context.get("iql", {}).get("filters", "")
iql_aggregation = result.context.get("iql", {}).get("aggregation", "")
Expand All @@ -185,10 +167,11 @@ async def _ask_collection(
log_content = self.log.read()

return (
gr.Textbox(value=textual_response, visible=return_natural_response),
gr.Textbox(value=view_name, visible=True),
gr.Code(value=iql_filters, visible=bool(iql_filters)),
gr.Code(value=iql_aggregation, visible=bool(iql_aggregation)),
gr.Code(value=sql, visible=bool(sql)),
gr.Textbox(value=textual_response, visible=return_natural_response),
retrieved_rows,
empty_retrieved_rows_warning,
log_content,
Expand Down Expand Up @@ -224,6 +207,36 @@ def _load_results_into_dataframe(results: List[Dict[str, Any]]) -> pd.DataFrame:
"""
return pd.DataFrame(json.loads(json.dumps(results, default=str)))

def _render_param(self, param: MethodParamWithTyping) -> str:
if param.similarity_index:
return f"{param.name}: {str(param.type).replace('typing.', '')}"
return str(param)

def _render_tab_data(self, data: pd.DataFrame) -> None:
with gr.Tab("Data"):
if data.empty:
gr.Label("No data available", show_label=False)
else:
gr.Dataframe(value=data, height=320)

def _render_tab_iql(self, methods: List[ExposedFunction], label: str) -> None:
with gr.Tab(f"IQL {label}"):
if methods:
gr.Dataframe(
value=[
[
f"{method.name}({', '.join(self._render_param(param) for param in method.parameters)})",
method.description,
]
for method in methods
],
headers=["signature", "description"],
interactive=False,
height=325,
)
else:
gr.Label(f"No {label.lower()} available", show_label=False)

def create_interface(self) -> gr.Interface:
"""
Creates a Gradio interface for interacting with the collection.
Expand Down Expand Up @@ -299,18 +312,41 @@ def create_interface(self) -> gr.Interface:
value=selected_view,
interactive=bool(selected_view),
)
if selected_view:
view_preview, view_preview_label = self._render_view_preview(selected_view)
else:
view_preview, view_preview_label = self._render_dataframe(
pd.DataFrame(), "No view selected"
)

@gr.render(inputs=view_dropdown, triggers=[demo.load, view_dropdown.change])
def render_view_preview(view_name: Optional[str]) -> None:
if view_name is None:
gr.Label("No views", show_label=False)
return

view = self.collection.get(view_name)

if not isinstance(view, BaseStructuredView):
gr.Label(value="Preview not available", show_label=False)
return

result = view.execute()
data = self._load_results_into_dataframe(result.results)
if self.preview_limit is not None:
data = data.head(self.preview_limit)

filters = view.list_filters()
aggregations = view.list_aggregations()

self._render_tab_data(data)
self._render_tab_iql(filters, "Filters")
self._render_tab_iql(aggregations, "Aggregations")

with gr.Tab("Results"):
natural_language_response = gr.Textbox(
label="Natural Language Response",
visible=False,
)
selected_view_name = gr.Textbox(
label="Selected View",
visible=False,
max_lines=1,
)

with gr.Row():
iql_fitlers_result = gr.Code(
Expand Down Expand Up @@ -367,6 +403,7 @@ def create_interface(self) -> gr.Interface:
[
natural_language_response_checkbox,
natural_language_response,
selected_view_name,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
Expand All @@ -379,21 +416,14 @@ def create_interface(self) -> gr.Interface:
fn=self._clear_results,
outputs=[
natural_language_response,
selected_view_name,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
retrieved_rows,
retrieved_rows_label,
],
)
view_dropdown.change(
fn=self._render_view_preview,
inputs=view_dropdown,
outputs=[
view_preview,
view_preview_label,
],
)
ask_button.click(
fn=self._ask_collection,
inputs=[
Expand All @@ -403,10 +433,11 @@ def create_interface(self) -> gr.Interface:
natural_language_response_checkbox,
],
outputs=[
natural_language_response,
selected_view_name,
iql_fitlers_result,
iql_aggregation_result,
sql_result,
natural_language_response,
retrieved_rows,
retrieved_rows_label,
log_console,
Expand Down

0 comments on commit da3238f

Please sign in to comment.