Skip to content

Commit

Permalink
Add scirpt export
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Jan 7, 2019
1 parent fc3bd1a commit e65ae8a
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 24 deletions.
38 changes: 37 additions & 1 deletion src/NnCase.Cli/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Numerics.Tensors;
using System.Threading.Tasks;
using CommandLine;
using NnCase.Converter.Converters;
using NnCase.Converter.Data;
using NnCase.Converter.Model;
using NnCase.Converter.Model.Layers;
using NnCase.Converter.Model.Layers.K210;
using NnCase.Converter.Transforms;
using NnCase.Converter.Transforms.K210;

Expand Down Expand Up @@ -68,6 +71,28 @@ static async Task Main(string[] args)
graph = tfc.Graph;
break;
}
case "test":
{
var inputs = new[]
{
new InputLayer(new[]{-1,3,8,8}){ Name ="input" }
};
var conv2d = new K210Conv2d(inputs[0].Output.Dimensions, K210Conv2dType.Conv2d,
new DenseTensor<float>(new[] { 32, 3, 3, 3 }), null, K210PoolType.None, ActivationFunctionType.Relu);
conv2d.Input.SetConnection(inputs[0].Output);

var spconv2d = new K210SeparableConv2d(conv2d.Output.Dimensions, new DenseTensor<float>(new[] { 1, 32, 3, 3, 3 }),
new DenseTensor<float>(new[] { 32, 64, 1, 1 }), null, K210PoolType.LeftTop, ActivationFunctionType.Relu);
spconv2d.Input.SetConnection(conv2d.Output);

var outputs = new[]
{
new OutputLayer(spconv2d.Output.Dimensions){Name = "output"}
};
outputs[0].Input.SetConnection(spconv2d.Output);
graph = new Graph(inputs, outputs);
}
break;
default:
throw new ArgumentException("input-format");
}
Expand Down Expand Up @@ -104,7 +129,7 @@ static async Task Main(string[] args)
}

