-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from marimo-team/aka/high-dimensional-visualiz…
…ation example: explore high-dimensional data
- Loading branch information
Showing
4 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Explore high dimensional data | ||
|
||
[](https://marimo.app/github.com/marimo-team/examples/blob/main/explore_high_dimensional_data/explore_high_dimensional_data.py) | ||
|
||
**This template lets you visualize and interactively explore high dimensional | ||
data.** The starter code uses PCA to embed and plot numerical digits, seeing | ||
how they cluster together — when you select points in the plot, the notebook | ||
shows you the underlying images! | ||
|
||
To use this notebook on your own data, just replace the implementations | ||
of the following four functions: | ||
|
||
* `load_data` | ||
* `embed_data` | ||
* `scatter_data` | ||
* `show_selection` | ||
|
||
<img src="https://raw.githubusercontent.com/marimo-team/marimo/main/docs/_static/embedding.gif" width="700px" /> | ||
|
||
## Running this notebook | ||
|
||
Open this notebook in [our online | ||
playground](https://marimo.app/github.com/marimo-team/examples/blob/main/explore_high_dimensional_data/explore_high_dimensional_data.py) | ||
or run it locally. | ||
|
||
### Running locally | ||
|
||
The requirements of each notebook are serialized in them as a top-level | ||
comment. Here are the steps to run the notebook: | ||
|
||
1. [Install `uv`](https://github.com/astral-sh/uv/?tab=readme-ov-file#installation) | ||
2. Open an example with `uvx marimo edit --sandbox <notebook-url>` | ||
|
||
> [!TIP] | ||
> The [`--sandbox` | ||
> flag](https://docs.marimo.io/guides/package_reproducibility/) opens the | ||
> notebook in an isolated virtual environment, automatically installing the | ||
> notebook's dependencies 📦 | ||
You can also open notebooks without `uv`, in which case you'll need to | ||
manually [install marimo](https://docs.marimo.io/getting_started/index.html#installation) | ||
first. Then run `marimo edit <notebook-url>`; however, you'll also need to | ||
install the requirements yourself. |
226 changes: 226 additions & 0 deletions
226
explore_high_dimensional_data/explore_high_dimensional_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
# /// script | ||
# requires-python = ">=3.12" | ||
# dependencies = [ | ||
# "altair==5.5.0", | ||
# "marimo", | ||
# "matplotlib==3.10.0", | ||
# "pandas==2.2.3", | ||
# "polars==1.20.0", | ||
# "scikit-learn==1.6.1", | ||
# ] | ||
# /// | ||
|
||
import marimo | ||
|
||
__generated_with = "0.10.16" | ||
app = marimo.App(width="columns") | ||
|
||
|
||
@app.cell(column=0, hide_code=True) | ||
def _(mo): | ||
mo.md( | ||
""" | ||
**This template lets you visualize and interactively explore high dimensional data.** The starter code uses PCA to embed and plot numerical | ||
digits, seeing how they cluster together — when you select points in the plot, the notebook shows you the underlying images. | ||
The left-hand column implements the core logic; the right-hand column executes the dimensionality reduction and shows the outputs. | ||
**To customize this template to your own data, just implement the functions in | ||
this column.** | ||
""" | ||
) | ||
return | ||
|
||
|
||
@app.cell | ||
def _(): | ||
def load_data(): | ||
""" | ||
Return a tuple of your data: | ||
* The dataset, with each row a different item in the dataset | ||
* A label for each data point. | ||
If your data doesn't have labels, just return a list of all ones | ||
with the same length as the number of items in your dataset. | ||
""" | ||
import sklearn.datasets | ||
|
||
data, labels = sklearn.datasets.load_digits(return_X_y=True) | ||
return data, labels | ||
|
||
return (load_data,) | ||
|
||
|
||
@app.cell | ||
def _(): | ||
def embed_data(data): | ||
""" | ||
Embed the data into two dimensions. The default implementation | ||
uses PCA, but you can also use UMAP, tSNE, or any other dimensionality | ||
reduction algorithm you like. | ||
The starter implementation here uses PCA, and assumes the data is a NumPy | ||
array. | ||
""" | ||
import sklearn | ||
|
||
return sklearn.decomposition.PCA(n_components=2, whiten=True).fit_transform( | ||
data | ||
) | ||
|
||
return (embed_data,) | ||
|
||
|
||
@app.cell | ||
def _(pl): | ||
def scatter_data(df: pl.DataFrame) -> alt.Chart: | ||
""" | ||
Visualize the embedded data using an Altair scatterplot. | ||
- df is a Polars dataframe with the following columns: | ||
* x: the first coordinate of the embedding | ||
* y: the second coordinate of the embedding | ||
* label: a label identifying each item, for coloring | ||
Modify the starter implementation to suit your needs, but make sure | ||
to return an altair chart. | ||
""" | ||
import altair as alt | ||
|
||
return ( | ||
alt.Chart(df) | ||
.mark_circle() | ||
.encode( | ||
x=alt.X("x:Q").scale(domain=(-2.5, 2.5)), | ||
y=alt.Y("y:Q").scale(domain=(-2.5, 2.5)), | ||
color=alt.Color("label:N"), | ||
) | ||
.properties(width=500, height=500) | ||
) | ||
|
||
return (scatter_data,) | ||
|
||
|
||
@app.cell | ||
def _(): | ||
def show_selection(data, rows, max_rows=10): | ||
""" | ||
Visualize selected rows of the data. | ||
- `data` is the data returned from `load_data` | ||
- `rows` is a list or array of row indices | ||
- `max_rows` is the maximum number of rows to display | ||
""" | ||
import matplotlib.pyplot as plt | ||
|
||
# show 10 images: either the first 10 from the selection, or the first ten | ||
# selected in the table | ||
rows = rows[:max_rows] | ||
images = data.reshape((-1, 8, 8))[rows] | ||
fig, axes = plt.subplots(1, len(rows)) | ||
fig.set_size_inches(12.5, 1.5) | ||
if len(rows) > 1: | ||
for im, ax in zip(images, axes.flat): | ||
ax.imshow(im, cmap="gray") | ||
ax.set_yticks([]) | ||
ax.set_xticks([]) | ||
else: | ||
axes.imshow(images[0], cmap="gray") | ||
axes.set_yticks([]) | ||
axes.set_xticks([]) | ||
plt.tight_layout() | ||
return fig | ||
|
||
return (show_selection,) | ||
|
||
|
||
@app.cell(column=1, hide_code=True) | ||
def _(mo): | ||
mo.md("""# Explore high dimensional data""") | ||
return | ||
|
||
|
||
@app.cell(hide_code=True) | ||
def _(mo): | ||
mo.md( | ||
""" | ||
Here's an **embedding** of your data, with similar points close to each other. | ||
This notebook will automatically drill down into points you **select with | ||
your mouse**; try it! | ||
""" | ||
) | ||
return | ||
|
||
|
||
@app.cell | ||
def _(load_data): | ||
data, labels = load_data() | ||
return data, labels | ||
|
||
|
||
@app.cell | ||
def _(): | ||
import polars as pl | ||
|
||
return (pl,) | ||
|
||
|
||
@app.cell | ||
def _(data, embed_data, labels, pl): | ||
X_embedded = embed_data(data) | ||
|
||
embedding = pl.DataFrame( | ||
{ | ||
"x": X_embedded[:, 0], | ||
"y": X_embedded[:, 1], | ||
"label": labels, | ||
"index": list(range(X_embedded.shape[0])), | ||
} | ||
) | ||
return X_embedded, embedding | ||
|
||
|
||
@app.cell | ||
def _(embedding, mo, scatter_data): | ||
chart = mo.ui.altair_chart(scatter_data(embedding)) | ||
chart | ||
return (chart,) | ||
|
||
|
||
@app.cell | ||
def _(chart, mo): | ||
table = mo.ui.table(chart.value) | ||
return (table,) | ||
|
||
|
||
@app.cell | ||
def _(chart, data, mo, show_selection, table): | ||
mo.stop(not len(chart.value)) | ||
|
||
selected_rows = show_selection(data, list(chart.value["index"])) | ||
|
||
mo.md( | ||
f""" | ||
**Here's a preview of the items you've selected**: | ||
{mo.as_html(selected_rows)} | ||
Here's all the data you've selected. | ||
{table} | ||
""" | ||
) | ||
return (selected_rows,) | ||
|
||
|
||
@app.cell | ||
def _(): | ||
import marimo as mo | ||
|
||
return (mo,) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters