Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 3aef5db

Browse files
committed
merge with conv
1 parent f0292df commit 3aef5db

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

lib/backends/webgl/ops/conv.ts

+38-28
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
6868

6969
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
7070
const hasBias = inputs.length > 2;
71-
const processBias = hasBias ? `dotProd += getBias(output_channel);` : ``;
71+
const processBias = hasBias ? `value += getBias(output_channel);` : ``;
7272
const xShape = inputs[0].dims.slice();
7373
const wShape = inputs[1].dims.slice();
7474
const outputChannelsPerGroup = wShape[0] / this.group;
@@ -87,18 +87,20 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
8787
const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides);
8888
const glsl = getGlsl(handler.session.backend.glContext.version);
8989

90+
const {activationFunction, applyActivation} = getActicationSnippet(this.activation);
91+
9092
const shaderSource = `
9193
const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]});
9294
const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]});
93-
95+
${activationFunction}
9496
void main() {
9597
ivec4 coords = getOutputCoords();
9698
int batch = coords.x;
9799
int output_channel = coords.y;
98100
ivec2 xRCCorner = coords.zw * strides - pads;
99101
int group_id = output_channel / ${outputChannelsPerGroup};
100102
101-
float dotProd = 0.0;
103+
float value = 0.0;
102104
for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) {
103105
int input_channel = group_id * ${wShape[1]} + wInChannel;
104106
for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) {
@@ -116,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
116118
117119
float xVal = getX(batch, input_channel, xWidth, xHeight);
118120
float wVal = getW(output_channel, wInChannel, wWidth, wHeight);
119-
dotProd += xVal*wVal;
121+
value += xVal*wVal;
120122
}
121123
}
122124
}
123125
${processBias}
124-
${glsl.output} = vec4(dotProd, .0, .0, .0);
126+
${applyActivation}
127+
${glsl.output} = vec4(value, .0, .0, .0);
125128
}
126129
`;
127130
return {
@@ -143,6 +146,28 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
143146
}
144147
}
145148

149+
function getActicationSnippet(activation: string) {
150+
let activationFunction = '';
151+
let activationName = '';
152+
switch (activation) {
153+
case 'Relu':
154+
activationName = glslRelu().name;
155+
activationFunction = glslRelu().body;
156+
break;
157+
case 'Sigmoid':
158+
activationName = glslSigmoid().name;
159+
activationFunction = glslSigmoid().body;
160+
break;
161+
default:
162+
activationName = '';
163+
activationFunction = '';
164+
}
165+
const applyActivation = activation ? `
166+
value = ${activationName}(value);` :
167+
'';
168+
return {activationFunction, applyActivation};
169+
}
170+
146171
export class WebGLUnpackedConv extends Conv {
147172
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
148173
const programManager = inferenceHandler.session.programManager;
@@ -250,6 +275,7 @@ export class WebGLUnpackedConv extends Conv {
250275
const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4);
251276
const outputLayout = inferenceHandler.createTextureLayoutFromShape(
252277
im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3});
278+
253279
const shaderSource = `
254280
const int XC = ${xshape[1]};
255281
const int XH = ${xshape[2]};
@@ -265,13 +291,12 @@ export class WebGLUnpackedConv extends Conv {
265291
const int KHKW = KH*KW;
266292
const int XCKHKW = XC * KHKW;
267293
const int outputChannels = 4;
268-
269294
vec4 process(int indices[${rank}]) {
270295
int b = indices[0]; // batch size
271296
int oh = indices[1] * strideH - padH; //output height
272297
int ow = indices[2] * strideW - padW; //output width
273298
int p = indices[3] * outputChannels; //patch
274-
vec4 v = vec4(0.0);
299+
vec4 value = vec4(0.0);
275300
for(int i=0; i < outputChannels; ++i) {
276301
if(p < XCKHKW) {
277302
int patchC = p / KHKW;
@@ -288,12 +313,12 @@ export class WebGLUnpackedConv extends Conv {
288313
xh2 < XH &&
289314
xw2 >= 0 &&
290315
xw2 < XW) {
291-
v[i] = _X(x);
316+
value[i] = _X(x);
292317
}
293318
}
294319
++p;
295320
}
296-
return v;
321+
return value;
297322
}
298323
`;
299324
return {
@@ -332,22 +357,7 @@ export class WebGLUnpackedConv extends Conv {
332357
samplers.push('B');
333358
}
334359

335-
let activationFunction = '';
336-
let activationName = '';
337-
switch (this.activation) {
338-
case 'Relu':
339-
activationName = glslRelu().name;
340-
activationFunction = glslRelu().body;
341-
break;
342-
case 'Sigmoid':
343-
activationName = glslSigmoid().name;
344-
activationFunction = glslSigmoid().body;
345-
break;
346-
default:
347-
activationName = '';
348-
activationFunction = '';
349-
}
350-
const applyActivation = this.activation ? `sum = ${activationName}(sum);` : '';
360+
const {activationFunction, applyActivation} = getActicationSnippet(this.activation);
351361

352362
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
353363
const shaderSource = `
@@ -362,16 +372,16 @@ export class WebGLUnpackedConv extends Conv {
362372
int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${
363373
im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset;
364374
int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset;
365-
float sum = sharedDimOffset == 0 ? ${initValue} : 0.0;
375+
float value = sharedDimOffset == 0 ? ${initValue} : 0.0;
366376
for (int i = 0; i < ${sharedDimReadSize}; ++i) {
367377
vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height});
368378
vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height});
369-
sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
379+
value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
370380
++im2colOffset;
371381
++kernelOffset;
372382
}
373383
${applyActivation}
374-
return sum;
384+
return value;
375385
}`;
376386
return {
377387
inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout],

0 commit comments

Comments
 (0)