3
3
4
4
import { Attribute } from '../../../attribute' ;
5
5
import { Logger } from '../../../instrument' ;
6
- import { Conv } from '../../../ops/conv' ;
6
+ import { Conv , getActicationSnippet } from '../../../ops/conv' ;
7
7
import { Tensor } from '../../../tensor' ;
8
8
import { PoolConvUtil } from '../../../util' ;
9
9
import { getGlsl } from '../glsl-source' ;
@@ -12,7 +12,6 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t
12
12
import { WebGLContext } from '../webgl-context' ;
13
13
14
14
import { WebGLConvPacked } from './conv-pack' ;
15
- import { glslRelu , glslSigmoid } from './unary-op' ;
16
15
17
16
export class WebGLConv extends Conv {
18
17
unpackedGroupedConvImpl : WebGLUnpackedGroupedConv ;
@@ -146,28 +145,6 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
146
145
}
147
146
}
148
147
149
- function getActicationSnippet ( activation : string ) {
150
- let activationFunction = '' ;
151
- let activationName = '' ;
152
- switch ( activation ) {
153
- case 'Relu' :
154
- activationName = glslRelu ( ) . name ;
155
- activationFunction = glslRelu ( ) . body ;
156
- break ;
157
- case 'Sigmoid' :
158
- activationName = glslSigmoid ( ) . name ;
159
- activationFunction = glslSigmoid ( ) . body ;
160
- break ;
161
- default :
162
- activationName = '' ;
163
- activationFunction = '' ;
164
- }
165
- const applyActivation = activation ? `
166
- value = ${ activationName } (value);` :
167
- '' ;
168
- return { activationFunction, applyActivation} ;
169
- }
170
-
171
148
export class WebGLUnpackedConv extends Conv {
172
149
run ( inferenceHandler : WebGLInferenceHandler , inputs : Tensor [ ] ) : Tensor [ ] {
173
150
const programManager = inferenceHandler . session . programManager ;
@@ -242,7 +219,6 @@ export class WebGLUnpackedConv extends Conv {
242
219
let blend = false ;
243
220
for ( let k = 0 ; k < sharedDim ; k += sharedDimReadSize ) {
244
221
Logger . verbose ( 'MatMul2D' , `k = ${ k } , sharedDim: ${ sharedDim } , readSize = ${ sharedDimReadSize } ` ) ;
245
-
246
222
if ( k === sharedDimReadSize ) {
247
223
blend = true ;
248
224
gl . enable ( gl . BLEND ) ;
@@ -348,7 +324,7 @@ export class WebGLUnpackedConv extends Conv {
348
324
const outputLayout = inferenceHandler . createTextureLayoutFromShape ( outputShape ) ;
349
325
const initValue = ( inputs . length < 3 ) ? '0.0' : '_B(b)' ;
350
326
const sharedDim = im2colLayout . shape [ 3 ] ;
351
- const blendEnabled = inferenceHandler . session . backend . glContext . isBlendSupported ;
327
+ const blendEnabled = inferenceHandler . session . backend . glContext . isBlendSupported && ! this . activation ;
352
328
const sharedDimReadSize = blendEnabled && inferenceHandler . session . backend . matmulMaxBatchSize ?
353
329
this . calcSharedDimReadSize ( inferenceHandler . session . backend . matmulMaxBatchSize , sharedDim ) :
354
330
sharedDim ;
0 commit comments