Transform.Process(graph, new Transform[] {
new K210SeprableConv2dTransform(),
new K210SeparableConv2dTransform(),
new K210SpaceToBatchNdAndValidConv2dTransform(),
new K210SameConv2dTransform(),
new K210Stride2Conv2dTransform(),
Expand All @@ -131,6 +156,17 @@ await k210c.ConvertAsync(new ImageDataset(
}
break;
}
case "k210script":
{
{
var dim = graph.Inputs.First().Output.Dimensions.ToArray();
var k210c = new GraphToScriptConverter(graph);
await k210c.ConvertAsync(
Path.GetDirectoryName(options.Output),
Path.GetFileNameWithoutExtension(options.Output));
}
break;
}
default:
throw new ArgumentException("output-format");
}
Expand Down
2 changes: 1 addition & 1 deletion src/NnCase.Converter/Converters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static async Task ExportK210Code(string modelPath, string datasetDir, str
tfc.Convert();
var graph = tfc.Graph;
Transform.Process(graph, new Transform[] {
new K210SeprableConv2dTransform(),
new K210SeparableConv2dTransform(),
new K210SpaceToBatchNdAndValidConv2dTransform(),
new K210SameConv2dTransform(),
new K210Stride2Conv2dTransform(),
Expand Down
190 changes: 189 additions & 1 deletion src/NnCase.Converter/Converters/GraphToScriptConverter.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,78 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using NnCase.Converter.Model;
using NnCase.Converter.Model.Layers;
using NnCase.Converter.Model.Layers.K210;
using RazorLight;

namespace NnCase.Converter.Converters
{
public abstract class ScriptLayerConfig
{

}

public class ScriptInputLayerConfig : ScriptLayerConfig
{
public string Name { get; set; }

public long[] Dimensions { get; set; }

public int Output { get; set; }
}

public class ScriptOutputLayerConfig : ScriptLayerConfig
{
public string Name { get; set; }

public int Input { get; set; }
}

public class ScriptConv2dLayerConfig : ScriptLayerConfig
{
public int KernelSize { get; set; }

public int Filters { get; set; }

public int Stride { get; set; }

public ActivationFunctionType Activation { get; set; }

public int Input { get; set; }

public int Output { get; set; }
}

public class ScriptSeparableConv2dLayerConfig : ScriptLayerConfig
{
public int KernelSize { get; set; }

public int Filters { get; set; }

public int Stride { get; set; }

public ActivationFunctionType Activation { get; set; }

public int Input { get; set; }

public int Output { get; set; }
}

public class ScriptGenerationContext
{
public string Prefix { get; set; }

public IReadOnlyList<ScriptLayerConfig> Layers { get; set; }

public IReadOnlyDictionary<OutputConnector, int> Outputs { get; set; }

public string OutputName { get; set; }
}

public class GraphToScriptConverter
{
private readonly Graph _graph;
Expand All @@ -22,9 +87,132 @@ public GraphToScriptConverter(Graph graph)
.Build();
}

public async Task ConvertAsync(Graph graph, string outputDir, string prefix)
public async Task ConvertAsync(string outputDir, string prefix)
{
var context = new ConvertContext();

foreach (var layer in _graph.Outputs)
ConvertLayer(layer, context);

var scriptGenContext = new ScriptGenerationContext
{
Prefix = prefix,
Layers = context.Layers,
Outputs = context.Outputs,
OutputName = _graph.Outputs.First().Name
};
var code = await _templateEngine.CompileRenderAsync("Model", scriptGenContext);
File.WriteAllText(Path.Combine(outputDir, $"{prefix}.py"), code);
}

private void ConvertLayer(Layer layer, ConvertContext context)
{
if (!context.ProcessMap.GetValueOrDefault(layer))
{
context.ProcessMap[layer] = true;

foreach (var conn in layer.InputConnectors)
{
var nextLayer = conn.Connection?.From.Owner;
if (nextLayer != null)
ConvertLayer(nextLayer, context);
}

switch (layer)
{
case InputLayer l:
ConvertInputLayer(l, context);
break;
case OutputLayer l:
ConvertOutputLayer(l, context);
break;
case K210Conv2d l:
ConvertK210Conv2d(l, context);
break;
case K210SeparableConv2d l:
ConvertK210SeparableConv2d(l, context);
break;
default:
throw new NotSupportedException(nameof(layer));
}
}
}

private void ConvertInputLayer(InputLayer layer, ConvertContext context)
{
context.Layers.Add(new ScriptInputLayerConfig
{
Name = layer.Name,
Dimensions = layer.Output.Dimensions.ToNHWC().ToArray(),
Output = context.AddOutput(layer.Output)
});
}

private void ConvertOutputLayer(OutputLayer layer, ConvertContext context)
{
context.Layers.Add(new ScriptOutputLayerConfig
{
Name = layer.Name,
Input = context.Outputs[layer.Input.Connection.From]
});
}

private void ConvertK210Conv2d(K210Conv2d layer, ConvertContext context)
{
if (layer.Conv2dType != K210Conv2dType.Conv2d)
throw new NotSupportedException("Depthwise conv2d is not supported.");

context.Layers.Add(new ScriptConv2dLayerConfig
{
KernelSize = layer.KernelWidth,
Stride = GetStride(layer.PoolType),
Filters = layer.OutputChannels,
Activation = layer.FusedActivationFunction,
Input = context.Outputs[layer.Input.Connection.From],
Output = context.AddOutput(layer.Output)
});
}

private void ConvertK210SeparableConv2d(K210SeparableConv2d layer, ConvertContext context)
{
context.Layers.Add(new ScriptSeparableConv2dLayerConfig
{
KernelSize = layer.KernelWidth,
Stride = GetStride(layer.PoolType),
Filters = layer.OutputChannels,
Activation = layer.FusedActivationFunction,
Input = context.Outputs[layer.Input.Connection.From],
Output = context.AddOutput(layer.Output)
});
}

private static int GetStride(K210PoolType poolType)
{
switch (poolType)
{
case K210PoolType.None:
return 1;
case K210PoolType.LeftTop:
return 2;
default:
throw new NotSupportedException(nameof(poolType));
}
}

private class ConvertContext
{
public Dictionary<Layer, bool> ProcessMap = new Dictionary<Layer, bool>();

public List<ScriptLayerConfig> Layers = new List<ScriptLayerConfig>();

public Dictionary<OutputConnector, int> Outputs = new Dictionary<OutputConnector, int>();

public int AddOutput(OutputConnector output)
{
var id = Outputs.Count;
Outputs.Add(output, id);
return id;
}
}
}
}
2 changes: 0 additions & 2 deletions src/NnCase.Converter/Model/Layers/K210/K210SeparableConv2d.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ public class K210SeparableConv2d : Layer

public OutputConnector Output { get; }

public K210Conv2dType Conv2dType { get; }

public Tensor<float> DwWeights { get; }

public Tensor<float> PwWeights { get; }
Expand Down
9 changes: 5 additions & 4 deletions src/NnCase.Converter/NnCase.Converter.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
<DefineConstants>$(DefineConstants);ENABLE_SPAN_T;UNSAFE_BYTEBUFFER;CHANNEL_WISE</DefineConstants>
</PropertyGroup>

<ItemGroup>
<None Remove="Templates\Script\Model.cshtml" />
</ItemGroup>

<ItemGroup>
<EmbeddedResource Include="Templates\K210\Model.cshtml" />
<EmbeddedResource Include="Templates\Script\Model.cshtml" />
</ItemGroup>

<ItemGroup>
Expand All @@ -34,8 +39,4 @@
<ProjectReference Include="..\FlatBuffers\FlatBuffers.csproj" />
</ItemGroup>

<ItemGroup>
<Folder Include="Templates\Script\" />
</ItemGroup>

</Project>
Loading

0 comments on commit e65ae8a

Please sign in to comment.