@@ -68,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
68
68
69
69
createProgramInfo ( handler : WebGLInferenceHandler , inputs : Tensor [ ] ) : ProgramInfo {
70
70
const hasBias = inputs . length > 2 ;
71
- const processBias = hasBias ? `dotProd += getBias(output_channel);` : `` ;
71
+ const processBias = hasBias ? `value += getBias(output_channel);` : `` ;
72
72
const xShape = inputs [ 0 ] . dims . slice ( ) ;
73
73
const wShape = inputs [ 1 ] . dims . slice ( ) ;
74
74
const outputChannelsPerGroup = wShape [ 0 ] / this . group ;
@@ -87,18 +87,20 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
87
87
const outputShape = WebGLConv . calcOutputShape ( xShape , wShape , this . dilations , this . pads , this . strides ) ;
88
88
const glsl = getGlsl ( handler . session . backend . glContext . version ) ;
89
89
90
+ const { activationFunction, applyActivation} = getActicationSnippet ( this . activation ) ;
91
+
90
92
const shaderSource = `
91
93
const ivec2 strides = ivec2(${ this . strides [ 0 ] } , ${ this . strides [ 1 ] } );
92
94
const ivec2 pads = ivec2(${ this . pads [ 0 ] } , ${ this . pads [ 1 ] } );
93
-
95
+ ${ activationFunction }
94
96
void main() {
95
97
ivec4 coords = getOutputCoords();
96
98
int batch = coords.x;
97
99
int output_channel = coords.y;
98
100
ivec2 xRCCorner = coords.zw * strides - pads;
99
101
int group_id = output_channel / ${ outputChannelsPerGroup } ;
100
102
101
- float dotProd = 0.0;
103
+ float value = 0.0;
102
104
for (int wInChannel = 0; wInChannel < ${ wShape [ 1 ] } ; wInChannel++) {
103
105
int input_channel = group_id * ${ wShape [ 1 ] } + wInChannel;
104
106
for (int wHeight = 0; wHeight < ${ wShape [ 2 ] } ; wHeight++) {
@@ -116,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
116
118
117
119
float xVal = getX(batch, input_channel, xWidth, xHeight);
118
120
float wVal = getW(output_channel, wInChannel, wWidth, wHeight);
119
- dotProd += xVal*wVal;
121
+ value += xVal*wVal;
120
122
}
121
123
}
122
124
}
123
125
${ processBias }
124
- ${ glsl . output } = vec4(dotProd, .0, .0, .0);
126
+ ${ applyActivation }
127
+ ${ glsl . output } = vec4(value, .0, .0, .0);
125
128
}
126
129
` ;
127
130
return {
@@ -143,6 +146,28 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
143
146
}
144
147
}
145
148
149
+ function getActicationSnippet ( activation : string ) {
150
+ let activationFunction = '' ;
151
+ let activationName = '' ;
152
+ switch ( activation ) {
153
+ case 'Relu' :
154
+ activationName = glslRelu ( ) . name ;
155
+ activationFunction = glslRelu ( ) . body ;
156
+ break ;
157
+ case 'Sigmoid' :
158
+ activationName = glslSigmoid ( ) . name ;
159
+ activationFunction = glslSigmoid ( ) . body ;
160
+ break ;
161
+ default :
162
+ activationName = '' ;
163
+ activationFunction = '' ;
164
+ }
165
+ const applyActivation = activation ? `
166
+ value = ${ activationName } (value);` :
167
+ '' ;
168
+ return { activationFunction, applyActivation} ;
169
+ }
170
+
146
171
export class WebGLUnpackedConv extends Conv {
147
172
run ( inferenceHandler : WebGLInferenceHandler , inputs : Tensor [ ] ) : Tensor [ ] {
148
173
const programManager = inferenceHandler . session . programManager ;
@@ -250,6 +275,7 @@ export class WebGLUnpackedConv extends Conv {
250
275
const im2colDims = WebGLUnpackedConv . calcIm2ColDims ( xshape , kshape , outputShape , 4 ) ;
251
276
const outputLayout = inferenceHandler . createTextureLayoutFromShape (
252
277
im2colDims , 4 , [ im2colDims [ 0 ] , im2colDims [ 1 ] , im2colDims [ 2 ] , im2colDims [ 3 ] * 4 ] , { breakAxis : 3 } ) ;
278
+
253
279
const shaderSource = `
254
280
const int XC = ${ xshape [ 1 ] } ;
255
281
const int XH = ${ xshape [ 2 ] } ;
@@ -265,13 +291,12 @@ export class WebGLUnpackedConv extends Conv {
265
291
const int KHKW = KH*KW;
266
292
const int XCKHKW = XC * KHKW;
267
293
const int outputChannels = 4;
268
-
269
294
vec4 process(int indices[${ rank } ]) {
270
295
int b = indices[0]; // batch size
271
296
int oh = indices[1] * strideH - padH; //output height
272
297
int ow = indices[2] * strideW - padW; //output width
273
298
int p = indices[3] * outputChannels; //patch
274
- vec4 v = vec4(0.0);
299
+ vec4 value = vec4(0.0);
275
300
for(int i=0; i < outputChannels; ++i) {
276
301
if(p < XCKHKW) {
277
302
int patchC = p / KHKW;
@@ -288,12 +313,12 @@ export class WebGLUnpackedConv extends Conv {
288
313
xh2 < XH &&
289
314
xw2 >= 0 &&
290
315
xw2 < XW) {
291
- v [i] = _X(x);
316
+ value [i] = _X(x);
292
317
}
293
318
}
294
319
++p;
295
320
}
296
- return v ;
321
+ return value ;
297
322
}
298
323
` ;
299
324
return {
@@ -332,22 +357,7 @@ export class WebGLUnpackedConv extends Conv {
332
357
samplers . push ( 'B' ) ;
333
358
}
334
359
335
- let activationFunction = '' ;
336
- let activationName = '' ;
337
- switch ( this . activation ) {
338
- case 'Relu' :
339
- activationName = glslRelu ( ) . name ;
340
- activationFunction = glslRelu ( ) . body ;
341
- break ;
342
- case 'Sigmoid' :
343
- activationName = glslSigmoid ( ) . name ;
344
- activationFunction = glslSigmoid ( ) . body ;
345
- break ;
346
- default :
347
- activationName = '' ;
348
- activationFunction = '' ;
349
- }
350
- const applyActivation = this . activation ? `sum = ${ activationName } (sum);` : '' ;
360
+ const { activationFunction, applyActivation} = getActicationSnippet ( this . activation ) ;
351
361
352
362
const glsl = getGlsl ( inferenceHandler . session . backend . glContext . version ) ;
353
363
const shaderSource = `
@@ -362,16 +372,16 @@ export class WebGLUnpackedConv extends Conv {
362
372
int im2colOffset = im2col[0] * ${ im2colLayout . strides [ 0 ] } + im2col[1] * ${
363
373
im2colLayout . strides [ 1 ] } + im2col[2] * ${ im2colLayout . strides [ 2 ] } + sharedDimOffset;
364
374
int kernelOffset = indices[1] * ${ kLayout . strides [ 0 ] } + sharedDimOffset;
365
- float sum = sharedDimOffset == 0 ? ${ initValue } : 0.0;
375
+ float value = sharedDimOffset == 0 ? ${ initValue } : 0.0;
366
376
for (int i = 0; i < ${ sharedDimReadSize } ; ++i) {
367
377
vec2 im2colCoords = offsetToCoords(im2colOffset, ${ im2colLayout . width } , ${ im2colLayout . height } );
368
378
vec2 kernelCoords = offsetToCoords(kernelOffset, ${ kLayout . width } , ${ kLayout . height } );
369
- sum += dot(${ glsl . texture2D } (Im2Col, im2colCoords), ${ glsl . texture2D } (K, kernelCoords));
379
+ value += dot(${ glsl . texture2D } (Im2Col, im2colCoords), ${ glsl . texture2D } (K, kernelCoords));
370
380
++im2colOffset;
371
381
++kernelOffset;
372
382
}
373
383
${ applyActivation }
374
- return sum ;
384
+ return value ;
375
385
}` ;
376
386
return {
377
387
inputLayouts : inputs . length === 3 ? [ im2colLayout , kLayout , bLayout ! ] : [ im2colLayout , kLayout ] ,
0 commit comments