From ecfd88a11bdf6480c0496564c6392463997429fc Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 12 Aug 2024 19:33:39 +0800 Subject: [PATCH] pnnx2ncnn convert torch.roll with one or two shifts (#5623) --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/pass_ncnn/torch_roll.cpp | 193 +++++++++++++++++++++++ tools/pnnx/tests/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/test_torch_roll.py | 64 ++++++++ tools/pnnx/tests/onnx/CMakeLists.txt | 1 + tools/pnnx/tests/onnx/test_torch_roll.py | 64 ++++++++ tools/pnnx/tests/test_torch_roll.py | 61 +++++++ 8 files changed, 386 insertions(+) create mode 100644 tools/pnnx/src/pass_ncnn/torch_roll.cpp create mode 100644 tools/pnnx/tests/ncnn/test_torch_roll.py create mode 100644 tools/pnnx/tests/onnx/test_torch_roll.py create mode 100644 tools/pnnx/tests/test_torch_roll.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 27dfdef52f8..c5c6228dee7 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -572,6 +572,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_mm.cpp pass_ncnn/torch_norm.cpp pass_ncnn/torch_prod.cpp + pass_ncnn/torch_roll.cpp pass_ncnn/torch_slice_scatter.cpp pass_ncnn/torch_squeeze.cpp pass_ncnn/torch_sum.cpp diff --git a/tools/pnnx/src/pass_ncnn/torch_roll.cpp b/tools/pnnx/src/pass_ncnn/torch_roll.cpp new file mode 100644 index 00000000000..c7c29593333 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_roll.cpp @@ -0,0 +1,193 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_roll : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.roll op_0 1 1 input out dims=%dims shifts=%shifts +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Slice slice 1 2 input a b +Concat concat 2 1 b a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.at("dims").type != 5) + return false; + + if (captured_params.at("dims").ai.size() != 1) + return false; + + if (captured_params.at("shifts").type != 5) + return false; + + if (captured_params.at("shifts").ai.size() != 1) + return false; + + return true; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + const Operand* in = ops.at("slice")->inputs[0]; + + const int batch_index = in->params.at("__batch_index").i; + + int axis = captured_params.at("dims").ai[0]; + if (axis == batch_index) + { + fprintf(stderr, "roll along batch axis %d is not supported\n", batch_index); + } + + if (axis < 0) + { + int input_rank = in->shape.size(); + axis = input_rank + axis; + } + + if (axis > batch_index) + axis -= 1; + + ops.at("slice")->params["1"] = axis; + + ops.at("concat")->params["0"] = axis; + + const int shift = captured_params.at("shifts").ai[0]; + ops.at("slice")->params["2"] = std::vector{-shift}; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_roll, 20) + +class torch_roll_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.roll op_0 1 1 input out dims=%dims shifts=%shifts +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input 0 1 input +Slice slice 1 2 input a b +Slice slice_a 1 2 a a0 a1 +Slice slice_b 1 2 b b0 b1 +Concat concat_a 2 1 a1 a0 a10 +Concat concat_b 2 1 b1 b0 b10 +Concat concat 2 1 b10 a10 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.at("dims").type != 5) + return false; + + if (captured_params.at("dims").ai.size() != 2) + return false; + + if (captured_params.at("shifts").type != 5) + return false; + + if (captured_params.at("shifts").ai.size() != 2) + return false; + + return true; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + const Operand* in = ops.at("slice")->inputs[0]; + + const int batch_index = in->params.at("__batch_index").i; + + int axis0 = captured_params.at("dims").ai[0]; + int axis1 = captured_params.at("dims").ai[1]; + if (axis0 == batch_index || axis1 == batch_index) + { + fprintf(stderr, "roll along batch axis %d is not supported\n", batch_index); + } + + if (axis0 < 0) + { + int input_rank = in->shape.size(); + axis0 = input_rank + axis0; + } + + if (axis0 > batch_index) + axis0 -= 1; + + if (axis1 < 0) + { + int input_rank = in->shape.size(); + axis1 = input_rank + axis1; + } + if (axis1 > batch_index) + axis1 -= 1; + + ops.at("slice")->params["1"] = axis0; + ops.at("slice_a")->params["1"] = axis1; + ops.at("slice_b")->params["1"] = axis1; + + ops.at("concat_a")->params["0"] = axis1; + ops.at("concat_b")->params["0"] = axis1; + ops.at("concat")->params["0"] = axis0; + + const int shift0 = captured_params.at("shifts").ai[0]; + const int shift1 = captured_params.at("shifts").ai[1]; + ops.at("slice")->params["2"] = std::vector{-shift0}; + ops.at("slice_a")->params["2"] = std::vector{-shift1}; + ops.at("slice_b")->params["2"] = std::vector{-shift1}; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_roll_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 7bbf1c6ea9c..a5522a70bb2 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -234,6 +234,7 @@ pnnx_add_test(torch_ones_like) pnnx_add_test(torch_positive) pnnx_add_test(torch_prod) pnnx_add_test(torch_repeat_interleave) +pnnx_add_test(torch_roll) pnnx_add_test(torch_scatter_add) pnnx_add_test(torch_slice_scatter) pnnx_add_test(torch_sum) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index a682e42835b..a60e63eb54b 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -162,6 +162,7 @@ pnnx_ncnn_add_test(torch_min) pnnx_ncnn_add_test(torch_mm) pnnx_ncnn_add_test(torch_norm) pnnx_ncnn_add_test(torch_prod) +pnnx_ncnn_add_test(torch_roll) pnnx_ncnn_add_test(torch_slice_scatter) pnnx_ncnn_add_test(torch_sum) pnnx_ncnn_add_test(torch_squeeze) diff --git a/tools/pnnx/tests/ncnn/test_torch_roll.py b/tools/pnnx/tests/ncnn/test_torch_roll.py new file mode 100644 index 00000000000..6412ee6ba60 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_roll.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.roll(x, 3, 1) + y = torch.roll(y, -2, -1) + z = torch.roll(z, shifts=(2,1), dims=(0,1)) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + z = torch.rand(8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_roll.pt") + + # torchscript to ncnn + import os + os.system("../../src/pnnx test_torch_roll.pt inputshape=[3,16],[5,9,11],[8,5,9,10]") + + # ncnn inference + import test_torch_roll_ncnn + b = test_torch_roll_ncnn.test_inference() + + print(x) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + print(a0) + print(b0) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index f4756740a79..673fa0434d9 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -145,6 +145,7 @@ pnnx_onnx_add_test(torch_mean) pnnx_onnx_add_test(torch_min) pnnx_onnx_add_test(torch_minimum) pnnx_onnx_add_test(torch_prod) +pnnx_onnx_add_test(torch_roll) pnnx_onnx_add_test(torch_split) pnnx_onnx_add_test(torch_squeeze) pnnx_onnx_add_test(torch_stack) diff --git a/tools/pnnx/tests/onnx/test_torch_roll.py b/tools/pnnx/tests/onnx/test_torch_roll.py new file mode 100644 index 00000000000..06b8d579649 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_roll.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.roll(x, 3, -1) + y = torch.roll(y, -2, -1) + z = torch.roll(z, shifts=(2,1), dims=(0,1)) + return x, y, z + +def test(): + if version.parse(torch.__version__) < version.parse('1.10'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_roll.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_roll.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_roll_pnnx + b = test_torch_roll_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_roll.py b/tools/pnnx/tests/test_torch_roll.py new file mode 100644 index 00000000000..32e3bde38e1 --- /dev/null +++ b/tools/pnnx/tests/test_torch_roll.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.roll(x, 3) + y = torch.roll(y, -2, -1) + z = torch.roll(z, shifts=(2,1), dims=(0,1)) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_roll.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_roll.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_roll_pnnx + b = test_torch_roll_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)