From 7c228e552b70ee3271b6f675698502a51271bc1c Mon Sep 17 00:00:00 2001 From: chengxiangming <781326019@qq.com> Date: Thu, 31 Oct 2024 11:21:49 +0000 Subject: [PATCH 1/3] Add pixel_unshuffle opencl support test=develop --- .../cl_kernel/image/pixel_unshuffle_kernel.cl | 143 +++++++++++++ lite/kernels/opencl/CMakeLists.txt | 1 + .../opencl/pixel_unshuffle_image_compute.cc | 193 ++++++++++++++++++ .../pixel_unshuffle_image_compute_test.cc | 153 ++++++++++++++ 4 files changed, 490 insertions(+) create mode 100644 lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl create mode 100644 lite/kernels/opencl/pixel_unshuffle_image_compute.cc create mode 100644 lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl b/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl new file mode 100644 index 00000000000..51b9cafc947 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl @@ -0,0 +1,143 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +__kernel void pixel_unshuffle(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int in_N, + __private const int in_C, + __private const int in_H, + __private const int in_W, + __private const int out_N, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int downscale_factor) { + + const int in_c4 = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + + int in_h = in_nh % in_H; + int in_n = in_nh / in_H; + + CL_DTYPE4 res = (CL_DTYPE4)(0, 0, 0, 0); + CL_DTYPE4 in; + + int in_c0 = in_c4 * 4 + 0; + int out_c0 = in_c0 / (downscale_factor * downscale_factor); + int offset0 = in_c0 % (downscale_factor * downscale_factor); + int offset_h0 = offset0 / downscale_factor; + int offset_w0 = offset0 % downscale_factor; + + int out_w0 = in_w * downscale_factor + offset_w0; + int out_h0 = in_h * downscale_factor + offset_h0; + int out_nh0 = in_n * out_H + out_h0; + + int2 out_pos0; + out_pos0.x = out_w0 + (out_c0 / 4) * in_W; + out_pos0.y = out_nh0; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos0); + if (out_c0 % 4 == 0) { + res.x = in.x; + } else if (out_c0 % 4 == 1) { + res.x = in.y; + } else if (out_c0 % 4 == 2) { + res.x = in.z; + } else if (out_c0 % 4 == 3) { + res.x = in.w; + } + + int in_c1 = in_c4 * 4 + 1; + int out_c1 = in_c1 / (downscale_factor * downscale_factor); + int offset1 = in_c1 % (downscale_factor * downscale_factor); + int offset_h1 = offset1 / downscale_factor; + int offset_w1 = offset1 % downscale_factor; + + int out_w1 = in_w * downscale_factor + offset_w1; + int out_h1 = in_h * downscale_factor + offset_h1; + int out_nh1 = in_n * out_H + out_h1; + + int2 out_pos1; + out_pos1.x = out_w1 + (out_c1 / 4) * in_W; + out_pos1.y = out_nh1; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos1); + if (out_c1 % 4 == 0) { + res.y = in.x; + } else if (out_c1 % 4 == 1) { + res.y = in.y; + } else if (out_c1 % 4 == 2) { + res.y = in.z; + } else if (out_c1 % 4 == 3) { + res.y = in.w; + } + + int in_c2 = in_c4 * 4 + 2; + int out_c2 = in_c2 / (downscale_factor * downscale_factor); + int offset2 = in_c2 % (downscale_factor * downscale_factor); + int offset_h2 = offset2 / downscale_factor; + int offset_w2 = offset2 % downscale_factor; + + int out_w2 = in_w * downscale_factor + offset_w2; + int out_h2 = in_h * downscale_factor + offset_h2; + int out_nh2 = in_n * out_H + out_h2; + + int2 out_pos2; + out_pos2.x = out_w2 + (out_c2 / 4) * in_W; + out_pos2.y = out_nh2; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos2); + if (out_c2 % 4 == 0) { + res.z = in.x; + } else if (out_c2 % 4 == 1) { + res.z = in.y; + } else if (out_c2 % 4 == 2) { + res.z = in.z; + } else if (out_c2 % 4 == 3) { + res.z = in.w; + } + + int in_c3 = in_c4 * 4 + 3; + int out_c3 = in_c3 / (downscale_factor * downscale_factor); + int offset3 = in_c3 % (downscale_factor * downscale_factor); + int offset_h3 = offset3 / downscale_factor; + int offset_w3 = offset3 % downscale_factor; + + int out_w3 = in_w * downscale_factor + offset_w3; + int out_h3 = in_h * downscale_factor + offset_h3; + int out_nh3 = in_n * out_H + out_h3; + + int2 out_pos3; + out_pos3.x = out_w3 + (out_c3 / 4) * in_W; + out_pos3.y = out_nh3; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos3); + if (out_c3 % 4 == 0) { + res.w = in.x; + } else if (out_c3 % 4 == 1) { + res.w = in.y; + } else if (out_c3 % 4 == 2) { + res.w = in.z; + } else if (out_c3 % 4 == 3) { + res.w = in.w; + } + + int2 in_pos; + in_pos.x = in_c4 * (in_W / downscale_factor) + in_w; + in_pos.y = in_nh; + if (in_pos.x < out_W * ((out_C + 3) / 4) && in_pos.y < out_H * out_N) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, in_pos, res); + } +} \ No newline at end of file diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 857e08ca6d7..c752b5ad060 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -42,6 +42,7 @@ add_kernel(dropout_opencl_image OPENCL basic SRCS dropout_image_compute.cc) add_kernel(pad2d_opencl_image OPENCL basic SRCS pad2d_image_compute.cc) add_kernel(box_coder_opencl_image OPENCL basic SRCS box_coder_image_compute.cc) add_kernel(pixel_shuffle_opencl_image OPENCL basic SRCS pixel_shuffle_image_compute.cc) +add_kernel(pixel_unshuffle_opencl_image OPENCL basic SRCS pixel_unshuffle_image_compute.cc) add_kernel(expand_opencl_image OPENCL basic SRCS expand_image_compute.cc) add_kernel(shuffle_channel_opencl_image OPENCL basic SRCS shuffle_channel_image_compute.cc) add_kernel(trigonometric_opencl_image OPENCL basic SRCS trigonometric_image_compute.cc) diff --git a/lite/kernels/opencl/pixel_unshuffle_image_compute.cc b/lite/kernels/opencl/pixel_unshuffle_image_compute.cc new file mode 100644 index 00000000000..2c8fd66714b --- /dev/null +++ b/lite/kernels/opencl/pixel_unshuffle_image_compute.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif +#include "lite/backends/opencl/cl_utility.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class PixelUnShuffleComputeImage2D + : public KernelLite { + public: + using param_t = operators::PixelUnShuffleParam; + + std::string doc() const override { + return "PixelUnShuffle using cl::Image2D, kFP16"; + } + + void PrepareForRun() override { + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + + auto& context = ctx_->As(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/pixel_unshuffle_kernel.cl", + build_options_, + time_stamp_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } + + void ReInitWhenNeeded() override { + VLOG(1) << "ReInitWhenNeeded: " << kernel_func_name_; + pixel_unshuffle_param_ = param_.get_mutable(); + auto x_dims = pixel_unshuffle_param_->x->dims(); + auto out_dims = pixel_unshuffle_param_->output->dims(); + VLOG(1) << "x_dims: " << x_dims; + VLOG(1) << "out_dims: " << out_dims; + VLOG(1) << "downscale_factor: " << pixel_unshuffle_param_->downscale_factor; + + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + out_img_shape_ = default_convertor.InitImageDimInfoWith( + pixel_unshuffle_param_->output->dims()); + VLOG(1) << "out_img_shape_: " << out_img_shape_[0] << " " + << out_img_shape_[1]; + + // compute global work size + auto image_width = out_dims[3] * ((out_dims[1] + 3) / 4); + size_t work_size_0 = image_width / out_dims[3]; + size_t work_size_1 = out_dims[3]; + size_t work_size_2 = out_dims[0] * out_dims[2]; + global_work_size_ = cl::NDRange{work_size_0, work_size_1, work_size_2}; + VLOG(1) << "global_work_size_: " << global_work_size_[0] << " " + << global_work_size_[1] << " " << global_work_size_[2]; + } + } + + void Run() override { + auto* x_img = GET_DATA_GPU(pixel_unshuffle_param_->x); + auto* out_img = MUTABLE_DATA_GPU(pixel_unshuffle_param_->output, + out_img_shape_[0], + out_img_shape_[1], + nullptr); + + auto x_dims = pixel_unshuffle_param_->x->dims(); + + int in_n = x_dims[0]; + int in_c = x_dims[1]; + int in_h = x_dims[2]; + int in_w = x_dims[3]; + + auto out_dims = pixel_unshuffle_param_->output->dims(); + + int out_n = out_dims[0]; + int out_c = out_dims[1]; + int out_h = out_dims[2]; + int out_w = out_dims[3]; + + const int downscale_factor = pixel_unshuffle_param_->downscale_factor; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto kernel = kernel_; + cl_int status; + status = kernel.setArg(0, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(1, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(2, in_n); + CL_CHECK_FATAL(status); + status = kernel.setArg(3, in_c); + CL_CHECK_FATAL(status); + status = kernel.setArg(4, in_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(5, in_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(6, out_n); + CL_CHECK_FATAL(status); + status = kernel.setArg(7, out_c); + CL_CHECK_FATAL(status); + status = kernel.setArg(8, out_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(9, out_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(10, downscale_factor); + CL_CHECK_FATAL(status); + + + status = EnqueueNDRangeKernel(context, + kernel, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status); + } + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + ch->cl_event = + event_; // `event_` defined in `kernel.h`, valid after kernel::Run + } +#endif + private: + std::string kernel_func_name_{"pixel_unshuffle"}; + std::string build_options_{""}; + std::string time_stamp_{GetTimeStamp()}; + + param_t* pixel_unshuffle_param_{nullptr}; + cl::Kernel kernel_; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pixel_unshuffle, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::PixelUnShuffleComputeImage2D, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); \ No newline at end of file diff --git a/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc b/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc new file mode 100644 index 00000000000..22d30f850c5 --- /dev/null +++ b/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { + +TEST(pixel_unshuffle_image2d, compute) { + LOG(INFO) << "create kernel ..."; + auto kernels = KernelRegistry::Global().Create("pixel_unshuffle", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + + const int INPUT_N = 1; + const int INPUT_C = 1; + const int INPUT_H = 4; + const int INPUT_W = 4; + const int DOWNSCALE_FACTOR = 2; // 对应下采样因子 + + auto kernel = std::move(kernels.front()); + + LOG(INFO) << "prepare to test kernel ====> " << kernel->doc(); + + lite::Tensor x, out; + operators::PixelUnShuffleParam param; + param.x = &x; + param.output = &out; + param.downscale_factor = DOWNSCALE_FACTOR; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr pixel_unshuffle_context(new KernelContext); + context->As().CopySharedTo( + &(pixel_unshuffle_context->As())); + + kernel->SetContext(std::move(pixel_unshuffle_context)); + + const DDim in_dim = + DDim(std::vector{INPUT_N, INPUT_C, INPUT_H, INPUT_W}); + const DDim out_dim = DDim( + std::vector{INPUT_N, + INPUT_C * DOWNSCALE_FACTOR * DOWNSCALE_FACTOR, + INPUT_H / DOWNSCALE_FACTOR, + INPUT_W / DOWNSCALE_FACTOR}); + LOG(INFO) << "in_dim: " << in_dim; + LOG(INFO) << "DOWNSCALE_FACTOR: " << DOWNSCALE_FACTOR; + LOG(INFO) << "out_dim: " << out_dim; + + x.Resize(in_dim); + out.Resize(out_dim); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-2, 2); + std::vector input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W); + + int index = 0; + for (auto& i : input_v) { + i = index++; + } + VLOG(1) << "input_v ..... "; + for (size_t i = 0; i < input_v.size(); i++) { + VLOG(10) << input_v[i]; + } + + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim); + LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " " + << x_image_shape[1]; + std::vector x_image_data(x_image_shape.production() * 4); // 4 : RGBA + default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + x_image_shape[0], x_image_shape[1], x_image_data.data()); + VLOG(1) << "x_image_data ..... "; + for (size_t i = 0; i < x_image_data.size(); i++) { + VLOG(10) << Half2Float(x_image_data[i]); + } + DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim); + LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + auto* out_image = out.mutable_data(out_image_shape[0], + out_image_shape[1]); + kernel->Launch(); + CLRuntime::Global()->command_queue().finish(); + std::vector out_data_v{ + 0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}; + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + half_t* out_image_data = new half_t[out_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + VLOG(1) << "out_image_data ..... "; + for (size_t i = 0; i < out_image_shape.production() * 4; i++) { + VLOG(10) << Half2Float(out_image_data[i]); + } + float* out_data = new float[out_image_shape.production() * 4]; + default_converter->ImageToNCHW( + out_image_data, out_data, out_image_shape, out_dim); + + VLOG(1) << "out_data ..... "; + for (int i = 0; i < out_dim.production(); i++) { + VLOG(10) << out_data[i]; + } + + for (int i = 0; i < out_dim.production(); i++) { + auto abs_diff = abs(out_data[i] - out_data_v[i]); + auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]); + EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << " out_data[" << i + << "]:" << out_data[i] << " " + "out_ref[" + << i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff + << " relative_diff:" << relative_diff + << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; + } + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(pixel_unshuffle, kOpenCL, kFP16, kImageDefault, image2d); \ No newline at end of file From 29c4a7eb7d5835cd90c3f4e17522f6de9a486b59 Mon Sep 17 00:00:00 2001 From: chengxiangming <781326019@qq.com> Date: Thu, 31 Oct 2024 11:21:49 +0000 Subject: [PATCH 2/3] Add pixel_unshuffle opencl op support test=develop --- .../cl_kernel/image/pixel_unshuffle_kernel.cl | 148 ++++++++++++++ lite/kernels/opencl/CMakeLists.txt | 1 + .../opencl/pixel_unshuffle_image_compute.cc | 193 ++++++++++++++++++ .../pixel_unshuffle_image_compute_test.cc | 153 ++++++++++++++ 4 files changed, 495 insertions(+) create mode 100644 lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl create mode 100644 lite/kernels/opencl/pixel_unshuffle_image_compute.cc create mode 100644 lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl b/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl new file mode 100644 index 00000000000..948982944b0 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/pixel_unshuffle_kernel.cl @@ -0,0 +1,148 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +__kernel void pixel_unshuffle(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int in_N, + __private const int in_C, + __private const int in_H, + __private const int in_W, + __private const int out_N, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int downscale_factor) { + + const int in_c4 = get_global_id(0); + const int in_w = get_global_id(1); + const int in_nh = get_global_id(2); + + int in_h = in_nh % in_H; + int in_n = in_nh / in_H; + + CL_DTYPE4 res = (CL_DTYPE4)(0, 0, 0, 0); + CL_DTYPE4 in; + int in_c; + int out_c; + int offset; + int offset_h; + int offset_w; + int out_w; + int out_h; + int out_nh; + int2 out_pos; + int2 in_pos; + + in_c = in_c4 * 4 + 0; + out_c = in_c / (downscale_factor * downscale_factor); + offset = in_c % (downscale_factor * downscale_factor); + offset_h = offset / downscale_factor; + offset_w = offset % downscale_factor; + + out_w = in_w * downscale_factor + offset_w; + out_h = in_h * downscale_factor + offset_h; + out_nh = in_n * out_H + out_h; + + out_pos.x = out_w + (out_c / 4) * in_W; + out_pos.y = out_nh; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (out_c % 4 == 0) { + res.x = in.x; + } else if (out_c % 4 == 1) { + res.x = in.y; + } else if (out_c % 4 == 2) { + res.x = in.z; + } else if (out_c % 4 == 3) { + res.x = in.w; + } + + in_c = in_c4 * 4 + 1; + out_c = in_c / (downscale_factor * downscale_factor); + offset = in_c % (downscale_factor * downscale_factor); + offset_h = offset / downscale_factor; + offset_w = offset % downscale_factor; + + out_w = in_w * downscale_factor + offset_w; + out_h = in_h * downscale_factor + offset_h; + out_nh = in_n * out_H + out_h; + + out_pos.x = out_w + (out_c / 4) * in_W; + out_pos.y = out_nh; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (out_c % 4 == 0) { + res.y = in.x; + } else if (out_c % 4 == 1) { + res.y = in.y; + } else if (out_c % 4 == 2) { + res.y = in.z; + } else if (out_c % 4 == 3) { + res.y = in.w; + } + + in_c = in_c4 * 4 + 2; + out_c = in_c / (downscale_factor * downscale_factor); + offset = in_c % (downscale_factor * downscale_factor); + offset_h = offset / downscale_factor; + offset_w = offset % downscale_factor; + + out_w = in_w * downscale_factor + offset_w; + out_h = in_h * downscale_factor + offset_h; + out_nh = in_n * out_H + out_h; + + out_pos.x = out_w + (out_c / 4) * in_W; + out_pos.y = out_nh; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (out_c % 4 == 0) { + res.z = in.x; + } else if (out_c % 4 == 1) { + res.z = in.y; + } else if (out_c % 4 == 2) { + res.z = in.z; + } else if (out_c % 4 == 3) { + res.z = in.w; + } + + in_c = in_c4 * 4 + 3; + out_c = in_c / (downscale_factor * downscale_factor); + offset = in_c % (downscale_factor * downscale_factor); + offset_h = offset / downscale_factor; + offset_w = offset % downscale_factor; + + out_w = in_w * downscale_factor + offset_w; + out_h = in_h * downscale_factor + offset_h; + out_nh = in_n * out_H + out_h; + + out_pos.x = out_w + (out_c / 4) * in_W; + out_pos.y = out_nh; + + in = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, SAMPLER, out_pos); + if (out_c % 4 == 0) { + res.w = in.x; + } else if (out_c % 4 == 1) { + res.w = in.y; + } else if (out_c % 4 == 2) { + res.w = in.z; + } else if (out_c % 4 == 3) { + res.w = in.w; + } + + in_pos.x = in_c4 * (in_W / downscale_factor) + in_w; + in_pos.y = in_nh; + if (in_pos.x < out_W * ((out_C + 3) / 4) && in_pos.y < out_H * out_N) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, in_pos, res); + } +} \ No newline at end of file diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 857e08ca6d7..c752b5ad060 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -42,6 +42,7 @@ add_kernel(dropout_opencl_image OPENCL basic SRCS dropout_image_compute.cc) add_kernel(pad2d_opencl_image OPENCL basic SRCS pad2d_image_compute.cc) add_kernel(box_coder_opencl_image OPENCL basic SRCS box_coder_image_compute.cc) add_kernel(pixel_shuffle_opencl_image OPENCL basic SRCS pixel_shuffle_image_compute.cc) +add_kernel(pixel_unshuffle_opencl_image OPENCL basic SRCS pixel_unshuffle_image_compute.cc) add_kernel(expand_opencl_image OPENCL basic SRCS expand_image_compute.cc) add_kernel(shuffle_channel_opencl_image OPENCL basic SRCS shuffle_channel_image_compute.cc) add_kernel(trigonometric_opencl_image OPENCL basic SRCS trigonometric_image_compute.cc) diff --git a/lite/kernels/opencl/pixel_unshuffle_image_compute.cc b/lite/kernels/opencl/pixel_unshuffle_image_compute.cc new file mode 100644 index 00000000000..2c8fd66714b --- /dev/null +++ b/lite/kernels/opencl/pixel_unshuffle_image_compute.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif +#include "lite/backends/opencl/cl_utility.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class PixelUnShuffleComputeImage2D + : public KernelLite { + public: + using param_t = operators::PixelUnShuffleParam; + + std::string doc() const override { + return "PixelUnShuffle using cl::Image2D, kFP16"; + } + + void PrepareForRun() override { + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + + auto& context = ctx_->As(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/pixel_unshuffle_kernel.cl", + build_options_, + time_stamp_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } + + void ReInitWhenNeeded() override { + VLOG(1) << "ReInitWhenNeeded: " << kernel_func_name_; + pixel_unshuffle_param_ = param_.get_mutable(); + auto x_dims = pixel_unshuffle_param_->x->dims(); + auto out_dims = pixel_unshuffle_param_->output->dims(); + VLOG(1) << "x_dims: " << x_dims; + VLOG(1) << "out_dims: " << out_dims; + VLOG(1) << "downscale_factor: " << pixel_unshuffle_param_->downscale_factor; + + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + out_img_shape_ = default_convertor.InitImageDimInfoWith( + pixel_unshuffle_param_->output->dims()); + VLOG(1) << "out_img_shape_: " << out_img_shape_[0] << " " + << out_img_shape_[1]; + + // compute global work size + auto image_width = out_dims[3] * ((out_dims[1] + 3) / 4); + size_t work_size_0 = image_width / out_dims[3]; + size_t work_size_1 = out_dims[3]; + size_t work_size_2 = out_dims[0] * out_dims[2]; + global_work_size_ = cl::NDRange{work_size_0, work_size_1, work_size_2}; + VLOG(1) << "global_work_size_: " << global_work_size_[0] << " " + << global_work_size_[1] << " " << global_work_size_[2]; + } + } + + void Run() override { + auto* x_img = GET_DATA_GPU(pixel_unshuffle_param_->x); + auto* out_img = MUTABLE_DATA_GPU(pixel_unshuffle_param_->output, + out_img_shape_[0], + out_img_shape_[1], + nullptr); + + auto x_dims = pixel_unshuffle_param_->x->dims(); + + int in_n = x_dims[0]; + int in_c = x_dims[1]; + int in_h = x_dims[2]; + int in_w = x_dims[3]; + + auto out_dims = pixel_unshuffle_param_->output->dims(); + + int out_n = out_dims[0]; + int out_c = out_dims[1]; + int out_h = out_dims[2]; + int out_w = out_dims[3]; + + const int downscale_factor = pixel_unshuffle_param_->downscale_factor; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto kernel = kernel_; + cl_int status; + status = kernel.setArg(0, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(1, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(2, in_n); + CL_CHECK_FATAL(status); + status = kernel.setArg(3, in_c); + CL_CHECK_FATAL(status); + status = kernel.setArg(4, in_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(5, in_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(6, out_n); + CL_CHECK_FATAL(status); + status = kernel.setArg(7, out_c); + CL_CHECK_FATAL(status); + status = kernel.setArg(8, out_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(9, out_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(10, downscale_factor); + CL_CHECK_FATAL(status); + + + status = EnqueueNDRangeKernel(context, + kernel, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status); + } + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + ch->cl_event = + event_; // `event_` defined in `kernel.h`, valid after kernel::Run + } +#endif + private: + std::string kernel_func_name_{"pixel_unshuffle"}; + std::string build_options_{""}; + std::string time_stamp_{GetTimeStamp()}; + + param_t* pixel_unshuffle_param_{nullptr}; + cl::Kernel kernel_; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pixel_unshuffle, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::PixelUnShuffleComputeImage2D, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); \ No newline at end of file diff --git a/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc b/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc new file mode 100644 index 00000000000..51821d38244 --- /dev/null +++ b/lite/kernels/opencl/pixel_unshuffle_image_compute_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { + +TEST(pixel_unshuffle_image2d, compute) { + LOG(INFO) << "create kernel ..."; + auto kernels = KernelRegistry::Global().Create("pixel_unshuffle", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + + const int INPUT_N = 1; + const int INPUT_C = 1; + const int INPUT_H = 4; + const int INPUT_W = 4; + const int DOWNSCALE_FACTOR = 2; + + auto kernel = std::move(kernels.front()); + + LOG(INFO) << "prepare to test kernel ====> " << kernel->doc(); + + lite::Tensor x, out; + operators::PixelUnShuffleParam param; + param.x = &x; + param.output = &out; + param.downscale_factor = DOWNSCALE_FACTOR; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr pixel_unshuffle_context(new KernelContext); + context->As().CopySharedTo( + &(pixel_unshuffle_context->As())); + + kernel->SetContext(std::move(pixel_unshuffle_context)); + + const DDim in_dim = + DDim(std::vector{INPUT_N, INPUT_C, INPUT_H, INPUT_W}); + const DDim out_dim = DDim( + std::vector{INPUT_N, + INPUT_C * DOWNSCALE_FACTOR * DOWNSCALE_FACTOR, + INPUT_H / DOWNSCALE_FACTOR, + INPUT_W / DOWNSCALE_FACTOR}); + LOG(INFO) << "in_dim: " << in_dim; + LOG(INFO) << "DOWNSCALE_FACTOR: " << DOWNSCALE_FACTOR; + LOG(INFO) << "out_dim: " << out_dim; + + x.Resize(in_dim); + out.Resize(out_dim); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-2, 2); + std::vector input_v(INPUT_N * INPUT_C * INPUT_H * INPUT_W); + + int index = 0; + for (auto& i : input_v) { + i = index++; + } + VLOG(1) << "input_v ..... "; + for (size_t i = 0; i < input_v.size(); i++) { + VLOG(10) << input_v[i]; + } + + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim); + LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " " + << x_image_shape[1]; + std::vector x_image_data(x_image_shape.production() * 4); // 4 : RGBA + default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + x_image_shape[0], x_image_shape[1], x_image_data.data()); + VLOG(1) << "x_image_data ..... "; + for (size_t i = 0; i < x_image_data.size(); i++) { + VLOG(10) << Half2Float(x_image_data[i]); + } + DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim); + LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + auto* out_image = out.mutable_data(out_image_shape[0], + out_image_shape[1]); + kernel->Launch(); + CLRuntime::Global()->command_queue().finish(); + std::vector out_data_v{ + 0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}; + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + half_t* out_image_data = new half_t[out_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + VLOG(1) << "out_image_data ..... "; + for (size_t i = 0; i < out_image_shape.production() * 4; i++) { + VLOG(10) << Half2Float(out_image_data[i]); + } + float* out_data = new float[out_image_shape.production() * 4]; + default_converter->ImageToNCHW( + out_image_data, out_data, out_image_shape, out_dim); + + VLOG(1) << "out_data ..... "; + for (int i = 0; i < out_dim.production(); i++) { + VLOG(10) << out_data[i]; + } + + for (int i = 0; i < out_dim.production(); i++) { + auto abs_diff = abs(out_data[i] - out_data_v[i]); + auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_data_v[i]); + EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << " out_data[" << i + << "]:" << out_data[i] << " " + "out_ref[" + << i << "]:" << out_data_v[i] << " abs_diff:" << abs_diff + << " relative_diff:" << relative_diff + << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; + } + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(pixel_unshuffle, kOpenCL, kFP16, kImageDefault, image2d); \ No newline at end of file From 7c7b2e05ac18733509cf3666df8ce611937a45e7 Mon Sep 17 00:00:00 2001 From: chengxiangming <781326019@qq.com> Date: Thu, 21 Nov 2024 09:33:10 +0000 Subject: [PATCH 3/3] [Metal] Add pixel_unshuffle & pixel_shuffle metal support test=develop --- .../metal/metal_kernel/texture/Common.metal | 4 + .../metal_kernel/texture/PixelShuffle.metal | 32 +----- .../metal_kernel/texture/PixelUnShuffle.metal | 47 ++++++++ lite/kernels/metal/CMakeLists.txt | 3 +- lite/kernels/metal/image_op/metal_params.h | 4 + .../image_op/pixel_shuffle_image_compute.h | 21 ++-- .../image_op/pixel_shuffle_image_compute.mm | 89 ++++++++------- .../image_op/pixel_unshuffle_image_compute.h | 71 ++++++++++++ .../image_op/pixel_unshuffle_image_compute.mm | 108 ++++++++++++++++++ 9 files changed, 299 insertions(+), 80 deletions(-) create mode 100644 lite/backends/metal/metal_kernel/texture/PixelUnShuffle.metal create mode 100644 lite/kernels/metal/image_op/pixel_unshuffle_image_compute.h create mode 100644 lite/kernels/metal/image_op/pixel_unshuffle_image_compute.mm diff --git a/lite/backends/metal/metal_kernel/texture/Common.metal b/lite/backends/metal/metal_kernel/texture/Common.metal index b00b048bbcc..e0c3b17fa0c 100644 --- a/lite/backends/metal/metal_kernel/texture/Common.metal +++ b/lite/backends/metal/metal_kernel/texture/Common.metal @@ -163,6 +163,10 @@ struct PixelShuffleParam { int32_t upscale_factor; }; +struct PixelUnShuffleParam { + int32_t downscale_factor; +}; + struct ExpandParam { ushort fast; ushort c; diff --git a/lite/backends/metal/metal_kernel/texture/PixelShuffle.metal b/lite/backends/metal/metal_kernel/texture/PixelShuffle.metal index e2f05d526b2..66971b3acea 100644 --- a/lite/backends/metal/metal_kernel/texture/PixelShuffle.metal +++ b/lite/backends/metal/metal_kernel/texture/PixelShuffle.metal @@ -17,8 +17,8 @@ #include "Common.metal" using namespace metal; -kernel void pixel_shuffle(texture2d_array inTexture[[texture(0)]], - texture2d_array outTexture[[texture(1)]], +kernel void pixel_shuffle(texture2d_array inTexture[[texture(0)]], + texture2d_array outTexture[[texture(1)]], constant PixelShuffleParam& param[[buffer(0)]], uint3 gid[[thread_position_in_grid]]) { if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || @@ -26,41 +26,17 @@ kernel void pixel_shuffle(texture2d_array inTexture[[text return; constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); - int upscale_factor = param.upscale_factor; - int inX = gid.x / upscale_factor; - int inY = gid.y / upscale_factor; - - float4 res; - for (int i = 0; i < 4; i++) { - int c = gid.z * 4 + i; - int inC = c * upscale_factor * upscale_factor + (gid.y % upscale_factor) * upscale_factor + - gid.x % upscale_factor; - float4 input = inTexture.read(uint2(inX, inY), inC / 4); - res[i] = input[inC % 4]; - } - - outTexture.write(res, gid.xy, gid.z); -} - -kernel void pixel_shuffle_half(texture2d_array inTexture[[texture(0)]], - texture2d_array outTexture[[texture(1)]], - constant PixelShuffleParam& param[[buffer(0)]], - uint3 gid[[thread_position_in_grid]]) { - if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || - gid.z >= outTexture.get_array_size()) - return; - constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); int upscale_factor = param.upscale_factor; int inX = gid.x / upscale_factor; int inY = gid.y / upscale_factor; - half4 res; + ftype4 res; for (int i = 0; i < 4; i++) { int c = gid.z * 4 + i; int inC = c * upscale_factor * upscale_factor + (gid.y % upscale_factor) * upscale_factor + gid.x % upscale_factor; - half4 input = inTexture.read(uint2(inX, inY), inC / 4); + ftype4 input = inTexture.read(uint2(inX, inY), inC / 4); res[i] = input[inC % 4]; } diff --git a/lite/backends/metal/metal_kernel/texture/PixelUnShuffle.metal b/lite/backends/metal/metal_kernel/texture/PixelUnShuffle.metal new file mode 100644 index 00000000000..067a71f916b --- /dev/null +++ b/lite/backends/metal/metal_kernel/texture/PixelUnShuffle.metal @@ -0,0 +1,47 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "Common.metal" +#include +using namespace metal; + +kernel void pixel_unshuffle(texture2d_array inTexture[[texture(0)]], + texture2d_array outTexture[[texture(1)]], + constant PixelUnShuffleParam& param[[buffer(0)]], + uint3 gid[[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) + return; + + int downscale_factor = param.downscale_factor; + int outX = gid.x * downscale_factor; + int outY = gid.y * downscale_factor; + + ftype4 res = ftype4(0.0); + + for (int i = 0; i < 4; i++) { + int c = gid.z * 4 + i; + int outC = c / (downscale_factor * downscale_factor); + int offset = c % (downscale_factor * downscale_factor); + int offset_h = offset / downscale_factor; + int offset_w = offset % downscale_factor; + + int readX = outX + offset_w; + int readY = outY + offset_h; + + ftype4 input = inTexture.read(uint2(readX, readY), outC / 4); + res[i] = input[outC % 4]; + } + outTexture.write(res, gid.xy, gid.z); +} diff --git a/lite/kernels/metal/CMakeLists.txt b/lite/kernels/metal/CMakeLists.txt index 2662810ba35..a7f6b1dda62 100644 --- a/lite/kernels/metal/CMakeLists.txt +++ b/lite/kernels/metal/CMakeLists.txt @@ -53,7 +53,8 @@ add_kernel(yolo_box_metal_image METAL basic SRCS image_op/yolo_box_image_compute # add_kernel(grid_sampler_metal_image METAL basic SRCS image_op/grid_sampler_image_compute.mm) # add_kernel(instance_norm_metal_image METAL basic SRCS image_op/instance_norm_image_compute.mm) # add_kernel(lrn_metal_image METAL basic SRCS image_op/lrn_image_compute.mm) -# add_kernel(pixel_shuffle_metal_image METAL basic SRCS image_op/pixel_shuffle_image_compute.mm) +add_kernel(pixel_shuffle_metal_image METAL basic SRCS image_op/pixel_shuffle_image_compute.mm) +add_kernel(pixel_unshuffle_metal_image METAL basic SRCS image_op/pixel_unshuffle_image_compute.mm) # add_kernel(rsqrt_metal_image METAL basic SRCS image_op/rsqrt_image_compute.mm) # add_kernel(tanh_metal_image METAL basic SRCS image_op/tanh_image_compute.mm) # add_kernel(prior_box_metal_image METAL basic SRCS image_op/prior_box_image_compute.mm) diff --git a/lite/kernels/metal/image_op/metal_params.h b/lite/kernels/metal/image_op/metal_params.h index a14dda44dad..dc8c6ccea84 100644 --- a/lite/kernels/metal/image_op/metal_params.h +++ b/lite/kernels/metal/image_op/metal_params.h @@ -163,6 +163,10 @@ struct PixelShuffleMetalParam { int upscale_factor; }; +struct PixelUnShuffleMetalParam { + int downscale_factor; +}; + struct ShuffleChannelMetalParam { uint32_t group; uint32_t channel_per_group; diff --git a/lite/kernels/metal/image_op/pixel_shuffle_image_compute.h b/lite/kernels/metal/image_op/pixel_shuffle_image_compute.h index eef98266a26..20a4cdb06ed 100644 --- a/lite/kernels/metal/image_op/pixel_shuffle_image_compute.h +++ b/lite/kernels/metal/image_op/pixel_shuffle_image_compute.h @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once +#ifndef LITE_KERNELS_METAL_IMAGE_OP_PIXEL_SHUFFLE_IMAGE_COMPUTE_H_ +#define LITE_KERNELS_METAL_IMAGE_OP_PIXEL_SHUFFLE_IMAGE_COMPUTE_H_ #include @@ -32,9 +33,8 @@ namespace lite { namespace kernels { namespace metal { -template class PixelShuffleImageCompute - : public KernelLite { + : public KernelLite { using param_t = operators::PixelShuffleParam; public: @@ -43,14 +43,18 @@ class PixelShuffleImageCompute void SaveOutput() override { MetalDebug::SaveOutput("pixel_shuffle", output_buffer_); }; + virtual ~PixelShuffleImageCompute(); private: + void run_without_mps(); + void setup_without_mps(); + const MetalImage* input_buffer_; - MetalImage* output_buffer_; - std::shared_ptr param_buffer_; - std::shared_ptr kernel_; - std::shared_ptr queue_; - std::shared_ptr encoder_; + MetalImage* output_buffer_{nullptr}; + std::shared_ptr params_buffer_; + + id pipline_; + std::string function_name_; MetalContext* metal_context_; }; @@ -58,3 +62,4 @@ class PixelShuffleImageCompute } // namespace kernels } // namespace lite } // namespace paddle +#endif LITE_KERNELS_METAL_IMAGE_OP_PIXEL_SHUFFLE_IMAGE_COMPUTE_H_ diff --git a/lite/kernels/metal/image_op/pixel_shuffle_image_compute.mm b/lite/kernels/metal/image_op/pixel_shuffle_image_compute.mm index 0e352421ec6..e211535a140 100644 --- a/lite/kernels/metal/image_op/pixel_shuffle_image_compute.mm +++ b/lite/kernels/metal/image_op/pixel_shuffle_image_compute.mm @@ -13,7 +13,9 @@ // limitations under the License. #include "lite/kernels/metal/image_op/pixel_shuffle_image_compute.h" +#include "lite/backends/metal/metal_context_imp.h" #include "lite/core/op_registry.h" +#include "lite/core/tensor.h" #include "lite/kernels/metal/image_op/metal_params.h" namespace paddle { @@ -21,51 +23,54 @@ namespace kernels { namespace metal { -template -void PixelShuffleImageCompute::PrepareForRun() { - auto& context = this->ctx_->template As(); +void PixelShuffleImageCompute::PrepareForRun() { + auto& context = ctx_->As(); metal_context_ = (MetalContext*)context.context(); - auto device = metal_context_->GetDefaultDevice(); - const auto& param = this->template Param(); + const auto& param = this->Param(); auto output_dims = param.output->dims(); - input_buffer_ = param.x->template data(); - output_buffer_ = param.output->template mutable_data(output_dims); - - PixelShuffleMetalParam metal_param{param.upscale_factor}; - param_buffer_ = metal_context_->CreateBuffer( - *device, &metal_param, sizeof(metal_param), METAL_ACCESS_FLAG::CPUWriteOnly); +#ifdef LITE_WITH_METAL_FULL +#else + output_buffer_ = param.output->mutable_data(metal_context_, output_dims); + input_buffer_ = param.x->data(); +#endif + setup_without_mps(); +} - std::string function_name = ""; - if (std::is_same::value) { - function_name = "pixel_shuffle"; - } else if (std::is_same::value) { - function_name = "pixel_shuffle_half"; +void PixelShuffleImageCompute::Run() { + @autoreleasepool { + run_without_mps(); } - assert(!function_name.empty()); - - kernel_ = metal_context_->GetKernel(*device, function_name); - queue_ = metal_context_->GetDefaultQueue(*device); } -template -void PixelShuffleImageCompute::Run() { - auto output_width = output_buffer_->texture_width_; - auto output_height = output_buffer_->texture_height_; - auto output_array_length = output_buffer_->array_length_; +void PixelShuffleImageCompute::run_without_mps() { + auto pipline = pipline_; + auto outTexture = output_buffer_->image(); + auto backend = (__bridge MetalContextImp*)metal_context_->backend(); + + auto encoder = [backend commandEncoder]; + [encoder setTexture:input_buffer_->image() atIndex:(0)]; + [encoder setTexture:output_buffer_->image() atIndex:(1)]; + [encoder setBuffer:params_buffer_->buffer() offset:(0) atIndex:(0)]; + + [backend dispatchEncoder:encoder pipline:pipline outTexture:outTexture]; + [backend commit]; +} - auto encoder = - std::make_shared(metal_context_->cmd_buf_.get(), &kernel_->program_); - MetalUint3 global_work_size = {static_cast(output_width), - static_cast(output_height), - static_cast(output_array_length)}; +void PixelShuffleImageCompute::setup_without_mps() { + const auto& param = this->Param(); - [encoder->metal_command_encoder_ setTexture:(input_buffer_->image()) atIndex:(0)]; - [encoder->metal_command_encoder_ setTexture:(output_buffer_->image()) atIndex:(1)]; - [encoder->metal_command_encoder_ setBuffer:(param_buffer_->buffer()) offset:(0) atIndex:(0)]; + PixelShuffleMetalParam params{param.upscale_factor}; + params_buffer_ = std::make_shared(metal_context_, sizeof(params), ¶ms); + function_name_ = "pixel_shuffle"; + // pipline + auto backend = (__bridge MetalContextImp*)metal_context_->backend(); + pipline_ = [backend pipline:function_name_]; +} - kernel_->Execute(*encoder, global_work_size, false); +PixelShuffleImageCompute::~PixelShuffleImageCompute() { + TargetWrapperMetal::FreeImage(output_buffer_); } } // namespace metal @@ -73,18 +78,11 @@ } // namespace lite } // namespace paddle -template class paddle::lite::kernels::metal::PixelShuffleImageCompute; -template class paddle::lite::kernels::metal::PixelShuffleImageCompute; -typedef paddle::lite::kernels::metal::PixelShuffleImageCompute - MetalPixelShuffleFp32; -typedef paddle::lite::kernels::metal::PixelShuffleImageCompute - MetalPixelShuffleFp16; - REGISTER_LITE_KERNEL(pixel_shuffle, kMetal, kFloat, kMetalTexture2DArray, - MetalPixelShuffleFp32, + paddle::lite::kernels::metal::PixelShuffleImageCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kMetal), @@ -96,7 +94,12 @@ DATALAYOUT(kMetalTexture2DArray))}) .Finalize(); -REGISTER_LITE_KERNEL(pixel_shuffle, kMetal, kFP16, kMetalTexture2DArray, MetalPixelShuffleFp16, def) +REGISTER_LITE_KERNEL(pixel_shuffle, + kMetal, + kFP16, + kMetalTexture2DArray, + paddle::lite::kernels::metal::PixelShuffleImageCompute, + def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))}) .BindOutput("Out", diff --git a/lite/kernels/metal/image_op/pixel_unshuffle_image_compute.h b/lite/kernels/metal/image_op/pixel_unshuffle_image_compute.h new file mode 100644 index 00000000000..b123d99a15e --- /dev/null +++ b/lite/kernels/metal/image_op/pixel_unshuffle_image_compute.h @@ -0,0 +1,71 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LITE_KERNELS_METAL_IMAGE_OP_PIXEL_UNSHUFFLE_IMAGE_COMPUTE_H_ +#define LITE_KERNELS_METAL_IMAGE_OP_PIXEL_UNSHUFFLE_IMAGE_COMPUTE_H_ + +#include + +#include "lite/core/kernel.h" +#include "lite/core/tensor.h" +#include "lite/operators/op_params.h" + +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif + +#include "lite/backends/metal/metal_context.h" +#include "lite/backends/metal/metal_debug.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace metal { + +class PixelUnShuffleImageCompute + : public KernelLite { + using param_t = operators::PixelUnShuffleParam; + + public: + void PrepareForRun() override; + void Run() override; + void SaveOutput() override { + MetalDebug::SaveOutput("pixel_unshuffle", output_buffer_); + }; + virtual ~PixelUnShuffleImageCompute(); + + private: + bool use_mps_{false}; + void* mps_pool_op_{nullptr}; + void* mps_input_image_{nullptr}; + void* mps_output_image_{nullptr}; + + void run_without_mps(); + void setup_without_mps(); + + const MetalImage* input_buffer_; + MetalImage* output_buffer_{nullptr}; + std::shared_ptr params_buffer_; + + id pipline_; + std::string function_name_; + MetalContext* metal_context_; +}; + +} // namespace metal +} // namespace kernels +} // namespace lite +} // namespace paddle + +#endif LITE_KERNELS_METAL_IMAGE_OP_PIXEL_UNSHUFFLE_IMAGE_COMPUTE_H_ diff --git a/lite/kernels/metal/image_op/pixel_unshuffle_image_compute.mm b/lite/kernels/metal/image_op/pixel_unshuffle_image_compute.mm new file mode 100644 index 00000000000..4263193a61f --- /dev/null +++ b/lite/kernels/metal/image_op/pixel_unshuffle_image_compute.mm @@ -0,0 +1,108 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/metal/image_op/pixel_unshuffle_image_compute.h" +#include "lite/backends/metal/metal_context_imp.h" +#include "lite/backends/metal/metal_debug.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/metal/image_op/metal_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace metal { + +void PixelUnShuffleImageCompute::PrepareForRun() { + auto& context = ctx_->As(); + metal_context_ = (MetalContext*)context.context(); + + const auto& param = this->Param(); + auto output_dims = param.output->dims(); + +#ifdef LITE_WITH_METAL_FULL +#else + output_buffer_ = param.output->mutable_data(metal_context_, output_dims); + input_buffer_ = param.x->data(); +#endif + setup_without_mps(); +} + +void PixelUnShuffleImageCompute::Run() { + @autoreleasepool { + run_without_mps(); + } +} + +void PixelUnShuffleImageCompute::run_without_mps() { + auto pipline = pipline_; + auto outTexture = output_buffer_->image(); + auto backend = (__bridge MetalContextImp*)metal_context_->backend(); + + auto encoder = [backend commandEncoder]; + [encoder setTexture:input_buffer_->image() atIndex:(0)]; + [encoder setTexture:output_buffer_->image() atIndex:(1)]; + [encoder setBuffer:params_buffer_->buffer() offset:(0) atIndex:(0)]; + + [backend dispatchEncoder:encoder pipline:pipline outTexture:outTexture]; + [backend commit]; +} + +void PixelUnShuffleImageCompute::setup_without_mps() { + const auto& param = this->Param(); + + PixelShuffleMetalParam params{param.downscale_factor}; + params_buffer_ = std::make_shared(metal_context_, sizeof(params), ¶ms); + function_name_ = "pixel_unshuffle"; + // pipline + auto backend = (__bridge MetalContextImp*)metal_context_->backend(); + pipline_ = [backend pipline:function_name_]; +} + +PixelUnShuffleImageCompute::~PixelUnShuffleImageCompute() { + TargetWrapperMetal::FreeImage(output_buffer_); +} + +} // namespace metal +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pixel_unshuffle, + kMetal, + kFloat, + kMetalTexture2DArray, + paddle::lite::kernels::metal::PixelUnShuffleImageCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kMetal), + PRECISION(kFloat), + DATALAYOUT(kMetalTexture2DArray))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kMetal), + PRECISION(kFloat), + DATALAYOUT(kMetalTexture2DArray))}) + .Finalize(); + +REGISTER_LITE_KERNEL(pixel_unshuffle, + kMetal, + kFP16, + kMetalTexture2DArray, + paddle::lite::kernels::metal::PixelUnShuffleImageCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kMetal), PRECISION(kFP16), DATALAYOUT(kMetalTexture2DArray))}) + .Finalize();