Skip to content

Commit

Permalink
pnnx2ncnn convert torch.roll with one or two shifts (Tencent#5623)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Aug 12, 2024
1 parent f3cd4c2 commit ecfd88a
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 0 deletions.
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_mm.cpp
pass_ncnn/torch_norm.cpp
pass_ncnn/torch_prod.cpp
pass_ncnn/torch_roll.cpp
pass_ncnn/torch_slice_scatter.cpp
pass_ncnn/torch_squeeze.cpp
pass_ncnn/torch_sum.cpp
Expand Down
193 changes: 193 additions & 0 deletions tools/pnnx/src/pass_ncnn/torch_roll.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_roll : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.roll op_0 1 1 input out dims=%dims shifts=%shifts
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
Slice slice 1 2 input a b
Concat concat 2 1 b a out
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dims").type != 5)
return false;

if (captured_params.at("dims").ai.size() != 1)
return false;

if (captured_params.at("shifts").type != 5)
return false;

if (captured_params.at("shifts").ai.size() != 1)
return false;

return true;
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);

const Operand* in = ops.at("slice")->inputs[0];

const int batch_index = in->params.at("__batch_index").i;

int axis = captured_params.at("dims").ai[0];
if (axis == batch_index)
{
fprintf(stderr, "roll along batch axis %d is not supported\n", batch_index);
}

if (axis < 0)
{
int input_rank = in->shape.size();
axis = input_rank + axis;
}

if (axis > batch_index)
axis -= 1;

ops.at("slice")->params["1"] = axis;

ops.at("concat")->params["0"] = axis;

const int shift = captured_params.at("shifts").ai[0];
ops.at("slice")->params["2"] = std::vector<int>{-shift};
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_roll, 20)

class torch_roll_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.roll op_0 1 1 input out dims=%dims shifts=%shifts
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input 0 1 input
Slice slice 1 2 input a b
Slice slice_a 1 2 a a0 a1
Slice slice_b 1 2 b b0 b1
Concat concat_a 2 1 a1 a0 a10
Concat concat_b 2 1 b1 b0 b10
Concat concat 2 1 b10 a10 out
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dims").type != 5)
return false;

if (captured_params.at("dims").ai.size() != 2)
return false;

if (captured_params.at("shifts").type != 5)
return false;

if (captured_params.at("shifts").ai.size() != 2)
return false;

return true;
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);

const Operand* in = ops.at("slice")->inputs[0];

const int batch_index = in->params.at("__batch_index").i;

int axis0 = captured_params.at("dims").ai[0];
int axis1 = captured_params.at("dims").ai[1];
if (axis0 == batch_index || axis1 == batch_index)
{
fprintf(stderr, "roll along batch axis %d is not supported\n", batch_index);
}

if (axis0 < 0)
{
int input_rank = in->shape.size();
axis0 = input_rank + axis0;
}

if (axis0 > batch_index)
axis0 -= 1;

if (axis1 < 0)
{
int input_rank = in->shape.size();
axis1 = input_rank + axis1;
}
if (axis1 > batch_index)
axis1 -= 1;

ops.at("slice")->params["1"] = axis0;
ops.at("slice_a")->params["1"] = axis1;
ops.at("slice_b")->params["1"] = axis1;

ops.at("concat_a")->params["0"] = axis1;
ops.at("concat_b")->params["0"] = axis1;
ops.at("concat")->params["0"] = axis0;

const int shift0 = captured_params.at("shifts").ai[0];
const int shift1 = captured_params.at("shifts").ai[1];
ops.at("slice")->params["2"] = std::vector<int>{-shift0};
ops.at("slice_a")->params["2"] = std::vector<int>{-shift1};
ops.at("slice_b")->params["2"] = std::vector<int>{-shift1};
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_roll_1, 20)

} // namespace ncnn

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ pnnx_add_test(torch_ones_like)
pnnx_add_test(torch_positive)
pnnx_add_test(torch_prod)
pnnx_add_test(torch_repeat_interleave)
pnnx_add_test(torch_roll)
pnnx_add_test(torch_scatter_add)
pnnx_add_test(torch_slice_scatter)
pnnx_add_test(torch_sum)
Expand Down
1 change: 1 addition & 0 deletions tools/pnnx/tests/ncnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ pnnx_ncnn_add_test(torch_min)
pnnx_ncnn_add_test(torch_mm)
pnnx_ncnn_add_test(torch_norm)
pnnx_ncnn_add_test(torch_prod)
pnnx_ncnn_add_test(torch_roll)
pnnx_ncnn_add_test(torch_slice_scatter)
pnnx_ncnn_add_test(torch_sum)
pnnx_ncnn_add_test(torch_squeeze)
Expand Down
64 changes: 64 additions & 0 deletions tools/pnnx/tests/ncnn/test_torch_roll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = torch.roll(x, 3, 1)
y = torch.roll(y, -2, -1)
z = torch.roll(z, shifts=(2,1), dims=(0,1))
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.rand(5, 9, 11)
z = torch.rand(8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_roll.pt")

# torchscript to ncnn
import os
os.system("../../src/pnnx test_torch_roll.pt inputshape=[3,16],[5,9,11],[8,5,9,10]")

# ncnn inference
import test_torch_roll_ncnn
b = test_torch_roll_ncnn.test_inference()

print(x)
for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
print(a0)
print(b0)
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
1 change: 1 addition & 0 deletions tools/pnnx/tests/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ pnnx_onnx_add_test(torch_mean)
pnnx_onnx_add_test(torch_min)
pnnx_onnx_add_test(torch_minimum)
pnnx_onnx_add_test(torch_prod)
pnnx_onnx_add_test(torch_roll)
pnnx_onnx_add_test(torch_split)
pnnx_onnx_add_test(torch_squeeze)
pnnx_onnx_add_test(torch_stack)
Expand Down
64 changes: 64 additions & 0 deletions tools/pnnx/tests/onnx/test_torch_roll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = torch.roll(x, 3, -1)
y = torch.roll(y, -2, -1)
z = torch.roll(z, shifts=(2,1), dims=(0,1))
return x, y, z

def test():
if version.parse(torch.__version__) < version.parse('1.10'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(14, 8, 5, 9, 10)

a = net(x, y, z)

# export onnx
torch.onnx.export(net, (x, y, z), "test_torch_roll.onnx")

# onnx to pnnx
import os
os.system("../../src/pnnx test_torch_roll.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]")

# pnnx inference
import test_torch_roll_pnnx
b = test_torch_roll_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading

0 comments on commit ecfd88a

Please sign in to comment.