Skip to content

Commit

Permalink
feat: Implement native zarrita zarr loader
Browse files Browse the repository at this point in the history
  • Loading branch information
manzt committed Jul 15, 2024
1 parent 7a34f6b commit 4c6dcce
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 69 deletions.
131 changes: 131 additions & 0 deletions src/ZarrPixelSource.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import * as zarr from 'zarrita';
import { DTYPE_VALUES } from '@vivjs/constants';

import type * as viv from '@vivjs/types';
import type { Readable } from '@zarrita/storage';

import { assert } from './utils';
import { getImageSize } from '@hms-dbmi/viv';

// TODO: Export from top-level zarrita
type Slice = ReturnType<typeof zarr.slice>;

const xKey = 'x';
const yKey = 'y';
const rgbaChannelKey = '_c';

export class ZarrPixelSource<S extends Array<string> = Array<string>> implements viv.PixelSource<S> {
#arr: zarr.Array<zarr.DataType, Readable>;
readonly labels: viv.Labels<S>;
readonly tileSize: number;
readonly dtype: viv.SupportedDtype;

constructor(
arr: zarr.Array<zarr.DataType, Readable>,
options: {
labels: viv.Labels<S>;
tileSize: number;
}
) {
this.#arr = arr;
this.labels = options.labels;
this.tileSize = options.tileSize;
const vivDtype = capitalize(arr.dtype);
assert(isSupportedDtype(vivDtype), `Unsupported dtype: ${vivDtype}`);
this.dtype = vivDtype;
}

get shape() {
return this.#arr.shape;
}

async getRaster(options: {
selection: viv.PixelSourceSelection<S> | Array<number>;
signal?: AbortSignal;
}): Promise<viv.PixelData> {
const { selection, signal } = options;
return this.#fetchData(buildZarrQuery(this.labels, selection), { signal });
}

async getTile(options: {
x: number;
y: number;
selection: viv.PixelSourceSelection<S> | Array<number>;
signal?: AbortSignal;
}): Promise<viv.PixelData> {
const { x, y, selection, signal } = options;
const sel = buildZarrQuery(this.labels, selection);

const { height, width } = getImageSize(this);
const [xStart, xStop] = [x * this.tileSize, Math.min((x + 1) * this.tileSize, width)];
const [yStart, yStop] = [y * this.tileSize, Math.min((y + 1) * this.tileSize, height)];

// Deck.gl can sometimes request edge tiles that don't exist. We throw
// a BoundsCheckError which is picked up in `ZarrPixelSource.onTileError`
// and ignored by deck.gl.
if (xStart === xStop || yStart === yStop) {
throw new BoundsCheckError('Tile slice is zero-sized.');
}
if (xStart < 0 || yStart < 0 || xStop > width || yStop > height) {
throw new BoundsCheckError('Tile slice is out of bounds.');
}

sel[this.labels.indexOf(xKey)] = zarr.slice(xStart, xStop);
sel[this.labels.indexOf(yKey)] = zarr.slice(yStart, yStop);
return this.#fetchData(sel, { signal });
}

onTileError(err: Error): void {
if (err instanceof BoundsCheckError) {
return;
}
throw err;
}

async #fetchData(selection: Array<number | Slice>, options: { signal?: AbortSignal }): Promise<viv.PixelData> {
const {
data,
shape: [height, width],
} = await zarr.get(this.#arr, selection, {
// @ts-expect-error this is ok for now and should be supported by all backends
signal: options.signal,
});
return { data: data as viv.SupportedTypedArray, width, height };
}
}

function buildZarrQuery(labels: string[], selection: Record<string, number> | Array<number>): Array<Slice | number> {
let sel: Array<Slice | number>;
if (Array.isArray(selection)) {
// shallow copy
sel = [...selection];
} else {
// initialize with zeros
sel = Array.from({ length: labels.length }, () => 0);
// fill in the selection
for (const [key, idx] of Object.entries(selection)) {
sel[labels.indexOf(key)] = idx;
}
}
sel[labels.indexOf(xKey)] = zarr.slice(null);
sel[labels.indexOf(xKey)] = zarr.slice(null);
if (rgbaChannelKey in labels) {
sel[labels.indexOf(rgbaChannelKey)] = zarr.slice(null);
}
return sel;
}

