Skip to content

Commit

Permalink
made improvement for virtual background (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
jibon57 authored Jun 8, 2023
1 parent cfa4308 commit 7c0fbd7
Show file tree
Hide file tree
Showing 14 changed files with 159 additions and 55 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"devDependencies": {
"@bufbuild/protoc-gen-es": "1.2.1",
"@pmmmwh/react-refresh-webpack-plugin": "0.5.10",
"@types/emscripten": "^1.39.6",
"@types/lodash": "4.14.195",
"@types/react": "18.2.9",
"@types/react-dom": "18.2.4",
Expand Down
2 changes: 1 addition & 1 deletion src/assets/tflite/tflite-simd.js

Large diffs are not rendered by default.

Binary file modified src/assets/tflite/tflite-simd.wasm
Binary file not shown.
2 changes: 1 addition & 1 deletion src/assets/tflite/tflite.js

Large diffs are not rendered by default.

Binary file modified src/assets/tflite/tflite.wasm
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ export type SegmentationConfig = {
backend: SegmentationBackend;
inputResolution: InputResolution;
pipeline: PipelineName;
targetFps: number;
deferInputResizing: boolean;
};

export const defaultSegmentationConfig: SegmentationConfig = {
model: 'meet',
backend: 'wasmSimd',
inputResolution: '160x96',
pipeline: 'canvas2dCpu',
targetFps: 65,
deferInputResizing: true,
};

export function getTFLiteModelFileName(
Expand Down
48 changes: 48 additions & 0 deletions src/components/virtual-background/helpers/timerHelper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
type TimerData = {
callbackId: number;
};

export type TimerWorker = {
setTimeout(callback: () => void, timeoutMs?: number): number;
clearTimeout(callbackId: number): void;
terminate(): void;
};

export function createTimerWorker(): TimerWorker {
const callbacks = new Map<number, () => void>();

const worker = new Worker(new URL('./timerWorker', import.meta.url));

worker.onmessage = (event: MessageEvent<TimerData>) => {
const callback = callbacks.get(event.data.callbackId);
if (!callback) {
return;
}
callbacks.delete(event.data.callbackId);
callback();
};

let nextCallbackId = 1;

function setTimeout(callback: () => void, timeoutMs = 0) {
const callbackId = nextCallbackId++;
callbacks.set(callbackId, callback);
worker.postMessage({ callbackId, timeoutMs });
return callbackId;
}

function clearTimeout(callbackId: number) {
if (!callbacks.has(callbackId)) {
return;
}
worker.postMessage({ callbackId });
callbacks.delete(callbackId);
}

function terminate() {
callbacks.clear();
worker.terminate();
}

return { setTimeout, clearTimeout, terminate };
}
24 changes: 24 additions & 0 deletions src/components/virtual-background/helpers/timerWorker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
declare const self: DedicatedWorkerGlobalScope;

type TimerData = {
callbackId: number;
timeoutMs?: number;
};

const timeoutIds = new Map<number, number>();

self.onmessage = (event: MessageEvent<TimerData>) => {
if (event.data.timeoutMs !== undefined) {
const timeoutId = self.setTimeout(() => {
self.postMessage({ callbackId: event.data.callbackId });
timeoutIds.delete(event.data.callbackId);
}, event.data.timeoutMs);
timeoutIds.set(event.data.callbackId, timeoutId);
} else {
const timeoutId = timeoutIds.get(event.data.callbackId);
self.clearTimeout(timeoutId);
timeoutIds.delete(event.data.callbackId);
}
};

export {};
33 changes: 19 additions & 14 deletions src/components/virtual-background/hooks/useRenderingPipeline.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
/* eslint-disable */
/*eslint-disable*/

import { BodyPix } from '@tensorflow-models/body-pix';
import { useEffect, useRef, useState } from 'react';
import { buildCanvas2dPipeline } from '../pipelines/canvas2d/canvas2dPipeline';
import { buildWebGL2Pipeline } from '../pipelines/webgl2/webgl2Pipeline';
import { BackgroundConfig } from '../helpers/backgroundHelper';
import { RenderingPipeline } from '../helpers/renderingPipelineHelper';
import { SegmentationConfig } from '../helpers/segmentationHelper';
import { SourcePlayback } from '../helpers/sourceHelper';
import { TFLite } from './useTFLite';
import { createTimerWorker } from '../helpers/timerHelper';
import { buildWebGL2Pipeline } from '../pipelines/webgl2/webgl2Pipeline';
import { buildCanvas2dPipeline } from '../pipelines/canvas2d/canvas2dPipeline';
declare const IS_PRODUCTION: boolean;

function useRenderingPipeline(
Expand All @@ -24,17 +26,17 @@ function useRenderingPipeline(
const [durations, setDurations] = useState<number[]>([]);

useEffect(() => {
// The useEffect cleanup function is not enough to stop
// the rendering loop when the framerate is low
let shouldRender = true;
const targetTimerTimeoutMs = 1000 / segmentationConfig.targetFps;

let previousTime = 0;
let beginTime = 0;
let eventCount = 0;
let frameCount = 0;
const frameDurations: number[] = [];

let renderRequestId: number;
let renderTimeoutId: number;

const timerWorker = createTimerWorker();

const newPipeline =
segmentationConfig.pipeline === 'webgl2'
Expand All @@ -45,6 +47,7 @@ function useRenderingPipeline(
segmentationConfig,
canvasRef.current,
tflite,
timerWorker,
addFrameEvent,
)
: buildCanvas2dPipeline(
Expand All @@ -58,13 +61,16 @@ function useRenderingPipeline(
);

async function render() {
if (!shouldRender) {
return;
}
const startTime = performance.now();

beginFrame();
await newPipeline.render();
endFrame();
renderRequestId = requestAnimationFrame(render);

renderTimeoutId = timerWorker.setTimeout(
render,
Math.max(0, targetTimerTimeoutMs - (performance.now() - startTime)),
);
}

function beginFrame() {
Expand Down Expand Up @@ -104,10 +110,9 @@ function useRenderingPipeline(
setPipeline(newPipeline);

return () => {
shouldRender = false;
cancelAnimationFrame(renderRequestId);
timerWorker.clearTimeout(renderTimeoutId);
timerWorker.terminate();
newPipeline.cleanUp();

if (!IS_PRODUCTION) {
console.log(
'Animation stopped:',
Expand Down
16 changes: 5 additions & 11 deletions src/components/virtual-background/hooks/useTFLite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import {
declare function createTFLiteModule(): Promise<TFLite>;
declare function createTFLiteSIMDModule(): Promise<TFLite>;

// eslint-disable-next-line
// @ts-ignore
export interface TFLite extends EmscriptenModule {
_getModelBufferMemoryOffset(): number;
_getInputMemoryOffset(): number;
Expand All @@ -32,17 +30,13 @@ function useTFLite(segmentationConfig: SegmentationConfig) {

useEffect(() => {
async function loadTFLite() {
createTFLiteModule().then(setTFLite);
try {
const createdTFLiteSIMD = await createTFLiteSIMDModule();
setTFLiteSIMD(createdTFLiteSIMD);
setTFLite(createdTFLiteSIMD);
setSIMDSupported(true);
} catch (e) {
try {
createTFLiteModule().then(setTFLite);
} catch (error) {
console.warn('Failed to create TFLite WebAssembly module.', error);
}
} catch (error) {
console.warn('Failed to create TFLite SIMD WebAssembly module.', error);
}
}

Expand All @@ -63,7 +57,7 @@ function useTFLite(segmentationConfig: SegmentationConfig) {

setSelectedTFLite(undefined);

const newSelectedTFLite: any =
const newSelectedTFLite =
segmentationConfig.backend === 'wasmSimd' ? tfliteSIMD : tflite;

if (!newSelectedTFLite) {
Expand Down Expand Up @@ -116,7 +110,7 @@ function useTFLite(segmentationConfig: SegmentationConfig) {
}

loadTFLiteModel();
// eslint-disable-next-line
//eslint-disable-next-line
}, [
tflite,
tfliteSIMD,
Expand Down
28 changes: 22 additions & 6 deletions src/components/virtual-background/pipelines/helpers/webglHelper.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
/* eslint-disable */
/*eslint-disable*/

import { TimerWorker } from '../../helpers/timerHelper';

/**
* Use it along with boyswan.glsl-literal VSCode extension
* to get GLSL syntax highlighting.
Expand Down Expand Up @@ -82,6 +85,7 @@ export function createTexture(
}

export async function readPixelsAsync(
timerWorker: TimerWorker,
gl: WebGL2RenderingContext,
x: number,
y: number,
Expand All @@ -97,13 +101,21 @@ export async function readPixelsAsync(
gl.readPixels(x, y, width, height, format, type, 0);
gl.bindBuffer(gl.PIXEL_PACK_BUFFER, null);

await getBufferSubDataAsync(gl, gl.PIXEL_PACK_BUFFER, buf, 0, dest);
await getBufferSubDataAsync(
timerWorker,
gl,
gl.PIXEL_PACK_BUFFER,
buf,
0,
dest,
);

gl.deleteBuffer(buf);
return dest;
}

async function getBufferSubDataAsync(
timerWorker: TimerWorker,
gl: WebGL2RenderingContext,
target: number,
buffer: WebGLBuffer,
Expand All @@ -114,7 +126,7 @@ async function getBufferSubDataAsync(
) {
const sync = gl.fenceSync(gl.SYNC_GPU_COMMANDS_COMPLETE, 0)!;
gl.flush();
const res = await clientWaitAsync(gl, sync);
const res = await clientWaitAsync(timerWorker, gl, sync);
gl.deleteSync(sync);

if (res !== gl.WAIT_FAILED) {
Expand All @@ -124,7 +136,11 @@ async function getBufferSubDataAsync(
}
}

function clientWaitAsync(gl: WebGL2RenderingContext, sync: WebGLSync) {
function clientWaitAsync(
timerWorker: TimerWorker,
gl: WebGL2RenderingContext,
sync: WebGLSync,
) {
return new Promise<number>((resolve) => {
function test() {
const res = gl.clientWaitSync(sync, 0, 0);
Expand All @@ -133,11 +149,11 @@ function clientWaitAsync(gl: WebGL2RenderingContext, sync: WebGLSync) {
return;
}
if (res === gl.TIMEOUT_EXPIRED) {
requestAnimationFrame(test);
timerWorker.setTimeout(test);
return;
}
resolve(res);
}
requestAnimationFrame(test);
timerWorker.setTimeout(test);
});
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import {
inputResolutions,
SegmentationConfig,
} from '../../helpers/segmentationHelper';
import {
compileShader,
createPiplelineStageProgram,
createTexture,
glsl,
readPixelsAsync,
} from '../helpers/webglHelper';
import { TimerWorker } from '../../helpers/timerHelper';
import {
inputResolutions,
SegmentationConfig,
} from '../../helpers/segmentationHelper';
import { TFLite } from '../../hooks/useTFLite';

export function buildResizingStage(
timerWorker: TimerWorker,
gl: WebGL2RenderingContext,
vertexShader: WebGLShader,
positionBuffer: WebGLBuffer,
texCoordBuffer: WebGLBuffer,
segmentationConfig: SegmentationConfig,
tflite: any,
tflite: TFLite,
) {
const fragmentShaderSource = glsl`#version 300 es
Expand Down Expand Up @@ -69,14 +72,14 @@ export function buildResizingStage(
gl.useProgram(program);
gl.uniform1i(inputFrameLocation, 0);

function render() {
async function render() {
gl.viewport(0, 0, outputWidth, outputHeight);
gl.useProgram(program);
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);

// Downloads pixels asynchronously from GPU while rendering the current frame
readPixelsAsync(
const readPixelsPromise = readPixelsAsync(
timerWorker,
gl,
0,
0,
Expand All @@ -87,6 +90,14 @@ export function buildResizingStage(
outputPixels,
);

if (segmentationConfig.deferInputResizing) {
// Downloads pixels asynchronously from GPU while rendering the current frame.
// The pixels will be available in the next frame render which results
// in offsets in the segmentation output but increases the frame rate.
} else {
await readPixelsPromise;
}

for (let i = 0; i < outputPixelCount; i++) {
const tfliteIndex = tfliteInputMemoryOffset + i * 3;
const outputIndex = i * 4;
Expand Down
Loading

0 comments on commit 7c0fbd7

Please sign in to comment.