Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 2927e7d

Browse files
committed
add packed conv fuse
1 parent ff67621 commit 2927e7d

File tree

13 files changed

+46
-30
lines changed

13 files changed

+46
-30
lines changed

lib/backends/webgl/ops/concat_packed.ts

+4-3
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,9 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
110110
`;
111111

112112
return {
113-
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
114-
outputLayout: handler.createTextureLayoutFromShape(outputShape),
113+
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t, 4, true, t.dims, true)),
114+
outputLayout:
115+
handler.createTextureLayoutFromShape(outputShape, 4, outputShape, {isPacked: true, reverseWH: true}),
115116
samplers,
116117
shaderSource,
117118
hasMain: true,
@@ -120,7 +121,7 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
120121
};
121122
}
122123
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
123-
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
124+
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i], true));
124125
return {
125126
inputTextureDatas: inputTDs,
126127
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),

lib/backends/webgl/ops/conv-pack.ts

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT license.
33

4+
import {Attribute} from '../../../attribute';
45
import {Logger} from '../../../instrument';
56
import {Conv} from '../../../ops/conv';
67
import {Tensor} from '../../../tensor';
@@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv {
3637
const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
3738
const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
3839
const matmul = new WebGLMatMulPacked();
40+
if (!!this.activation) {
41+
const attributes = new Attribute(undefined);
42+
attributes.set('__internal_activation', 'string', (this.activation));
43+
matmul.initialize(attributes);
44+
}
3945
const reshape = new WebGLReshapePacked();
4046
// shape for kernel reshape
4147
const shape =

lib/backends/webgl/ops/conv.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import {Attribute} from '../../../attribute';
55
import {Logger} from '../../../instrument';
6-
import {Conv, getActicationSnippet} from '../../../ops/conv';
6+
import {Conv} from '../../../ops/conv';
77
import {Tensor} from '../../../tensor';
88
import {PoolConvUtil} from '../../../util';
99
import {getGlsl} from '../glsl-source';
@@ -12,6 +12,7 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t
1212
import {WebGLContext} from '../webgl-context';
1313

1414
import {WebGLConvPacked} from './conv-pack';
15+
import {getActicationSnippet} from './fuse_utils';
1516

1617
export class WebGLConv extends Conv {
1718
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;

lib/backends/webgl/ops/fuse_utils.ts

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import {glslRelu, glslSigmoid} from './unary-op';
2+
3+
export function getActicationSnippet(activation: string) {
4+
let activationFunction = '';
5+
let activationName = '';
6+
switch (activation) {
7+
case 'Relu':
8+
activationName = glslRelu().name;
9+
activationFunction = glslRelu().body;
10+
break;
11+
case 'Sigmoid':
12+
activationName = glslSigmoid().name;
13+
activationFunction = glslSigmoid().body;
14+
break;
15+
default:
16+
activationName = '';
17+
activationFunction = '';
18+
}
19+
const applyActivation = activation ? `
20+
value = ${activationName}(value);` :
21+
'';
22+
return {activationFunction, applyActivation};
23+
}

lib/backends/webgl/ops/matmul-pack.ts

+5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {Tensor} from '../../../tensor';
66
import {BroadcastUtil} from '../../../util';
77
import {WebGLInferenceHandler} from '../inference-handler';
88
import {ProgramInfo, RunData, WebGLOperator} from '../types';
9+
import {getActicationSnippet} from './fuse_utils';
910

1011
export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
1112
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
@@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
2526
const aRank = aShape.length;
2627
const bRank = bShape.length;
2728
const sharedDim = aShape[aShape.length - 1];
29+
30+
const {activationFunction, applyActivation} = getActicationSnippet(this.activation);
2831
// TODO:fix broadcasting
2932
const shaderSource = `
33+
${activationFunction}
3034
vec4 process(int indices[${rank}]) {
3135
int a[${aRank}];
3236
int b[${bRank}];
@@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
4145
value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba;
4246
}
4347
${processBias}
48+
${applyActivation}
4449
return value;
4550
}`;
4651
return {

lib/ops/conv.ts

-23
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import {Attribute} from '../attribute';
55
import {InferenceHandler} from '../backend';
6-
import {glslRelu, glslSigmoid} from '../backends/webgl/ops/unary-op';
76
import {Operator} from '../operators';
87
import {Tensor} from '../tensor';
98

@@ -92,25 +91,3 @@ export abstract class Conv implements Operator {
9291
protected strides: number[];
9392
protected activation: string;
9493
}
95-
96-
export function getActicationSnippet(activation: string) {
97-
let activationFunction = '';
98-
let activationName = '';
99-
switch (activation) {
100-
case 'Relu':
101-
activationName = glslRelu().name;
102-
activationFunction = glslRelu().body;
103-
break;
104-
case 'Sigmoid':
105-
activationName = glslSigmoid().name;
106-
activationFunction = glslSigmoid().body;
107-
break;
108-
default:
109-
activationName = '';
110-
activationFunction = '';
111-
}
112-
const applyActivation = activation ? `
113-
value = ${activationName}(value);` :
114-
'';
115-
return {activationFunction, applyActivation};
116-
}

lib/ops/matmul.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import {Tensor} from '../tensor';
99
export abstract class MatMul implements Operator {
1010
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
1111

12-
initialize(attributes: Attribute): void {}
12+
initialize(attributes: Attribute): void {
13+
this.activation = attributes.getString('__internal_activation', '');
14+
}
1315

1416
checkInputs(inputs: Tensor[]): boolean {
1517
if (!inputs || inputs.length !== 2) {
@@ -38,4 +40,5 @@ export abstract class MatMul implements Operator {
3840

3941
return true;
4042
}
43+
protected activation: string;
4144
}
Binary file not shown.
Binary file not shown.
734 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/unittests/backends/webgl/test_concat_packed.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {
9898
texture: webglTextureB!
9999
};
100100

101-
webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA);
102-
webglInferenceHandler.setTextureData(inputTensorB.dataId, textureDataB);
101+
webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA, true);
102+
webglInferenceHandler.setTextureData(inputTensorB.dataId, textureDataB, true);
103103

104104
// compile shader code
105105
const programInfo =

0 commit comments

Comments
 (0)