function capitalize(s: string) {
return s[0].toUpperCase() + s.slice(1);
}

function isSupportedDtype(dtype: string): dtype is viv.SupportedDtype {
return dtype in DTYPE_VALUES;
}

class BoundsCheckError extends Error {
name = 'BoundsCheckError';
constructor(message?: string) {
super(message);
}
}
3 changes: 2 additions & 1 deletion src/gridLayer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import { SolidPolygonLayer, TextLayer } from '@deck.gl/layers';
import type { CompositeLayerProps } from '@deck.gl/core/lib/composite-layer';
import pMap from 'p-map';

import { XRLayer, ZarrPixelSource, ColorPaletteExtension } from '@hms-dbmi/viv';
import { XRLayer, ColorPaletteExtension } from '@hms-dbmi/viv';
import type { BaseLayerProps } from './state';
import type { ZarrPixelSource } from './ZarrPixelSource';
import { assert } from './utils';

export interface GridLoader {
Expand Down
11 changes: 7 additions & 4 deletions src/io.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { ImageLayer, MultiscaleImageLayer, ZarrPixelSource } from '@hms-dbmi/viv';
import { ImageLayer, MultiscaleImageLayer } from '@hms-dbmi/viv';
import * as zarr from '@zarrita/core';
import type { Readable } from '@zarrita/storage';
import GridLayer from './gridLayer';
import { loadOmeroMultiscales, loadPlate, loadWell } from './ome';
import { ZarrPixelSource } from './ZarrPixelSource';
import type { ImageLayerConfig, LayerState, MultichannelConfig, SingleChannelConfig, SourceData } from './state';
import {
COLORS,
Expand All @@ -20,12 +21,14 @@ import {
range,
calcDataRange,
calcConstrastLimits,
createZarrArrayAdapter,
resolveAttrs,
assert,
} from './utils';

async function loadSingleChannel(config: SingleChannelConfig, data: ZarrPixelSource<string[]>[]): Promise<SourceData> {
async function loadSingleChannel(
config: SingleChannelConfig,
data: Array<ZarrPixelSource<string[]>>
): Promise<SourceData> {
const { color, contrast_limits, visibility, name, colormap = '', opacity = 1 } = config;
const lowres = data[data.length - 1];
const selection = Array(data[0].shape.length).fill(0);
Expand Down Expand Up @@ -156,7 +159,7 @@ export async function createSourceData(config: ImageLayerConfig): Promise<Source
const { channel_axis, labels } = getAxisLabelsAndChannelAxis(config, axes, data[0]);

const tileSize = guessTileSize(data[0]);
const loader = data.map((d) => new ZarrPixelSource(createZarrArrayAdapter(d), labels, tileSize));
const loader = data.map((d) => new ZarrPixelSource(d, { labels, tileSize }));
const [base] = loader;

// If explicit channel axis is provided, try to load as multichannel.
Expand Down
9 changes: 4 additions & 5 deletions src/ome.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import { ZarrPixelSource } from '@hms-dbmi/viv';
import pMap from 'p-map';
import * as zarr from '@zarrita/core';
import type { Readable } from '@zarrita/storage';
import type { ImageLayerConfig, SourceData } from './state';
import {
assert,
calcConstrastLimits,
createZarrArrayAdapter,
getAttrsOnly,
getDefaultColors,
getDefaultVisibilities,
Expand All @@ -19,6 +17,7 @@ import {
range,
resolveAttrs,
} from './utils';
import { ZarrPixelSource } from './ZarrPixelSource';

export async function loadWell(
config: ImageLayerConfig,
Expand Down Expand Up @@ -79,7 +78,7 @@ export async function loadWell(
name: String(offset),
row,
col,
loader: new ZarrPixelSource(createZarrArrayAdapter(data[offset]), axis_labels, tileSize),
loader: new ZarrPixelSource(data[offset], { labels: axis_labels, tileSize }),
};
});
});
Expand Down Expand Up @@ -190,7 +189,7 @@ export async function loadPlate(
name: `${row}${col}`,
row: rows.indexOf(row),
col: columns.indexOf(col),
loader: new ZarrPixelSource(createZarrArrayAdapter(d[1]), axis_labels, tileSize),
loader: new ZarrPixelSource(d[1], { labels: axis_labels, tileSize }),
};
});
let meta;
Expand Down Expand Up @@ -250,7 +249,7 @@ export async function loadOmeroMultiscales(
const meta = parseOmeroMeta(attrs.omero, axes);
const tileSize = guessTileSize(data[0]);

const loader = data.map((arr) => new ZarrPixelSource(createZarrArrayAdapter(arr), axis_labels, tileSize));
const loader = data.map((arr) => new ZarrPixelSource(arr, { labels: axis_labels, tileSize }));
return {
loader: loader,
axis_labels,
Expand Down
5 changes: 3 additions & 2 deletions src/state.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import type { ImageLayer, MultiscaleImageLayer, ZarrPixelSource } from '@hms-dbmi/viv';
import type { ImageLayer, MultiscaleImageLayer } from '@hms-dbmi/viv';
import type { Matrix4 } from 'math.gl';
import type { PrimitiveAtom, WritableAtom } from 'jotai';
import { atom } from 'jotai';
import { atomFamily, splitAtom, waitForAll } from 'jotai/utils';

import type { Readable } from '@zarrita/storage';
import type { default as GridLayer, GridLayerProps, GridLoader } from './gridLayer';
import type { ZarrPixelSource } from './ZarrPixelSource';
import { initLayerStateFromSource } from './io';

export interface ViewState {
Expand Down Expand Up @@ -57,7 +58,7 @@ export interface SingleChannelConfig extends BaseConfig {
export type ImageLayerConfig = MultichannelConfig | SingleChannelConfig;

export type SourceData = {
loader: ZarrPixelSource<string[]>[];
loader: ZarrPixelSource[];
loaders?: GridLoader[]; // for OME plates
rows?: number;
columns?: number;
Expand Down
61 changes: 4 additions & 57 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import type { ZarrPixelSource } from '@hms-dbmi/viv';
import type { Slice } from '@zarrita/indexing';
import * as zarr from '@zarrita/core';
import { slice, get } from '@zarrita/indexing';
import { FetchStore, Readable } from '@zarrita/storage';
import { Matrix4 } from 'math.gl';

import type { ZarrPixelSource } from './ZarrPixelSource';
import { lru } from './lru-store';
import type { ViewState } from './state';

Expand Down Expand Up @@ -268,9 +266,9 @@ export function parseMatrix(model_matrix?: string | number[]): Matrix4 {
return matrix;
}

export async function calcDataRange<S extends string[]>(
source: ZarrPixelSource<S>,
selection: number[]
export async function calcDataRange(
source: ZarrPixelSource,
selection: Array<number>
): Promise<[min: number, max: number]> {
if (source.dtype === 'Uint8') return [0, 255];
const { data } = await source.getRaster({ selection });
Expand Down Expand Up @@ -353,57 +351,6 @@ export function typedEmitter<T>() {
};
}

function getV2DataType(dtype: string) {
const mapping: Record<string, string> = {
int8: '|i1',
uint8: '|u1',
int16: '<i2',
uint16: '<u2',
int32: '<i4',
uint32: '<u4',
int64: '<i8',
uint64: '<u8',
float32: '<f4',
float64: '<f8',
};
assert(dtype in mapping, `Unsupported dtype ${dtype}`);
return mapping[dtype];
}

type Selection = (number | Omit<Slice, 'indices'> | null)[];

/**
* This is needed by @hms-dbmi/viv to get raw data from a Zarr.js style interface.
*/
export function createZarrArrayAdapter(arr: zarr.Array<zarr.DataType>): any {
return new Proxy(arr, {
get(target, prop) {
if (prop === 'getRaw') {
return (selection: Selection) => {
return get(
target,
selection.map((s) => {
if (typeof s === 'object' && s !== null) {
return slice(s.start, s.stop, s.step);
}
return s;
})
);
};
}
if (prop === 'getRawChunk') {
return (selection: number[], options: { storeOptions: RequestInit }) => {
return target.getChunk(selection, options.storeOptions);
};
}
if (prop === 'dtype') {
return getV2DataType(target.dtype);
}
return Reflect.get(target, prop);
},
});
}

/**
* Extracts the OME metadata from the zarr attributes
*
Expand Down

0 comments on commit 4c6dcce

Please sign in to comment.