Skip to content

Commit

Permalink
Add quantized concat
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Mar 8, 2019
1 parent 6216c25 commit 54a6e21
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/NnCase.Cli/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,9 @@ static async Task Main(string[] args)
new K210EliminateAddRemovePaddingTransform(),
new QuantizedAddTransform(),
new QuantizedMaxPool2dTransform(),
//new ExclusiveConcatenationTransform(),
new ExclusiveConcatenationTransform(),
new QuantizedExclusiveConcatenationTransform(),
new QuantizedConcatenationTransform(),
new EliminateQuantizeDequantizeTransform(),
new EliminateInputQuantizeTransform(),
new K210EliminateInputUploadTransform(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using NnCase.Converter.K210.Converters.Stages.Convert;
using NnCase.Converter.K210.Converters.Stages.Inference;
using NnCase.Converter.Model.Layers;

namespace NnCase.Converter.K210.Converters.Layers
{
[LayerConverter(typeof(QuantizedConcatenation), K210LayerType.QuantizedConcatenation)]
public class QuantizedConcatenationConverter
{
public ConcatenationLayerArgument Convert(QuantizedConcatenation layer, ConvertContext context)
{
return new ConcatenationLayerArgument
{
InputCount = (uint)layer.Inputs.Count
};
}

public void Infer(QuantizedConcatenation layer, ConcatenationLayerArgument argument, InferenceContext context)
{
var outputAlloc = context.MainMemoryMap[layer.Output];

argument.Flags = K210LayerFlags.MainMemoryOutput;
argument.MainMemoryOutputAddress = outputAlloc.GetAddress();
argument.InputsMainMemory = (from i in layer.Inputs
let a = context.MainMemoryMap[i.Connection.From]
select new MemoryRange
{
Start = a.GetAddress(),
Size = a.Size
}).ToList();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public enum K210LayerType
L2Normalization,
Softmax,
Concatenation,
QuantizedConcatenation,
K210Conv = 10240,
K210AddPadding,
K210RemovePadding,
Expand Down
60 changes: 60 additions & 0 deletions src/NnCase.Converter/Transforms/QuantizedConcatenationTransform.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using NnCase.Converter.Model;
using NnCase.Converter.Model.Layers;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace NnCase.Converter.Transforms
{
public class QuantizedConcatenationTransform : Transform
{
protected override bool OnTryMatch(Layer layer, TransformContext context)
{
try
{
if (layer is Concatenation concat)
{
if (concat.Inputs.Select(x => x.Connection.From.Owner).All(x => x is Dequantize))
{
context.Inputs.AddRange(concat.Inputs);
context.Outputs.Add(concat.Output);

context.MatchedLayers.Add(layer);
return true;
}
}

return false;
}
catch
{
return false;
}
}

public override void Process(TransformContext context)
{
var concat = (Concatenation)context.MatchedLayers[0];
var output = concat.Output;

var exConcat = new QuantizedConcatenation(concat.Inputs.Select(x => new ReadOnlyMemory<int>(x.Dimensions.ToArray())));
for (int i = 0; i < exConcat.Inputs.Count; i++)
{
var input = concat.Inputs[i].Connection.From;
var quantize = new Quantize(input.Dimensions);
var requantize = new Requantize(quantize.Output.Dimensions);
quantize.Input.SetConnection(input);
requantize.Input.SetConnection(quantize.Output);
exConcat.Inputs[i].SetConnection(requantize.Output);
}

var dequantize = new Dequantize(exConcat.Output.Dimensions);
dequantize.Input.SetConnection(exConcat.Output);

var oldOuts = output.Connections.Select(o => o.To).ToList();
foreach (var oldOut in oldOuts)
oldOut.SetConnection(dequantize.Output);
}
}
}

0 comments on commit 54a6e21

Please sign in to comment.