diff --git a/CONFIG.ts b/CONFIG.ts index 51fa4833..472374ea 100644 --- a/CONFIG.ts +++ b/CONFIG.ts @@ -11,7 +11,10 @@ const config = { // When opening the viewer, or refreshing the page, the viewer will revert to the following default dataset data:{ // Default dataset URL (must be publically accessible) - default_dataset: "https://public.czbiohub.org/royerlab/zoo/Zebrafish/tracks_zebrafish_bundle.zarr/" + // default_dataset: "https://public.czbiohub.org/royerlab/zoo/Zebrafish/tracks_zebrafish_bundle.zarr/" + // default_dataset: "https://public.czbiohub.org/royerlab/zoo/misc/tracks_ascidian_withSize_withFeatures_bundle.zarr/" + // default_dataset: "https://public.czbiohub.org/royerlab/zoo/misc/tracks_zebrafish_smooth_displ_bundle.zarr/" + default_dataset: "https://public.czbiohub.org/royerlab/zoo/misc/tracks_drosophila_smooth_displ_bundle.zarr/" }, // Default settings for certain parameters @@ -19,8 +22,15 @@ const config = { // Maximum number of cells a user can select without getting a warning max_num_selected_cells: 100, - // Choose colormap for the tracks, options: viridis-inferno, magma-inferno, inferno-inferno, plasma-inferno, cividis-inferno [default] + // Choose colormap for the tracks + // options: viridis-inferno, magma-inferno, inferno-inferno, plasma-inferno, cividis-inferno [default] colormap_tracks: "cividis-inferno", + + // Choose colormap for coloring the cells, when the attribute is continuous or categorical + // options: HSL, viridis, plasma, inferno, magma, cividis + colormap_colorby_categorical: "HSL", + colormap_colorby_continuous: "plasma", + // Point size (arbitrary units), if cell sizes not provided in zarr attributes point_size: 0.1, diff --git a/python/src/intracktive/_tests/test_convert.py b/python/src/intracktive/_tests/test_convert.py index 7a936aa4..dfe4a11d 100644 --- a/python/src/intracktive/_tests/test_convert.py +++ b/python/src/intracktive/_tests/test_convert.py @@ -43,7 +43,8 @@ def test_actual_zarr_content(tmp_path: Path, make_sample_data: pd.DataFrame) -> convert_dataframe_to_zarr( df=df, zarr_path=new_path, - extra_cols=["radius"], + add_radius=True, + extra_cols=(), ) new_data = zarr.open(new_path) diff --git a/python/src/intracktive/convert.py b/python/src/intracktive/convert.py index d17976f7..c452a5c7 100644 --- a/python/src/intracktive/convert.py +++ b/python/src/intracktive/convert.py @@ -85,7 +85,9 @@ def get_unique_zarr_path(zarr_path: Path) -> Path: def convert_dataframe_to_zarr( df: pd.DataFrame, zarr_path: Path, + add_radius: bool = False, extra_cols: Iterable[str] = (), + pre_normalized: bool = False, ) -> Path: """ Convert a DataFrame of tracks to a sparse Zarr store @@ -113,11 +115,18 @@ def convert_dataframe_to_zarr( flag_2D = True df["z"] = 0.0 + points_cols = ( + ["z", "y", "x", "radius"] if add_radius else ["z", "y", "x"] + ) # columns to store in the points array extra_cols = list(extra_cols) - columns = REQUIRED_COLUMNS + extra_cols - points_cols = ["z", "y", "x"] + extra_cols # columns to store in the points array - - for col in columns: + columns_to_check = ( + REQUIRED_COLUMNS + ["radius"] if add_radius else REQUIRED_COLUMNS + ) # columns to check for in the DataFrame + columns_to_check = columns_to_check + extra_cols + print("point_cols:", points_cols) + print("columns_to_check:", columns_to_check) + + for col in columns_to_check: if col not in df.columns: raise ValueError(f"Column '{col}' not found in the DataFrame") @@ -144,7 +153,7 @@ def convert_dataframe_to_zarr( n_tracklets = df["track_id"].nunique() # (z, y, x) + extra_cols - num_values_per_point = 3 + len(extra_cols) + num_values_per_point = 4 if add_radius else 3 # store the points in an array points_array = ( @@ -154,6 +163,14 @@ def convert_dataframe_to_zarr( ) * INF_SPACE ) + attribute_array_empty = ( + np.ones( + (n_time_points, max_values_per_time_point), + dtype=np.float32, + ) + * INF_SPACE + ) + attribute_arrays = {} points_to_tracks = lil_matrix( (n_time_points * max_values_per_time_point, n_tracklets), dtype=np.int32 @@ -165,10 +182,18 @@ def convert_dataframe_to_zarr( points_array[t, : group_size * num_values_per_point] = ( group[points_cols].to_numpy().ravel() ) + points_ids = t * max_values_per_time_point + np.arange(group_size) points_to_tracks[points_ids, group["track_id"] - 1] = 1 + for col in extra_cols: + attribute_array = attribute_array_empty.copy() + for t, group in df.groupby("t"): + group_size = len(group) + attribute_array[t, :group_size] = group[col].to_numpy().ravel() + attribute_arrays[col] = attribute_array + LOG.info(f"Munged {len(df)} points in {time.monotonic() - start} seconds") # creating mapping of tracklets parent-child relationship @@ -233,16 +258,31 @@ def convert_dataframe_to_zarr( chunks=(1, points_array.shape[1]), dtype=np.float32, ) + print("points shape:", points.shape) points.attrs["values_per_point"] = num_values_per_point + if len(extra_cols) > 0: + attributes_matrix = np.hstack( + [attribute_arrays[attr] for attr in attribute_arrays] + ) + attributes = top_level_group.create_dataset( + "attributes", + data=attributes_matrix, + chunks=(1, attribute_array.shape[1]), + dtype=np.float32, + ) + attributes.attrs["columns"] = extra_cols + attributes.attrs["pre_normalized"] = pre_normalized + mean = df[["z", "y", "x"]].mean() extent = (df[["z", "y", "x"]] - mean).abs().max() extent_xyz = extent.max() for col in ("z", "y", "x"): points.attrs[f"mean_{col}"] = mean[col] + points.attrs["extent_xyz"] = extent_xyz - points.attrs["fields"] = ["z", "y", "x"] + extra_cols + points.attrs["fields"] = points_cols points.attrs["ndim"] = 2 if flag_2D else 3 top_level_group.create_groups( @@ -355,10 +395,26 @@ def dataframe_to_browser(df: pd.DataFrame, zarr_dir: Path) -> None: default=False, type=bool, ) +@click.option( + "--add_attributes", + is_flag=True, + help="Boolean indicating whether to include extra columns of the CSV as attributes for colors the cells in the viewer", + default=False, + type=bool, +) +@click.option( + "--pre_normalized", + is_flag=True, + help="Boolean indicating whether the extra columns with attributes are prenormalized to [0,1]", + default=False, + type=bool, +) def convert_cli( csv_file: Path, out_dir: Path | None, add_radius: bool, + add_attributes: bool, + pre_normalized: bool, ) -> None: """ Convert a CSV of tracks to a sparse Zarr store @@ -372,16 +428,22 @@ def convert_cli( zarr_path = out_dir / f"{csv_file.stem}_bundle.zarr" - extra_cols = ["radius"] if add_radius else [] - tracks_df = pd.read_csv(csv_file) + extra_cols = [] + if add_attributes: + columns_standard = REQUIRED_COLUMNS + extra_cols = tracks_df.columns.difference(columns_standard).to_list() + print("extra_cols:", extra_cols) + LOG.info(f"Read {len(tracks_df)} points in {time.monotonic() - start} seconds") convert_dataframe_to_zarr( tracks_df, zarr_path, + add_radius, extra_cols=extra_cols, + pre_normalized=pre_normalized, ) LOG.info(f"Full conversion took {time.monotonic() - start} seconds") diff --git a/src/components/App.tsx b/src/components/App.tsx index 46d82a9d..f036abf3 100644 --- a/src/components/App.tsx +++ b/src/components/App.tsx @@ -16,8 +16,9 @@ import { TrackManager, loadTrackManager } from "@/lib/TrackManager"; import { PointSelectionMode } from "@/lib/PointSelector"; import LeftSidebarWrapper from "./leftSidebar/LeftSidebarWrapper"; // import { TimestampOverlay } from "./overlays/TimestampOverlay"; -import { ColorMap } from "./overlays/ColorMap"; +import { ColorMapTracks, ColorMapCells } from "./overlays/ColorMap.tsx"; import { TrackDownloadData } from "./DownloadButton"; +import { numberOfDefaultColorByOptions } from "@/components/leftSidebar/DynamicDropdown.tsx"; import config from "../../CONFIG.ts"; const brandingName = config.branding.name || undefined; @@ -222,6 +223,7 @@ export default function App() { const getPoints = async (time: number) => { console.debug("fetch points at time %d", time); const data = await trackManager.fetchPointsAtTime(time); + // console.log('data shape:', data.length, 'attributes shape:', attributes.length); console.debug("got %d points for time %d", data.length / 3, time); if (ignore) { @@ -229,10 +231,18 @@ export default function App() { return; } + let attributes; + if (canvas.colorByEvent.action === "provided" || canvas.colorByEvent.action === "provided-normalized") { + attributes = await trackManager.fetchAttributessAtTime( + time, + canvas.colorByEvent.label - numberOfDefaultColorByOptions, + ); + } + // clearing the timeout prevents the loading indicator from showing at all if the fetch is fast clearTimeout(loadingTimeout); setIsLoadingPoints(false); - dispatchCanvas({ type: ActionType.POINTS_POSITIONS, positions: data }); + dispatchCanvas({ type: ActionType.POINTS_POSITIONS, positions: data, attributes: attributes }); }; getPoints(canvas.curTime); } else { @@ -250,7 +260,7 @@ export default function App() { clearTimeout(loadingTimeout); ignore = true; }; - }, [canvas.curTime, dispatchCanvas, trackManager]); + }, [canvas.curTime, canvas.colorByEvent, dispatchCanvas, trackManager]); // This fetches track IDs based on the selected point IDs. useEffect(() => { @@ -459,6 +469,13 @@ export default function App() { toggleAxesVisible={() => { dispatchCanvas({ type: ActionType.TOGGLE_AXES }); }} + colorBy={canvas.colorBy} + toggleColorBy={(colorBy: boolean) => { + dispatchCanvas({ type: ActionType.TOGGLE_COLOR_BY, colorBy }); + }} + changeColorBy={(event: string) => { + dispatchCanvas({ type: ActionType.CHANGE_COLOR_BY, event }); + }} /> @@ -501,7 +518,8 @@ export default function App() { > 0} /> {/* */} - + {numSelectedCells > 0 && } + {canvas.colorByEvent.type !== "default" && } {/* The playback controls */} diff --git a/src/components/DataControls.tsx b/src/components/DataControls.tsx index 7a6f167a..c9b7136b 100644 --- a/src/components/DataControls.tsx +++ b/src/components/DataControls.tsx @@ -149,8 +149,8 @@ export default function DataControls(props: DataControlsProps) {