Skip to content

Commit

Permalink
Fix logic for detecting changes in gr.Dataframe table value (#10360)
Browse files Browse the repository at this point in the history
* changes

* changes

* changes

* add changeset

* revert

* add changeset

* change

* changes

* add changeset

* add changeset

* allow non-string headers

* add changeset

* changes

* docs

* remove

* changes

* changes

* revert

* refactoring

* changes

* changes

* revert

* revert

* add changeset

* fix

* add changeset

* clean up

* cleanup

* more cleanup

* notebook

* test

* format

* add changeset

* backend

* format'

* add changeset

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Jan 17, 2025
1 parent 40e0c48 commit 31cccc3
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 92 deletions.
6 changes: 6 additions & 0 deletions .changeset/social-jokes-rest.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/dataframe": patch
"gradio": patch
---

fix:Fix logic for detecting changes in `gr.Dataframe` table value
1 change: 1 addition & 0 deletions demo/dataframe_streaming/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: dataframe_streaming"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import pandas as pd\n", "import time\n", "\n", "def update_dataframe(df):\n", " df.iloc[:, :] = 1\n", " yield df, 1\n", " time.sleep(0.1)\n", " df.iloc[:, :] = 2\n", " yield df, 2\n", "\n", "initial_df = pd.DataFrame(0, index=range(5), columns=range(5))\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " button = gr.Button(\"Update DataFrame\")\n", " number = gr.Number(value=0, label=\"Number\")\n", " dataframe = gr.Dataframe(value=initial_df, label=\"Dataframe\")\n", " button.click(fn=update_dataframe, inputs=dataframe, outputs=[dataframe, number])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
22 changes: 22 additions & 0 deletions demo/dataframe_streaming/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import gradio as gr
import pandas as pd
import time

def update_dataframe(df):
df.iloc[:, :] = 1
yield df, 1
time.sleep(0.1)
df.iloc[:, :] = 2
yield df, 2

initial_df = pd.DataFrame(0, index=range(5), columns=range(5))

with gr.Blocks() as demo:
with gr.Row():
button = gr.Button("Update DataFrame")
number = gr.Number(value=0, label="Number")
dataframe = gr.Dataframe(value=initial_df, label="Dataframe")
button.click(fn=update_dataframe, inputs=dataframe, outputs=[dataframe, number])

if __name__ == "__main__":
demo.launch()
50 changes: 40 additions & 10 deletions gradio/components/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,35 @@ def preprocess(
+ ". Please choose from: 'pandas', 'numpy', 'array', 'polars'."
)

@staticmethod
def _is_empty(
value: pd.DataFrame
| Styler
| np.ndarray
| pl.DataFrame
| list
| list[list]
| dict
| str
| None,
) -> bool:
import pandas as pd
from pandas.io.formats.style import Styler

if isinstance(value, pd.DataFrame):
return value.empty
elif isinstance(value, Styler):
return value.data.empty # type: ignore
elif isinstance(value, np.ndarray):
return value.size == 0
elif _is_polars_available() and isinstance(value, _import_polars().DataFrame):
return value.is_empty()
elif isinstance(value, list) and len(value) and isinstance(value[0], list):
return len(value[0]) == 0
elif isinstance(value, (list, dict)):
return len(value) == 0
return False

def postprocess(
self,
value: pd.DataFrame
Expand All @@ -241,12 +270,19 @@ def postprocess(
Parameters:
value: Expects data any of these formats: `pandas.DataFrame`, `pandas.Styler`, `numpy.array`, `polars.DataFrame`, `list[list]`, `list`, or a `dict` with keys 'data' (and optionally 'headers'), or `str` path to a csv, which is rendered as the spreadsheet.
Returns:
the uploaded spreadsheet data as an object with `headers` and `data` attributes
the uploaded spreadsheet data as an object with `headers` and `data` keys and optional `metadata` key
"""
import pandas as pd
from pandas.io.formats.style import Styler

if value is None:
if isinstance(value, Styler) and semantic_version.Version(
pd.__version__
) < semantic_version.Version("1.5.0"):
raise ValueError(
"Styler objects are only supported in pandas version 1.5.0 or higher. Please try: `pip install --upgrade pandas` to use this feature."
)

if value is None or self._is_empty(value):
return self.postprocess(self.empty_input)
if isinstance(value, dict):
if len(value) == 0:
Expand All @@ -259,20 +295,14 @@ def postprocess(
value = pd.read_csv(value) # type: ignore
if len(value) == 0:
return DataframeData(
headers=list(value.columns), # type: ignore
headers=[str(col) for col in value.columns], # Convert to strings
data=[[]], # type: ignore
)
return DataframeData(
headers=list(value.columns), # type: ignore
headers=[str(col) for col in value.columns], # Convert to strings
data=value.to_dict(orient="split")["data"], # type: ignore
)
elif isinstance(value, Styler):
if semantic_version.Version(pd.__version__) < semantic_version.Version(
"1.5.0"
):
raise ValueError(
"Styler objects are only supported in pandas version 1.5.0 or higher. Please try: `pip install --upgrade pandas` to use this feature."
)
if self.interactive:
warnings.warn(
"Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead."
Expand Down
78 changes: 11 additions & 67 deletions js/dataframe/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -46,75 +46,18 @@
display: boolean;
}[];
export let max_height: number | undefined = undefined;
export let loading_status: LoadingStatus;
export let interactive: boolean;
let _headers: Headers;
let display_value: string[][] | null;
let styling: string[][] | null;
let values: (string | number)[][];
async function handle_change(data?: {
data: Data;
headers: Headers;
metadata: Metadata;
}): Promise<void> {
let _data = data || value;
_headers = [...(_data.headers || headers)];
values = _data.data ? [..._data.data] : [];
display_value = _data?.metadata?.display_value
? [..._data?.metadata?.display_value]
$: _headers = [...(value.headers || headers)];
$: cell_values = value.data ? [...value.data] : [];
$: display_value = value?.metadata?.display_value
? [...value?.metadata?.display_value]
: null;
$: styling =
!interactive && value?.metadata?.styling
? [...value?.metadata?.styling]
: null;
styling =
!interactive && _data?.metadata?.styling
? [..._data?.metadata?.styling]
: null;
await tick();
gradio.dispatch("change");
if (!value_is_output) {
gradio.dispatch("input");
}
}
handle_change();
afterUpdate(() => {
value_is_output = false;
});
$: {
if (old_value && JSON.stringify(value) !== old_value) {
old_value = JSON.stringify(value);
handle_change();
}
}
if (
(Array.isArray(value) && value?.[0]?.length === 0) ||
value.data?.[0]?.length === 0
) {
value = {
data: [Array(col_count?.[0] || 3).fill("")],
headers: Array(col_count?.[0] || 3)
.fill("")
.map((_, i) => `${i + 1}`),
metadata: null
};
}
async function handle_value_change(data: {
data: Data;
headers: Headers;
metadata: Metadata;
}): Promise<void> {
if (JSON.stringify(data) !== old_value) {
value = { ...data };
old_value = JSON.stringify(value);
handle_change(data);
}
}
</script>

<Block
Expand All @@ -139,11 +82,11 @@
{show_label}
{row_count}
{col_count}
{values}
values={cell_values}
{display_value}
{styling}
headers={_headers}
on:change={(e) => handle_value_change(e.detail)}
on:change={(e) => gradio.dispatch("change")}
on:select={(e) => gradio.dispatch("select", e.detail)}
{wrap}
{datatype}
Expand All @@ -155,5 +98,6 @@
{column_widths}
upload={(...args) => gradio.client.upload(...args)}
stream_handler={(...args) => gradio.client.stream(...args)}
bind:value_is_output
/>
</Block>
27 changes: 13 additions & 14 deletions js/dataframe/shared/Table.svelte
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import { createEventDispatcher, tick, onMount } from "svelte";
import { afterUpdate, createEventDispatcher, tick, onMount } from "svelte";
import { dsvFormat } from "d3-dsv";
import { dequal } from "dequal/lite";
import { copy } from "@gradio/utils";
Expand Down Expand Up @@ -36,6 +36,7 @@
export let column_widths: string[] = [];
export let upload: Client["upload"];
export let stream_handler: Client["stream"];
export let value_is_output = false;
let selected: false | [number, number] = false;
let clicked_cell: { row: number; col: number } | undefined = undefined;
Expand All @@ -44,11 +45,8 @@
let t_rect: DOMRectReadOnly;
const dispatch = createEventDispatcher<{
change: {
data: (string | number)[][];
headers: string[];
metadata: Metadata;
};
change: undefined;
input: undefined;
select: SelectData;
}>();
Expand Down Expand Up @@ -160,21 +158,18 @@
$: if (!dequal(values, old_val)) {
data = process_data(values as (string | number)[][]);
old_val = values as (string | number)[][];
old_val = JSON.parse(JSON.stringify(values)) as (string | number)[][];
}
let data: { id: string; value: string | number }[][] = [[]];
let old_val: undefined | (string | number)[][] = undefined;
async function trigger_change(): Promise<void> {
dispatch("change", {
data: data.map((r) => r.map(({ value }) => value)),
headers: _headers.map((h) => h.value),
metadata: editable
? null
: { display_value: display_value, styling: styling }
});
dispatch("change");
if (!value_is_output) {
dispatch("input");
}
}
function get_sort_status(
Expand Down Expand Up @@ -754,6 +749,10 @@
}
}
}
afterUpdate(() => {
value_is_output = false;
});
</script>

<svelte:window on:resize={() => set_cell_widths()} />
Expand Down
8 changes: 8 additions & 0 deletions js/spa/test/dataframe_streaming.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { test, expect } from "@self/tootils";

test("DataFrame updates correctly when button is clicked", async ({ page }) => {
await page.getByRole("button", { name: "Update DataFrame" }).click();
await expect(
page.getByRole("table", { name: "Dataframe" }).locator("td").first()
).toHaveText("2");
});
6 changes: 5 additions & 1 deletion test/components/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def test_postprocess(self):
"""
dataframe_output = gr.Dataframe()
output = dataframe_output.postprocess([]).model_dump()
assert output == {"data": [[]], "headers": ["1", "2", "3"], "metadata": None}
assert output == {
"data": [["", "", ""]],
"headers": ["1", "2", "3"],
"metadata": None,
}
output = dataframe_output.postprocess(np.zeros((2, 2))).model_dump()
assert output == {
"data": [[0, 0], [0, 0]],
Expand Down

0 comments on commit 31cccc3

Please sign in to comment.