Skip to content

Commit

Permalink
Add new APIs for conv, deconv and fc
Browse files Browse the repository at this point in the history
The new apis remvoe weights, oc_count and ksize.

Signed-off-by: zhao.xia <[email protected]>
  • Loading branch information
nightingalei authored and thezha committed Jun 7, 2021
1 parent 8d35c4d commit 0ed1e89
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 15 deletions.
8 changes: 8 additions & 0 deletions include/tim/vx/ops/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ namespace ops {

class Conv1d : public Operation {
public:
Conv1d(Graph* graph, PadType padding, uint32_t stride,
uint32_t dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv1d(Graph* graph, int32_t weights, PadType padding,
uint32_t ksize, uint32_t stride,
uint32_t dilation, int32_t multiplier = 0,
Expand Down
10 changes: 10 additions & 0 deletions include/tim/vx/ops/conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ namespace ops {

class Conv2d : public Operation {
public:
Conv2d(Graph* graph, PadType padding,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv2d(Graph* graph, const std::array<uint32_t, 4> pad,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier = 0,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride,
Expand Down
13 changes: 13 additions & 0 deletions include/tim/vx/ops/deconv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ namespace ops {

class DeConv1d : public Operation {
public:
DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding, uint32_t group = 1,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
DeConv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t output_padding, uint32_t group = 1,
DataLayout input_layout = DataLayout::WHCN,
DataLayout kernel_layout = DataLayout::WHIcOc);
DeConv1d(Graph* graph, int32_t oc_count_, PadType pad_type,
uint32_t ksize,
uint32_t stride,
Expand All @@ -61,6 +69,10 @@ class DeConv1d : public Operation {
uint32_t output_padding,
const std::array<uint32_t, 2>& pad,
uint32_t group = 1);
DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding,
const std::array<uint32_t, 2>& pad, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout);

protected:
const uint32_t oc_count_; // output channel count
Expand All @@ -70,6 +82,7 @@ class DeConv1d : public Operation {
const uint32_t output_padding_;
const std::array<uint32_t, 2> pad_;
const uint32_t group_;
const DataLayout kernel_layout_;
};

} // namespace ops
Expand Down
1 change: 1 addition & 0 deletions include/tim/vx/ops/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace ops {

class FullyConnected : public Operation {
public:
FullyConnected(Graph* graph, uint32_t axis);
FullyConnected(Graph* graph, uint32_t axis, uint32_t weights);

protected:
Expand Down
14 changes: 12 additions & 2 deletions src/tim/vx/ops/conv1d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ namespace tim {
namespace vx {
namespace ops {

Conv1d::Conv1d(Graph* graph, PadType padding, uint32_t stride,
uint32_t dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv1d(graph, 0, padding, 0, stride, dilation, {0, 0},
multiplier, input_layout, kernel_layout) {}

Conv1d::Conv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv1d(graph, 0, PadType::AUTO, 0, stride, dilation, pad,
multiplier, input_layout, kernel_layout) {}

Conv1d::Conv1d(Graph* graph, int32_t weights, PadType padding,
uint32_t ksize, uint32_t stride,
uint32_t dilation, int32_t multiplier,
Expand All @@ -51,10 +63,8 @@ Conv1d::Conv1d(Graph* graph, int32_t weights, PadType padding,
pad_(pad),
multiplier_(multiplier),
kernel_layout_(kernel_layout) {
this->impl()->node()->nn_param.conv1d.ksize = ksize_;
this->impl()->node()->nn_param.conv1d.stride = stride_;
this->impl()->node()->nn_param.conv1d.pad_type = TranslatePadType(padding_);
this->impl()->node()->nn_param.conv1d.weights = weights;
this->impl()->node()->nn_param.conv1d.group = 1;
this->impl()->node()->nn_param.conv1d.dilation = dilation_;
this->impl()->node()->nn_param.conv1d.pad[0] = pad_[0];
Expand Down
17 changes: 14 additions & 3 deletions src/tim/vx/ops/conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ namespace tim {
namespace vx {
namespace ops {

Conv2d::Conv2d(Graph* graph, PadType padding,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv2d(graph, 0, padding, {0, 0}, stride, dilation, {0, 0, 0, 0},
multiplier, input_layout, kernel_layout) {}

Conv2d::Conv2d(Graph* graph, const std::array<uint32_t, 4> pad,
const std::array<uint32_t, 2>& stride,
const std::array<uint32_t, 2>& dilation, int32_t multiplier,
DataLayout input_layout, DataLayout kernel_layout)
: Conv2d(graph, 0, PadType::AUTO, {0, 0}, stride, dilation, pad,
multiplier, input_layout, kernel_layout) {}

Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding,
const std::array<uint32_t, 2>& ksize,
const std::array<uint32_t, 2>& stride,
Expand All @@ -54,12 +68,9 @@ Conv2d::Conv2d(Graph* graph, int32_t weights, PadType padding,
pad_(pad),
multiplier_(multiplier),
kernel_layout_(kernel_layout) {
this->impl()->node()->nn_param.conv2d.ksize[0] = ksize_[0];
this->impl()->node()->nn_param.conv2d.ksize[1] = ksize_[1];
this->impl()->node()->nn_param.conv2d.stride[0] = stride_[0];
this->impl()->node()->nn_param.conv2d.stride[1] = stride_[1];
this->impl()->node()->nn_param.conv2d.pad_type = TranslatePadType(padding_);
this->impl()->node()->nn_param.conv2d.weights = weights;
this->impl()->node()->nn_param.conv2d.group = 1;
this->impl()->node()->nn_param.conv2d.dilation[0] = dilation_[0];
this->impl()->node()->nn_param.conv2d.dilation[1] = dilation_[1];
Expand Down
42 changes: 33 additions & 9 deletions src/tim/vx/ops/deconv1d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,52 @@ namespace tim {
namespace vx {
namespace ops {

DeConv1d::DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout)
: DeConv1d(graph, pad_type, stride, output_padding, {0, 0}, group,
input_layout, kernel_layout) {
}

DeConv1d::DeConv1d(Graph* graph, const std::array<uint32_t, 2>& pad,
uint32_t stride, uint32_t output_padding, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout)
: DeConv1d(graph, PadType::AUTO, stride, output_padding, pad, group,
input_layout, kernel_layout) {
}

DeConv1d::DeConv1d(Graph* graph, int32_t oc_count, PadType pad_type,
uint32_t ksize, uint32_t stride, uint32_t output_padding)
: DeConv1d(graph, oc_count, pad_type, ksize, stride, output_padding,
{0, 0}) {
: DeConv1d(graph, pad_type, stride, output_padding,
{0, 0}, 1, DataLayout::WHCN, DataLayout::WHIcOc) {
(void)ksize;
(void)oc_count;
}

DeConv1d::DeConv1d(Graph* graph, int32_t oc_count, PadType pad_type,
uint32_t ksize, uint32_t stride, uint32_t output_padding,
const std::array<uint32_t, 2>& pad, uint32_t group)
: Operation(graph, VSI_NN_OP_DECONVOLUTION1D),
oc_count_(oc_count),
: DeConv1d(graph, pad_type, stride, output_padding,
pad, group, DataLayout::WHCN, DataLayout::WHIcOc) {
(void)ksize;
(void)oc_count;
}

DeConv1d::DeConv1d(Graph* graph, PadType pad_type,
uint32_t stride, uint32_t output_padding,
const std::array<uint32_t, 2>& pad, uint32_t group,
DataLayout input_layout, DataLayout kernel_layout)
: Operation(graph, VSI_NN_OP_DECONVOLUTION1D, 3, 1, input_layout),
oc_count_(0),
pad_type_(pad_type),
ksize_(ksize),
ksize_(0),
stride_(stride),
output_padding_(output_padding),
pad_(pad),
group_(group) {

this->impl()->node()->nn_param.deconvolution1d.ksize = ksize_;
group_(group),
kernel_layout_(kernel_layout) {
this->impl()->node()->nn_param.deconvolution1d.stride = stride_;
this->impl()->node()->nn_param.deconvolution1d.pad_type = TranslatePadType(pad_type_);
this->impl()->node()->nn_param.deconvolution1d.weights = oc_count_;
this->impl()->node()->nn_param.deconvolution1d.group = group_;
this->impl()->node()->nn_param.deconvolution1d.output_padding = output_padding_;
this->impl()->node()->nn_param.deconvolution1d.pad[0] = pad_[0];
Expand Down
6 changes: 5 additions & 1 deletion src/tim/vx/ops/fullyconnected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ namespace tim {
namespace vx {
namespace ops {

FullyConnected::FullyConnected(Graph* graph, uint32_t axis)
: FullyConnected(graph, axis, 0) {
}

FullyConnected::FullyConnected(Graph* graph, uint32_t axis, uint32_t weights)
: Operation(graph, VSI_NN_OP_FCL2) {
(void)weights;
this->impl()->node()->nn_param.fcl.axis = axis;
this->impl()->node()->nn_param.fcl.weights = weights;
}

} // namespace ops
Expand Down

0 comments on commit 0ed1e89

Please sign in to comment.