Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit ff67621

Browse files
committed
fix conv fuse bug
1 parent 3aef5db commit ff67621

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

lib/backends/webgl/ops/conv.ts

+2-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import {Attribute} from '../../../attribute';
55
import {Logger} from '../../../instrument';
6-
import {Conv} from '../../../ops/conv';
6+
import {Conv, getActicationSnippet} from '../../../ops/conv';
77
import {Tensor} from '../../../tensor';
88
import {PoolConvUtil} from '../../../util';
99
import {getGlsl} from '../glsl-source';
@@ -12,7 +12,6 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t
1212
import {WebGLContext} from '../webgl-context';
1313

1414
import {WebGLConvPacked} from './conv-pack';
15-
import {glslRelu, glslSigmoid} from './unary-op';
1615

1716
export class WebGLConv extends Conv {
1817
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;
@@ -146,28 +145,6 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
146145
}
147146
}
148147

149-
function getActicationSnippet(activation: string) {
150-
let activationFunction = '';
151-
let activationName = '';
152-
switch (activation) {
153-
case 'Relu':
154-
activationName = glslRelu().name;
155-
activationFunction = glslRelu().body;
156-
break;
157-
case 'Sigmoid':
158-
activationName = glslSigmoid().name;
159-
activationFunction = glslSigmoid().body;
160-
break;
161-
default:
162-
activationName = '';
163-
activationFunction = '';
164-
}
165-
const applyActivation = activation ? `
166-
value = ${activationName}(value);` :
167-
'';
168-
return {activationFunction, applyActivation};
169-
}
170-
171148
export class WebGLUnpackedConv extends Conv {
172149
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
173150
const programManager = inferenceHandler.session.programManager;
@@ -242,7 +219,6 @@ export class WebGLUnpackedConv extends Conv {
242219
let blend = false;
243220
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
244221
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);
245-
246222
if (k === sharedDimReadSize) {
247223
blend = true;
248224
gl.enable(gl.BLEND);
@@ -348,7 +324,7 @@ export class WebGLUnpackedConv extends Conv {
348324
const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape);
349325
const initValue = (inputs.length < 3) ? '0.0' : '_B(b)';
350326
const sharedDim = im2colLayout.shape[3];
351-
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
327+
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation;
352328
const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ?
353329
this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) :
354330
sharedDim;

lib/ops/conv.ts

+23
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import {Attribute} from '../attribute';
55
import {InferenceHandler} from '../backend';
6+
import {glslRelu, glslSigmoid} from '../backends/webgl/ops/unary-op';
67
import {Operator} from '../operators';
78
import {Tensor} from '../tensor';
89

@@ -91,3 +92,25 @@ export abstract class Conv implements Operator {
9192
protected strides: number[];
9293
protected activation: string;
9394
}
95+
96+
export function getActicationSnippet(activation: string) {
97+
let activationFunction = '';
98+
let activationName = '';
99+
switch (activation) {
100+
case 'Relu':
101+
activationName = glslRelu().name;
102+
activationFunction = glslRelu().body;
103+
break;
104+
case 'Sigmoid':
105+
activationName = glslSigmoid().name;
106+
activationFunction = glslSigmoid().body;
107+
break;
108+
default:
109+
activationName = '';
110+
activationFunction = '';
111+
}
112+
const applyActivation = activation ? `
113+
value = ${activationName}(value);` :
114+
'';
115+
return {activationFunction, applyActivation};
116+
}

0 commit comments

Comments
 (0)