diff --git a/app/packages/looker/src/lookers/abstract.ts b/app/packages/looker/src/lookers/abstract.ts index 7d18e8bddb..3eca774de8 100644 --- a/app/packages/looker/src/lookers/abstract.ts +++ b/app/packages/looker/src/lookers/abstract.ts @@ -293,6 +293,11 @@ export abstract class AbstractLooker< return; } + if (this.state.destroyed && this.sampleOverlays) { + // close all current overlays + this.pluckedOverlays.forEach((overlay) => overlay.cleanup?.()); + } + if ( !this.state.windowBBox || this.state.destroyed || diff --git a/app/packages/looker/src/overlays/base.ts b/app/packages/looker/src/overlays/base.ts index fd817ecf9d..a3ec867766 100644 --- a/app/packages/looker/src/overlays/base.ts +++ b/app/packages/looker/src/overlays/base.ts @@ -3,6 +3,7 @@ */ import { getCls, sizeBytesEstimate } from "@fiftyone/utilities"; +import { OverlayMask } from "../numpy"; import type { BaseState, Coordinates, NONFINITE } from "../state"; import { getLabelColor, shouldShowLabelTag } from "./util"; @@ -39,6 +40,11 @@ export interface SelectData { frameNumber?: number; } +export type LabelMask = { + bitmap?: ImageBitmap; + data?: OverlayMask; +}; + export interface RegularLabel extends BaseLabel { _id?: string; label?: string; @@ -67,6 +73,7 @@ export interface Overlay> { getPoints(state: Readonly): Coordinates[]; getSelectData(state: Readonly): SelectData; getSizeBytes(): number; + cleanup?(): void; } export abstract class CoordinateOverlay< diff --git a/app/packages/looker/src/overlays/detection.ts b/app/packages/looker/src/overlays/detection.ts index 4930771692..d5c62f3f1b 100644 --- a/app/packages/looker/src/overlays/detection.ts +++ b/app/packages/looker/src/overlays/detection.ts @@ -4,17 +4,19 @@ import { NONFINITES } from "@fiftyone/utilities"; import { INFO_COLOR } from "../constants"; -import { OverlayMask } from "../numpy"; import { BaseState, BoundingBox, Coordinates, NONFINITE } from "../state"; import { distanceFromLineSegment } from "../util"; -import { CONTAINS, CoordinateOverlay, PointInfo, RegularLabel } from "./base"; +import { + CONTAINS, + CoordinateOverlay, + LabelMask, + PointInfo, + RegularLabel, +} from "./base"; import { t } from "./util"; export interface DetectionLabel extends RegularLabel { - mask?: { - data: OverlayMask; - image: ArrayBuffer; - }; + mask?: LabelMask; bounding_box: BoundingBox; // valid for 3D bounding boxes @@ -27,10 +29,8 @@ export interface DetectionLabel extends RegularLabel { export default class DetectionOverlay< State extends BaseState > extends CoordinateOverlay { - private imageData: ImageData; private is3D: boolean; private labelBoundingBox: BoundingBox; - private canvas: HTMLCanvasElement; constructor(field, label) { super(field, label); @@ -40,32 +40,6 @@ export default class DetectionOverlay< } else { this.is3D = false; } - - if (this.label.mask) { - const [height, width] = this.label.mask.data.shape; - - if (!height || !width) { - return; - } - - this.canvas = document.createElement("canvas"); - this.canvas.width = width; - this.canvas.height = height; - this.imageData = new ImageData( - new Uint8ClampedArray(this.label.mask.image), - width, - height - ); - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.mask.data.shape[1], - this.label.mask.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); - } } containsPoint(state: Readonly): CONTAINS { @@ -169,7 +143,7 @@ export default class DetectionOverlay< } private drawMask(ctx: CanvasRenderingContext2D, state: Readonly) { - if (!this.canvas) { + if (!this.label.mask?.bitmap) { return; } @@ -177,8 +151,9 @@ export default class DetectionOverlay< const [x, y] = t(state, tlx, tly); const tmp = ctx.globalAlpha; ctx.globalAlpha = state.options.alpha; + ctx.imageSmoothingEnabled = false; ctx.drawImage( - this.canvas, + this.label.mask.bitmap, x, y, w * state.canvasBBox[2], @@ -285,6 +260,13 @@ export default class DetectionOverlay< const oh = state.strokeWidth / state.canvasBBox[3]; return [(bx - ow) * w, (by - oh) * h, (bw + ow * 2) * w, (bh + oh * 2) * h]; } + + public cleanup(): void { + if (this.label.mask?.bitmap) { + this.label.mask?.bitmap.close(); + console.log(">>>cleanup"); + } + } } export const getDetectionPoints = (labels: DetectionLabel[]): Coordinates[] => { diff --git a/app/packages/looker/src/overlays/heatmap.ts b/app/packages/looker/src/overlays/heatmap.ts index c53a3ad971..e8e8817643 100644 --- a/app/packages/looker/src/overlays/heatmap.ts +++ b/app/packages/looker/src/overlays/heatmap.ts @@ -6,14 +6,16 @@ import { getColor, getRGBA, getRGBAColor, + sizeBytesEstimate, } from "@fiftyone/utilities"; -import { ARRAY_TYPES, OverlayMask, TypedArray } from "../numpy"; +import { ARRAY_TYPES, TypedArray } from "../numpy"; import { BaseState, Coordinates } from "../state"; import { isFloatArray } from "../util"; import { clampedIndex } from "../worker/painter"; import { BaseLabel, CONTAINS, + LabelMask, Overlay, PointInfo, SelectData, @@ -21,13 +23,8 @@ import { } from "./base"; import { strokeCanvasRect, t } from "./util"; -interface HeatMap { - data: OverlayMask; - image: ArrayBuffer; -} - interface HeatmapLabel extends BaseLabel { - map?: HeatMap; + map?: LabelMask; range?: [number, number]; } @@ -45,8 +42,6 @@ export default class HeatmapOverlay private label: HeatmapLabel; private targets?: TypedArray; private readonly range: [number, number]; - private canvas: HTMLCanvasElement; - private imageData: ImageData; constructor(field: string, label: HeatmapLabel) { this.field = field; @@ -68,25 +63,6 @@ export default class HeatmapOverlay if (!width || !height) { return; } - - this.canvas = document.createElement("canvas"); - this.canvas.width = width; - this.canvas.height = height; - - this.imageData = new ImageData( - new Uint8ClampedArray(this.label.map.image), - width, - height - ); - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.map.data.shape[1], - this.label.map.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); } containsPoint(state: Readonly): CONTAINS { @@ -101,22 +77,12 @@ export default class HeatmapOverlay } draw(ctx: CanvasRenderingContext2D, state: Readonly): void { - if (this.imageData) { - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.map.data.shape[1], - this.label.map.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); - + if (this.label.map?.bitmap) { const [tlx, tly] = t(state, 0, 0); const [brx, bry] = t(state, 1, 1); const tmp = ctx.globalAlpha; ctx.globalAlpha = state.options.alpha; - ctx.drawImage(this.canvas, tlx, tly, brx - tlx, bry - tly); + ctx.drawImage(this.label.map.bitmap, tlx, tly, brx - tlx, bry - tly); ctx.globalAlpha = tmp; } @@ -235,6 +201,16 @@ export default class HeatmapOverlay return this.targets[index]; } + + getSizeBytes(): number { + return sizeBytesEstimate(this.label); + } + + public cleanup(): void { + if (this.label.map?.bitmap) { + this.label.map?.bitmap.close(); + } + } } export const getHeatmapPoints = (labels: HeatmapLabel[]): Coordinates[] => { diff --git a/app/packages/looker/src/overlays/segmentation.ts b/app/packages/looker/src/overlays/segmentation.ts index c55a8b5ef5..a4cb098254 100644 --- a/app/packages/looker/src/overlays/segmentation.ts +++ b/app/packages/looker/src/overlays/segmentation.ts @@ -2,12 +2,13 @@ * Copyright 2017-2024, Voxel51, Inc. */ -import { getColor } from "@fiftyone/utilities"; -import { ARRAY_TYPES, OverlayMask, TypedArray } from "../numpy"; +import { getColor, sizeBytesEstimate } from "@fiftyone/utilities"; +import { ARRAY_TYPES, TypedArray } from "../numpy"; import { BaseState, Coordinates, MaskTargets } from "../state"; import { BaseLabel, CONTAINS, + LabelMask, Overlay, PointInfo, SelectData, @@ -16,10 +17,7 @@ import { import { isRgbMaskTargets, strokeCanvasRect, t } from "./util"; interface SegmentationLabel extends BaseLabel { - mask?: { - data: OverlayMask; - image: ArrayBuffer; - }; + mask?: LabelMask; } interface SegmentationInfo extends BaseLabel { @@ -34,8 +32,6 @@ export default class SegmentationOverlay readonly field: string; private label: SegmentationLabel; private targets?: TypedArray; - private canvas: HTMLCanvasElement; - private imageData: ImageData; private isRgbMaskTargets = false; @@ -53,6 +49,7 @@ export default class SegmentationOverlay if (!this.label.mask) { return; } + const [height, width] = this.label.mask.data.shape; if (!height || !width) { @@ -62,25 +59,6 @@ export default class SegmentationOverlay this.targets = new ARRAY_TYPES[this.label.mask.data.arrayType]( this.label.mask.data.buffer ); - - this.canvas = document.createElement("canvas"); - this.canvas.width = width; - this.canvas.height = height; - - this.imageData = new ImageData( - new Uint8ClampedArray(this.label.mask.image), - width, - height - ); - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.mask.data.shape[1], - this.label.mask.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); } containsPoint(state: Readonly): CONTAINS { @@ -99,12 +77,12 @@ export default class SegmentationOverlay return; } - if (this.imageData) { + if (this.label.mask?.bitmap) { const [tlx, tly] = t(state, 0, 0); const [brx, bry] = t(state, 1, 1); const tmp = ctx.globalAlpha; ctx.globalAlpha = state.options.alpha; - ctx.drawImage(this.canvas, tlx, tly, brx - tlx, bry - tly); + ctx.drawImage(this.label.mask.bitmap, tlx, tly, brx - tlx, bry - tly); ctx.globalAlpha = tmp; } @@ -278,6 +256,16 @@ export default class SegmentationOverlay } return this.targets[index]; } + + getSizeBytes(): number { + return sizeBytesEstimate(this.label); + } + + public cleanup(): void { + if (this.label.mask?.bitmap) { + this.label.mask?.bitmap.close(); + } + } } export const getSegmentationPoints = ( diff --git a/app/packages/looker/src/worker/decorated-fetch.test.ts b/app/packages/looker/src/worker/decorated-fetch.test.ts index 67ed853200..3a9a15e1e7 100644 --- a/app/packages/looker/src/worker/decorated-fetch.test.ts +++ b/app/packages/looker/src/worker/decorated-fetch.test.ts @@ -15,7 +15,7 @@ describe("fetchWithLinearBackoff", () => { expect(response).toBe(mockResponse); expect(global.fetch).toHaveBeenCalledTimes(1); - expect(global.fetch).toHaveBeenCalledWith("http://fiftyone.ai"); + expect(global.fetch).toHaveBeenCalledWith("http://fiftyone.ai", {}); }); it("should retry when fetch fails and eventually succeed", async () => { @@ -35,7 +35,7 @@ describe("fetchWithLinearBackoff", () => { global.fetch = vi.fn().mockRejectedValue(new Error("Network Error")); await expect( - fetchWithLinearBackoff("http://fiftyone.ai", 3, 10) + fetchWithLinearBackoff("http://fiftyone.ai", {}, 3, 10) ).rejects.toThrowError(new RegExp("Max retries for fetch reached")); expect(global.fetch).toHaveBeenCalledTimes(3); @@ -46,7 +46,7 @@ describe("fetchWithLinearBackoff", () => { global.fetch = vi.fn().mockResolvedValue(mockResponse); await expect( - fetchWithLinearBackoff("http://fiftyone.ai", 5, 10) + fetchWithLinearBackoff("http://fiftyone.ai", {}, 5, 10) ).rejects.toThrow("HTTP error: 500"); expect(global.fetch).toHaveBeenCalledTimes(5); @@ -57,7 +57,7 @@ describe("fetchWithLinearBackoff", () => { global.fetch = vi.fn().mockResolvedValue(mockResponse); await expect( - fetchWithLinearBackoff("http://fiftyone.ai", 5, 10) + fetchWithLinearBackoff("http://fiftyone.ai", {}, 5, 10) ).rejects.toThrow("Non-retryable HTTP error: 404"); expect(global.fetch).toHaveBeenCalledTimes(1); @@ -73,7 +73,12 @@ describe("fetchWithLinearBackoff", () => { vi.useFakeTimers(); - const fetchPromise = fetchWithLinearBackoff("http://fiftyone.ai", 5, 10); + const fetchPromise = fetchWithLinearBackoff( + "http://fiftyone.ai", + {}, + 5, + 10 + ); // advance timers to simulate delays // after first delay diff --git a/app/packages/looker/src/worker/decorated-fetch.ts b/app/packages/looker/src/worker/decorated-fetch.ts index c77059d551..9f0a910ea2 100644 --- a/app/packages/looker/src/worker/decorated-fetch.ts +++ b/app/packages/looker/src/worker/decorated-fetch.ts @@ -12,12 +12,13 @@ class NonRetryableError extends Error { export const fetchWithLinearBackoff = async ( url: string, + opts: RequestInit = {}, retries = DEFAULT_MAX_RETRIES, delay = DEFAULT_BASE_DELAY ) => { for (let i = 0; i < retries; i++) { try { - const response = await fetch(url); + const response = await fetch(url, opts); if (response.ok) { return response; } else { diff --git a/app/packages/looker/src/worker/deserializer.ts b/app/packages/looker/src/worker/deserializer.ts index 02a7b03867..363522b01f 100644 --- a/app/packages/looker/src/worker/deserializer.ts +++ b/app/packages/looker/src/worker/deserializer.ts @@ -25,7 +25,6 @@ export const DeserializerFactory = { image: new ArrayBuffer(width * height * 4), }; buffers.push(data.buffer); - buffers.push(label.mask.image); } }, Detections: (labels, buffers) => { @@ -47,7 +46,6 @@ export const DeserializerFactory = { }; buffers.push(data.buffer); - buffers.push(label.map.image); } }, Segmentation: (label, buffers) => { @@ -63,7 +61,6 @@ export const DeserializerFactory = { }; buffers.push(data.buffer); - buffers.push(label.mask.image); } }, }; diff --git a/app/packages/looker/src/worker/disk-overlay-decoder.test.ts b/app/packages/looker/src/worker/disk-overlay-decoder.test.ts new file mode 100644 index 0000000000..5d20f454dd --- /dev/null +++ b/app/packages/looker/src/worker/disk-overlay-decoder.test.ts @@ -0,0 +1,216 @@ +import { getSampleSrc } from "@fiftyone/state"; +import { DETECTIONS, HEATMAP } from "@fiftyone/utilities"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Coloring, CustomizeColor } from ".."; +import { LabelMask } from "../overlays/base"; +import type { Colorscale } from "../state"; +import { decodeWithCanvas } from "./canvas-decoder"; +import { fetchWithLinearBackoff } from "./decorated-fetch"; +import { decodeOverlayOnDisk, IntermediateMask } from "./disk-overlay-decoder"; + +vi.mock("@fiftyone/state", () => ({ + getSampleSrc: vi.fn(), +})); + +vi.mock("@fiftyone/utilities", () => ({ + DETECTION: "Detection", + DETECTIONS: "Detections", + HEATMAP: "Heatmap", +})); + +vi.mock("./canvas-decoder", () => ({ + decodeWithCanvas: vi.fn(), +})); + +vi.mock("./decorated-fetch", () => ({ + fetchWithLinearBackoff: vi.fn(), +})); + +const COLORING = {} as Coloring; +const COLOR_SCALE = {} as Colorscale; +const CUSTOMIZE_COLOR_SETTING: CustomizeColor[] = []; +const SOURCES = {}; + +type MaskUnion = (IntermediateMask & LabelMask) | null; + +describe("decodeOverlayOnDisk", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should return early if label already has overlay field (not on disk)", async () => { + const field = "testField"; + const label = { mask: {}, mask_path: "shouldBeIgnored" }; + const cls = "Segmentation"; + const maskPathDecodingPromises: Promise[] = []; + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(label.mask).toBeDefined(); + expect(fetchWithLinearBackoff).not.toHaveBeenCalled(); + }); + + it("should fetch and decode overlay when label has overlay path field", async () => { + const field = "testField"; + const label = { mask_path: "/path/to/mask", mask: null as MaskUnion }; + const cls = "Segmentation"; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl = "http://example.com/path/to/mask"; + const mockBlob = new Blob(["mock data"], { type: "image/png" }); + const overlayMask = { shape: [100, 200] }; + + vi.mocked(getSampleSrc).mockReturnValue(sampleSrcUrl); + vi.mocked(fetchWithLinearBackoff).mockResolvedValue({ + blob: () => Promise.resolve(mockBlob), + } as Response); + vi.mocked(decodeWithCanvas).mockResolvedValue(overlayMask); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(getSampleSrc).toHaveBeenCalledWith("/path/to/mask"); + expect(fetchWithLinearBackoff).toHaveBeenCalledWith(sampleSrcUrl); + expect(decodeWithCanvas).toHaveBeenCalledWith(mockBlob); + expect(label.mask).toBeDefined(); + expect(label.mask.data).toBe(overlayMask); + expect(label.mask.image).toBeInstanceOf(ArrayBuffer); + expect(label.mask.image.byteLength).toBe(100 * 200 * 4); + }); + + it("should handle HEATMAP class", async () => { + const field = "testField"; + const label = { map_path: "/path/to/map", map: null as MaskUnion }; + const cls = HEATMAP; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl = "http://example.com/path/to/map"; + const mockBlob = new Blob(["mock data"], { type: "image/png" }); + const overlayMask = { shape: [100, 200] }; + + vi.mocked(getSampleSrc).mockReturnValue(sampleSrcUrl); + vi.mocked(fetchWithLinearBackoff).mockResolvedValue({ + blob: () => Promise.resolve(mockBlob), + } as Response); + vi.mocked(decodeWithCanvas).mockResolvedValue(overlayMask); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(getSampleSrc).toHaveBeenCalledWith("/path/to/map"); + expect(fetchWithLinearBackoff).toHaveBeenCalledWith(sampleSrcUrl); + expect(decodeWithCanvas).toHaveBeenCalledWith(mockBlob); + expect(label.map).toBeDefined(); + expect(label.map.data).toBe(overlayMask); + expect(label.map.image).toBeInstanceOf(ArrayBuffer); + expect(label.map.image.byteLength).toBe(100 * 200 * 4); + }); + + it("should handle DETECTIONS class and process detections recursively", async () => { + const field = "testField"; + const label = { + detections: [ + { mask_path: "/path/to/mask1", mask: null as MaskUnion }, + { mask_path: "/path/to/mask2", mask: null as MaskUnion }, + ], + }; + const cls = DETECTIONS; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl1 = "http://example.com/path/to/mask1"; + const sampleSrcUrl2 = "http://example.com/path/to/mask2"; + const mockBlob1 = new Blob(["mock data 1"], { type: "image/png" }); + const mockBlob2 = new Blob(["mock data 2"], { type: "image/png" }); + const overlayMask1 = { shape: [50, 50] }; + const overlayMask2 = { shape: [60, 60] }; + + vi.mocked(getSampleSrc) + .mockReturnValueOnce(sampleSrcUrl1) + .mockReturnValueOnce(sampleSrcUrl2); + vi.mocked(fetchWithLinearBackoff) + .mockResolvedValueOnce({ + blob: () => Promise.resolve(mockBlob1), + } as Response) + .mockResolvedValueOnce({ + blob: () => Promise.resolve(mockBlob2), + } as Response); + vi.mocked(decodeWithCanvas) + .mockResolvedValueOnce(overlayMask1) + .mockResolvedValueOnce(overlayMask2); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + await Promise.all(maskPathDecodingPromises); + + expect(getSampleSrc).toHaveBeenNthCalledWith(1, "/path/to/mask1"); + expect(getSampleSrc).toHaveBeenNthCalledWith(2, "/path/to/mask2"); + expect(label.detections[0].mask).toBeDefined(); + expect(label.detections[0].mask.data).toBe(overlayMask1); + expect(label.detections[1].mask).toBeDefined(); + expect(label.detections[1].mask.data).toBe(overlayMask2); + }); + + it("should return early if fetch (with retry) fails", async () => { + const field = "testField"; + const label = { mask_path: "/path/to/mask", mask: null as MaskUnion }; + const cls = "Segmentation"; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl = "http://example.com/path/to/mask"; + + vi.mocked(getSampleSrc).mockReturnValue(sampleSrcUrl); + vi.mocked(fetchWithLinearBackoff).mockRejectedValue( + new Error("Fetch failed") + ); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(getSampleSrc).toHaveBeenCalledWith("/path/to/mask"); + expect(fetchWithLinearBackoff).toHaveBeenCalledWith(sampleSrcUrl); + expect(decodeWithCanvas).not.toHaveBeenCalled(); + expect(label.mask).toBeNull(); + }); +}); diff --git a/app/packages/looker/src/worker/disk-overlay-decoder.ts b/app/packages/looker/src/worker/disk-overlay-decoder.ts new file mode 100644 index 0000000000..9573d4f49a --- /dev/null +++ b/app/packages/looker/src/worker/disk-overlay-decoder.ts @@ -0,0 +1,109 @@ +import { getSampleSrc } from "@fiftyone/state/src/recoil/utils"; +import { DETECTION, DETECTIONS } from "@fiftyone/utilities"; +import { Coloring, CustomizeColor } from ".."; +import { OverlayMask } from "../numpy"; +import { Colorscale } from "../state"; +import { decodeWithCanvas } from "./canvas-decoder"; +import { enqueueFetch } from "./pooled-fetch"; +import { getOverlayFieldFromCls } from "./shared"; + +export type IntermediateMask = { + data: OverlayMask; + image: ArrayBuffer; +}; + +/** + * Some label types (example: segmentation, heatmap) can have their overlay data stored on-disk, + * we want to impute the relevant mask property of these labels from what's stored in the disk + */ +export const decodeOverlayOnDisk = async ( + field: string, + label: Record, + coloring: Coloring, + customizeColorSetting: CustomizeColor[], + colorscale: Colorscale, + sources: { [path: string]: string }, + cls: string, + maskPathDecodingPromises: Promise[] = [], + maskTargetsBuffers: ArrayBuffer[] = [] +) => { + // handle all list types here + if (cls === DETECTIONS) { + const promises: Promise[] = []; + for (const detection of label.detections) { + promises.push( + decodeOverlayOnDisk( + field, + detection, + coloring, + customizeColorSetting, + colorscale, + {}, + DETECTION, + maskPathDecodingPromises, + maskTargetsBuffers + ) + ); + } + maskPathDecodingPromises.push(...promises); + } + + const overlayFields = getOverlayFieldFromCls(cls); + const overlayPathField = overlayFields.disk; + const overlayField = overlayFields.canonical; + + if (Boolean(label[overlayField]) || !Object.hasOwn(label, overlayPathField)) { + // nothing to be done + return; + } + + // convert absolute file path to a URL that we can "fetch" from + const overlayImageUrl = getSampleSrc( + sources[`${field}.${overlayPathField}`] || label[overlayPathField] + ); + const urlTokens = overlayImageUrl.split("?"); + + let baseUrl = overlayImageUrl; + + // remove query params if not local URL + if (!urlTokens.at(1)?.startsWith("filepath=")) { + baseUrl = overlayImageUrl.split("?")[0]; + } + + let overlayImageBlob: Blob; + try { + const overlayImageFetchResponse = await enqueueFetch({ + url: baseUrl, + options: { priority: "low" }, + }); + overlayImageBlob = await overlayImageFetchResponse.blob(); + } catch (e) { + console.error(e); + // skip decoding if fetch fails altogether + return; + } + + let overlayMask: OverlayMask; + + try { + overlayMask = await decodeWithCanvas(overlayImageBlob); + } catch (e) { + console.error(e); + return; + } + + const [overlayHeight, overlayWidth] = overlayMask.shape; + + // set the `mask` property for this label + // we need to do this because we need raw image pixel data + // to iterate through and paint it with the color + // defined by the user for this particular label + label[overlayField] = { + data: overlayMask, + image: new ArrayBuffer(overlayWidth * overlayHeight * 4), + } as IntermediateMask; + + // no need to transfer image's buffer + //since we'll be constructing ImageBitmap and transfering that + maskTargetsBuffers.push(overlayMask.buffer); +}; diff --git a/app/packages/looker/src/worker/index.ts b/app/packages/looker/src/worker/index.ts index 21859407e2..e11906b51a 100644 --- a/app/packages/looker/src/worker/index.ts +++ b/app/packages/looker/src/worker/index.ts @@ -2,14 +2,12 @@ * Copyright 2017-2024, Voxel51, Inc. */ -import { getSampleSrc } from "@fiftyone/state/src/recoil/utils"; import { DENSE_LABELS, DETECTION, DETECTIONS, DYNAMIC_EMBEDDED_DOCUMENT, EMBEDDED_DOCUMENT, - HEATMAP, LABEL_LIST, Schema, Stage, @@ -29,11 +27,10 @@ import { LabelTagColor, Sample, } from "../state"; -import { decodeWithCanvas } from "./canvas-decoder"; -import { fetchWithLinearBackoff } from "./decorated-fetch"; import { DeserializerFactory } from "./deserializer"; +import { decodeOverlayOnDisk } from "./disk-overlay-decoder"; import { PainterFactory } from "./painter"; -import { mapId } from "./shared"; +import { getOverlayFieldFromCls, mapId } from "./shared"; import { process3DLabels } from "./threed-label-processor"; interface ResolveColor { @@ -97,89 +94,15 @@ const painterFactory = PainterFactory(requestColor); const ALL_VALID_LABELS = new Set(VALID_LABEL_TYPES); /** - * Some label types (example: segmentation, heatmap) can have their overlay data stored on-disk, - * we want to impute the relevant mask property of these labels from what's stored in the disk + * This function processes labels in a recursive manner. It follows the following steps: + * 1. Deserialize masks. Accumulate promises. + * 2. Await mask path decoding to finish. + * 3. Start painting overlays. Accumulate promises. + * 4. Await overlay painting to finish. + * 5. Start bitmap generation. Accumulate promises. + * 6. Await bitmap generation to finish. + * 7. Transfer bitmaps and mask targets array buffers back to the main thread. */ -const imputeOverlayFromPath = async ( - field: string, - label: Record, - coloring: Coloring, - customizeColorSetting: CustomizeColor[], - colorscale: Colorscale, - buffers: ArrayBuffer[], - sources: { [path: string]: string }, - cls: string, - maskPathDecodingPromises: Promise[] = [] -) => { - // handle all list types here - if (cls === DETECTIONS) { - const promises: Promise[] = []; - for (const detection of label.detections) { - promises.push( - imputeOverlayFromPath( - field, - detection, - coloring, - customizeColorSetting, - colorscale, - buffers, - {}, - DETECTION - ) - ); - } - maskPathDecodingPromises.push(...promises); - } - - // overlay path is in `map_path` property for heatmap, or else, it's in `mask_path` property (for segmentation or detection) - const overlayPathField = cls === HEATMAP ? "map_path" : "mask_path"; - const overlayField = overlayPathField === "map_path" ? "map" : "mask"; - - if ( - Object.hasOwn(label, overlayField) || - !Object.hasOwn(label, overlayPathField) - ) { - // nothing to be done - return; - } - - // convert absolute file path to a URL that we can "fetch" from - const overlayImageUrl = getSampleSrc( - sources[`${field}.${overlayPathField}`] || label[overlayPathField] - ); - const urlTokens = overlayImageUrl.split("?"); - - let baseUrl = overlayImageUrl; - - // remove query params if not local URL - if (!urlTokens.at(1)?.startsWith("filepath=")) { - baseUrl = overlayImageUrl.split("?")[0]; - } - - let overlayImageBlob: Blob; - try { - const overlayImageFetchResponse = await fetchWithLinearBackoff(baseUrl); - overlayImageBlob = await overlayImageFetchResponse.blob(); - } catch (e) { - console.error(e); - // skip decoding if fetch fails altogether - return; - } - - const overlayMask = await decodeWithCanvas(overlayImageBlob); - const [overlayHeight, overlayWidth] = overlayMask.shape; - - // set the `mask` property for this label - label[overlayField] = { - data: overlayMask, - image: new ArrayBuffer(overlayWidth * overlayHeight * 4), - }; - - // transfer buffers - buffers.push(overlayMask.buffer); - buffers.push(label[overlayField].image); -}; - const processLabels = async ( sample: ProcessSample["sample"], coloring: ProcessSample["coloring"], @@ -190,13 +113,13 @@ const processLabels = async ( labelTagColors: ProcessSample["labelTagColors"], selectedLabelTags: ProcessSample["selectedLabelTags"], schema: Schema -): Promise => { - const buffers: ArrayBuffer[] = []; - const painterPromises = []; - - const maskPathDecodingPromises = []; +): Promise<[Promise[], ArrayBuffer[]]> => { + const maskPathDecodingPromises: Promise[] = []; + const painterPromises: Promise[] = []; + const bitmapPromises: Promise[] = []; + const maskTargetsBuffers: ArrayBuffer[] = []; - // mask deserialization / mask_path decoding loop + // mask deserialization / on-disk overlay decoding loop for (const field in sample) { let labels = sample[field]; if (!Array.isArray(labels)) { @@ -211,37 +134,39 @@ const processLabels = async ( if (DENSE_LABELS.has(cls)) { maskPathDecodingPromises.push( - imputeOverlayFromPath( + decodeOverlayOnDisk( `${prefix || ""}${field}`, label, coloring, customizeColorSetting, colorscale, - buffers, sources, cls, - maskPathDecodingPromises + maskPathDecodingPromises, + maskTargetsBuffers ) ); } if (cls in DeserializerFactory) { - DeserializerFactory[cls](label, buffers); + DeserializerFactory[cls](label, maskTargetsBuffers); } if ([EMBEDDED_DOCUMENT, DYNAMIC_EMBEDDED_DOCUMENT].includes(cls)) { - const moreBuffers = await processLabels( - label, - coloring, - `${prefix ? prefix : ""}${field}.`, - sources, - customizeColorSetting, - colorscale, - labelTagColors, - selectedLabelTags, - schema - ); - buffers.push(...moreBuffers); + const [moreBitmapPromises, moreMaskTargetsBuffers] = + await processLabels( + label, + coloring, + `${prefix ? prefix : ""}${field}.`, + sources, + customizeColorSetting, + colorscale, + labelTagColors, + selectedLabelTags, + schema + ); + bitmapPromises.push(...moreBitmapPromises); + maskTargetsBuffers.push(...moreMaskTargetsBuffers); } if (ALL_VALID_LABELS.has(cls)) { @@ -286,7 +211,61 @@ const processLabels = async ( } } - return Promise.all(painterPromises).then(() => buffers); + await Promise.allSettled(painterPromises); + + // bitmap generation loop + for (const field in sample) { + let labels = sample[field]; + if (!Array.isArray(labels)) { + labels = [labels]; + } + const cls = getCls(`${prefix ? prefix : ""}${field}`, schema); + + for (const label of labels) { + if (!label) { + continue; + } + + collectBitmapPromises(label, cls, bitmapPromises); + } + } + + return [bitmapPromises, maskTargetsBuffers]; +}; + +const collectBitmapPromises = (label, cls, bitmapPromises) => { + if (cls === DETECTIONS) { + label?.detections?.forEach((detection) => + collectBitmapPromises(detection, DETECTION, bitmapPromises) + ); + return; + } + + const overlayFields = getOverlayFieldFromCls(cls); + const overlayField = overlayFields.canonical; + + if (label[overlayField]) { + const [height, width] = label[overlayField].data.shape; + + const imageData = new ImageData( + new Uint8ClampedArray(label[overlayField].image), + width, + height + ); + + // set raw image to null - will be garbage collected + // we don't need it anymore since we copied to ImageData + label[overlayField].image = null; + + bitmapPromises.push( + new Promise((resolve) => { + createImageBitmap(imageData).then((imageBitmap) => { + label[overlayField].bitmap = imageBitmap; + resolve(imageBitmap); + }); + }) + ); + } }; /** GLOBALS */ @@ -316,7 +295,7 @@ export interface ProcessSample { type ProcessSampleMethod = ReaderMethod & ProcessSample; -const processSample = ({ +const processSample = async ({ sample, uuid, coloring, @@ -329,48 +308,46 @@ const processSample = ({ }: ProcessSample) => { mapId(sample); - let bufferPromises = []; + const imageBitmapPromises: Promise[] = []; + const maskTargetsBuffers: ArrayBuffer[] = []; if (sample?._media_type === "point-cloud" || sample?._media_type === "3d") { process3DLabels(schema, sample); } else { - bufferPromises = [ - processLabels( - sample, + const [bitmapPromises, moreMaskTargetsBuffers] = await processLabels( + sample, + coloring, + null, + sources, + customizeColorSetting, + colorscale, + labelTagColors, + selectedLabelTags, + schema + ); + imageBitmapPromises.push(...bitmapPromises); + maskTargetsBuffers.push(...moreMaskTargetsBuffers); + } + + if (sample.frames && sample.frames.length) { + for (const frame of sample.frames) { + const [moreBitmapPromises, moreMaskTargetsBuffers] = await processLabels( + frame, coloring, - null, + "frames.", sources, customizeColorSetting, colorscale, labelTagColors, selectedLabelTags, schema - ), - ]; - } - - if (sample.frames && sample.frames.length) { - bufferPromises = [ - ...bufferPromises, - ...sample.frames - .map((frame) => - processLabels( - frame, - coloring, - "frames.", - sources, - customizeColorSetting, - colorscale, - labelTagColors, - selectedLabelTags, - schema - ) - ) - .flat(), - ]; + ); + imageBitmapPromises.push(...moreBitmapPromises); + maskTargetsBuffers.push(...moreMaskTargetsBuffers); + } } - Promise.all(bufferPromises).then((buffers) => { + Promise.all(imageBitmapPromises).then((bitmaps) => { postMessage( { method: "processSample", @@ -383,7 +360,7 @@ const processSample = ({ selectedLabelTags, }, // @ts-ignore - buffers.flat() + bitmaps.flat().concat(maskTargetsBuffers.flat()) ); }); }; diff --git a/app/packages/looker/src/worker/pooled-fetch.ts b/app/packages/looker/src/worker/pooled-fetch.ts new file mode 100644 index 0000000000..3a61c8b0ce --- /dev/null +++ b/app/packages/looker/src/worker/pooled-fetch.ts @@ -0,0 +1,46 @@ +import { fetchWithLinearBackoff } from "./decorated-fetch"; + +interface QueueItem { + request: { + url: string; + options?: RequestInit; + }; + resolve: (value: Response | PromiseLike) => void; + reject: (reason?: any) => void; +} + +// note: arbitrary number that seems to work well +const MAX_CONCURRENT_REQUESTS = 100; + +let activeRequests = 0; +const requestQueue: QueueItem[] = []; + +export const enqueueFetch = (request: { + url: string; + options?: RequestInit; +}): Promise => { + return new Promise((resolve, reject) => { + requestQueue.push({ request, resolve, reject }); + processFetchQueue(); + }); +}; + +const processFetchQueue = () => { + if (activeRequests >= MAX_CONCURRENT_REQUESTS || requestQueue.length === 0) { + return; + } + + const { request, resolve, reject } = requestQueue.shift(); + activeRequests++; + + fetchWithLinearBackoff(request.url, request.options) + .then((response) => { + activeRequests--; + resolve(response); + processFetchQueue(); + }) + .catch((error) => { + activeRequests--; + reject(error); + }); +}; diff --git a/app/packages/looker/src/worker/shared.ts b/app/packages/looker/src/worker/shared.ts index adfda58d29..ec383b7536 100644 --- a/app/packages/looker/src/worker/shared.ts +++ b/app/packages/looker/src/worker/shared.ts @@ -1,3 +1,5 @@ +import { HEATMAP } from "@fiftyone/utilities"; + /** * Map the _id field to id */ @@ -8,3 +10,12 @@ export const mapId = (obj) => { } return obj; }; + +export const getOverlayFieldFromCls = (cls: string) => { + switch (cls) { + case HEATMAP: + return { canonical: "map", disk: "map_path" }; + default: + return { canonical: "mask", disk: "mask_path" }; + } +};