From 8c21bf6b77aef79a1b479cea5192ffa54d142b3f Mon Sep 17 00:00:00 2001 From: liuduanhui <103939338+DanieeelLiu@users.noreply.github.com> Date: Fri, 2 Aug 2024 16:04:17 +0800 Subject: [PATCH] [Feature] Support MsDeformAttnForward with fast kernel (#3157) --- .../mlu/ms_deform_attn_fast_mlu_kernel.hpp | 23 + .../mlu/ms_deform_attn_fast_mlu_kernel.mlu | 823 ++++++++++++++++++ .../csrc/pytorch/mlu/ms_deform_attn_mlu.cpp | 28 +- 3 files changed, 873 insertions(+), 1 deletion(-) create mode 100644 mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp create mode 100644 mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp new file mode 100644 index 0000000000..45da263fea --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.hpp @@ -0,0 +1,23 @@ +/************************************************************************* + * Copyright (C) [2024] by Cambricon, Inc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_ +#define MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_ +void KernelMsDeformAttnForwardFast( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char *data_value_gdram, + const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram); +#endif // MS_DEFORM_ATTN_FORWARD_FAST_MLU_KERNEL_HPP_ diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu new file mode 100644 index 0000000000..20cac71672 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_fast_mlu_kernel.mlu @@ -0,0 +1,823 @@ +/************************************************************************* + * Copyright (C) [2024] by Cambricon, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "common_mlu_helper.hpp" +#include "ms_deform_attn_fast_mlu_kernel.hpp" + +#define NRAM_REMAIN_SIZE (48 * 1024) +#define NRAM_AVALIABLE_SIZE (__MLU_NRAM_SIZE__ * 1024 - NRAM_REMAIN_SIZE) + +#define SRAM_REMAIN_SIZE (32 * 1024) +#define SRAM_AVALIABLE_SIZE (__MLU_SRAM_SIZE__ * 1024 - SRAM_REMAIN_SIZE) +#define SRAM_FOR_VALUE_SIZE (SRAM_AVALIABLE_SIZE - 128) + +#define MAX_MEMCPY_SEGNUM 65536 + +__nram__ char nram_buffer[NRAM_AVALIABLE_SIZE]; +__mlu_shared__ char sram_buffer[SRAM_AVALIABLE_SIZE]; + +template +__mlu_func__ inline T __mluop_min(T a, T b) { + return a < b ? a : b; +} + +template +__mlu_func__ inline T __mluop_max(T a, T b) { + return a > b ? a : b; +} + +template +__mlu_func__ void __mluop_floor(T* dst_ram, T* src_ram, int size) { + if (sizeof(T) == sizeof(float)) { + int16* mid = (int16*)(dst_ram + size / 2); + __bang_float2int16_dn(mid, (float*)src_ram, size, 0); + __bang_int162float((float*)dst_ram, (int16_t*)mid, size, 0); + } else { + __bang_half2int16_dn((int16_t*)dst_ram, (half*)src_ram, size, 0); + __bang_int162half((half*)dst_ram, (int16_t*)dst_ram, size, 0); + } +} + +__mlu_func__ void broadcastSpatialHW( + float* spatial_offset_bd_nram, // (num_levels, num_points) + float* spatial_h_bd_nram, // (num_levels, num_points) + float* spatial_w_bd_nram, // (num_levels, num_points) + int32_t* spatial_shapes_nram, // (num_levels, 2) + int32_t* spatial_offset_nram, // (num_levels) + const int32_t num_levels, const int32_t num_points) { + for (int i = 0; i < num_levels * 2; i++) { + ((float*)spatial_shapes_nram)[i] = (float)spatial_shapes_nram[i]; + } + + for (int i = 0; i < num_levels; i++) { + ((float*)spatial_offset_nram)[i] = (float)spatial_offset_nram[i]; + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_h_bd_nram + i * num_points, spatial_shapes_nram + i * 2, + sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_w_bd_nram + i * num_points, + spatial_shapes_nram + 1 + i * 2, sizeof(float), NRAM2NRAM, + sizeof(float), 0, num_points - 1); + } + + for (int i = 0; i < num_levels; i++) { + __memcpy(spatial_offset_bd_nram + i * num_points, spatial_offset_nram + i, + sizeof(float), NRAM2NRAM, sizeof(float), 0, num_points - 1); + } +} + +template +__mlu_func__ void getConditionCoordWeight( + int32_t* data_offset_nram, T* weight_polation_nram, + T* cond_point_polation_nram, T* cond_point_valid_nram, T* loc_nram, + T* weight_attn_nram, T* spatial_offset_bd_nram, T* spatial_w_bd_nram, + T* spatial_h_bd_nram, T* buf_nram, const int32_t deal_n, + const int32_t num_levels, const int32_t num_points, const int32_t num_heads, + int32_t pad_num_levels_points) { + int32_t pad_total_points = deal_n * pad_num_levels_points; + int32_t pad_block_points = pad_num_levels_points; + T* buf_x_nram = buf_nram; + T* buf_y_nram = buf_nram + pad_total_points; + T* buf_cond_nram = buf_nram + 2 * pad_total_points; + T* buf_x_floor = buf_nram + 2 * pad_total_points; + T* buf_y_floor = buf_nram + 3 * pad_total_points; + T* buf_x_ceil = buf_nram + 4 * pad_total_points; + T* buf_y_ceil = buf_nram + 5 * pad_total_points; + + __bang_write_value(buf_x_nram, pad_total_points, 0); + __bang_write_value(buf_y_nram, pad_total_points, 0); + __bang_write_value(buf_x_floor, pad_total_points, 0); + __bang_write_value(buf_x_ceil, pad_total_points, 0); + __bang_write_value(buf_y_floor, pad_total_points, 0); + __bang_write_value(buf_y_ceil, pad_total_points, 0); + + //================================================================================================ + __memcpy(buf_x_nram, loc_nram, sizeof(T), NRAM2NRAM, sizeof(T), 2 * sizeof(T), + pad_total_points - 1); + __memcpy(buf_y_nram, loc_nram + 1, sizeof(T), NRAM2NRAM, sizeof(T), + 2 * sizeof(T), pad_total_points - 1); + + // x = loc_x * spatial_w - 0.5; y = loc_y * spatial_h - 0.5; + __bang_cycle_mul(buf_x_nram, buf_x_nram, spatial_w_bd_nram, pad_total_points, + pad_block_points); + __bang_sub_scalar(buf_x_nram, buf_x_nram, (T)0.5, pad_total_points); + __bang_cycle_mul(buf_y_nram, buf_y_nram, spatial_h_bd_nram, pad_total_points, + pad_block_points); + __bang_sub_scalar(buf_y_nram, buf_y_nram, (T)0.5, pad_total_points); + + //================================================================================================ + // get point condition. use buf0, buf1, buf2 + // (x > -1 && y > -1 && y < spatial_h && x < spatial_w) + __bang_write_value(cond_point_valid_nram, pad_total_points, (T)-1.0); + __bang_gt(cond_point_valid_nram, buf_x_nram, cond_point_valid_nram, + pad_total_points); + __bang_write_value(buf_cond_nram, pad_total_points, (T)-1.0); + __bang_gt(buf_cond_nram, buf_y_nram, buf_cond_nram, pad_total_points); + + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + __bang_cycle_lt(buf_cond_nram, buf_x_nram, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + __bang_cycle_lt(buf_cond_nram, buf_y_nram, spatial_h_bd_nram, + pad_total_points, pad_block_points); + __bang_and(cond_point_valid_nram, cond_point_valid_nram, buf_cond_nram, + pad_total_points); + //================================================================================================ + __mluop_floor(buf_x_floor, buf_x_nram, 2 * pad_total_points); + __bang_add_scalar(buf_x_ceil, buf_x_floor, 1.0, pad_total_points); + __bang_add_scalar(buf_y_ceil, buf_y_floor, 1.0, pad_total_points); + + T* cond_point_polation_nram_tl = cond_point_polation_nram; + T* cond_point_polation_nram_bl = cond_point_polation_nram + pad_total_points; + T* cond_point_polation_nram_tr = + cond_point_polation_nram + 2 * pad_total_points; + T* cond_point_polation_nram_br = + cond_point_polation_nram + 3 * pad_total_points; + T* cond_point_polation_nram_cond1 = weight_polation_nram; + T* cond_point_polation_nram_cond2 = weight_polation_nram + pad_total_points; + T* cond_point_polation_nram_cond3 = + weight_polation_nram + 2 * pad_total_points; + T* cond_point_polation_nram_cond4 = + weight_polation_nram + 3 * pad_total_points; + __bang_ge_scalar(cond_point_polation_nram_cond1, buf_x_floor, (T)0, + pad_total_points); + __bang_cycle_lt(cond_point_polation_nram_cond2, buf_x_ceil, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_ge_scalar(cond_point_polation_nram_cond3, buf_y_floor, (T)0, + pad_total_points); + __bang_cycle_lt(cond_point_polation_nram_cond4, buf_y_ceil, spatial_h_bd_nram, + pad_total_points, pad_block_points); + __bang_and(cond_point_polation_nram_tl, cond_point_polation_nram_cond1, + cond_point_polation_nram_cond4, pad_total_points); + __bang_and(cond_point_polation_nram_bl, cond_point_polation_nram_cond1, + cond_point_polation_nram_cond3, pad_total_points); + __bang_and(cond_point_polation_nram_tr, cond_point_polation_nram_cond2, + cond_point_polation_nram_cond4, pad_total_points); + __bang_and(cond_point_polation_nram_br, cond_point_polation_nram_cond2, + cond_point_polation_nram_cond3, pad_total_points); + //================================================================================================ + // get polation weight. + T* buf_dx = (T*)data_offset_nram; + T* buf_dy = buf_dx + pad_total_points; + T* buf_dx_1 = buf_dy + pad_total_points; + T* buf_dy_1 = buf_dx_1 + pad_total_points; + // -dx = x_floor-x + // -dy = y_floor-y + // w1 = (1-dx)*dy = (dx-1)*(-dy) + // w2 = (1-dx)*(1-dy) = (dx-1)*(dy-1) + // w3 = dx*dy = (-dx)*(-dy) + // w4 = dx*(1-dy) = (-dx)*(dy-1) + T* weight_polation_nram_1 = weight_polation_nram; + T* weight_polation_nram_2 = weight_polation_nram + 1 * pad_total_points; + T* weight_polation_nram_3 = weight_polation_nram + 2 * pad_total_points; + T* weight_polation_nram_4 = weight_polation_nram + 3 * pad_total_points; + // T* weight_polation_nram_buf = buf_nram + 4 * total_points; + __bang_sub(buf_dx, buf_x_floor, buf_x_nram, pad_total_points); + __bang_sub(buf_dy, buf_y_floor, buf_y_nram, pad_total_points); + + __bang_sub(buf_dx_1, buf_x_nram, buf_x_floor, pad_total_points); + __bang_sub_scalar(buf_dx_1, buf_dx_1, (T)1.0, pad_total_points); + + __bang_sub(buf_dy_1, buf_y_nram, buf_y_floor, pad_total_points); + __bang_sub_scalar(buf_dy_1, buf_dy_1, (T)1.0, pad_total_points); + + __bang_mul(weight_polation_nram_1, buf_dx_1, buf_dy, pad_total_points); + __bang_mul(weight_polation_nram_2, buf_dx_1, buf_dy_1, pad_total_points); + __bang_mul(weight_polation_nram_3, buf_dx, buf_dy, pad_total_points); + __bang_mul(weight_polation_nram_4, buf_dx, buf_dy_1, pad_total_points); + //================================================================================================ + // correct the x,y in [0, w-1] and [0, h-1] + T* spatial_w1_bd_nram = buf_nram; + T* spatial_h1_bd_nram = buf_nram + pad_total_points; + __bang_sub_scalar(spatial_w1_bd_nram, spatial_w_bd_nram, (T)1, + pad_total_points); + __bang_sub_scalar(spatial_h1_bd_nram, spatial_h_bd_nram, (T)1, + pad_total_points); + T* maxtemp = (T*)data_offset_nram; + __bang_write_value(maxtemp, pad_total_points, (T)0); + __bang_maxequal(buf_x_floor, buf_x_floor, maxtemp, pad_total_points); + __bang_maxequal(buf_x_ceil, buf_x_ceil, maxtemp, pad_total_points); + __bang_cycle_minequal(buf_x_floor, buf_x_floor, spatial_w1_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_minequal(buf_x_ceil, buf_x_ceil, spatial_w1_bd_nram, + pad_total_points, pad_block_points); + __bang_maxequal(buf_y_floor, buf_y_floor, maxtemp, pad_total_points); + __bang_maxequal(buf_y_ceil, buf_y_ceil, maxtemp, pad_total_points); + __bang_cycle_minequal(buf_y_floor, buf_y_floor, spatial_h1_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_minequal(buf_y_ceil, buf_y_ceil, spatial_h1_bd_nram, + pad_total_points, pad_block_points); + //================================================================================================ + // offset = y*w + x + T* buf_hw_offset = buf_nram; + T* data_offset_nram_tl = (T*)data_offset_nram; + T* data_offset_nram_bl = data_offset_nram_tl + pad_total_points; + T* data_offset_nram_tr = data_offset_nram_bl + pad_total_points; + T* data_offset_nram_br = data_offset_nram_tr + pad_total_points; + // y_ceil*w + offset + x_floor + __bang_cycle_mul(buf_hw_offset, buf_y_ceil, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + pad_total_points, pad_block_points); + __bang_add(data_offset_nram_tl, buf_hw_offset, buf_x_floor, pad_total_points); + // y_ceil*w + offset + x_ceil + __bang_add(data_offset_nram_tr, buf_hw_offset, buf_x_ceil, pad_total_points); + // y_floor*w + offset + x_foor + __bang_cycle_mul(buf_hw_offset, buf_y_floor, spatial_w_bd_nram, + pad_total_points, pad_block_points); + __bang_cycle_add(buf_hw_offset, buf_hw_offset, spatial_offset_bd_nram, + pad_total_points, pad_block_points); + + __bang_add(data_offset_nram_bl, buf_hw_offset, buf_x_floor, pad_total_points); + // y_floor*w + offset + x_ceil + __bang_add(data_offset_nram_br, buf_hw_offset, buf_x_ceil, pad_total_points); + __bang_cycle_and(cond_point_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, 4 * pad_total_points, + pad_total_points); + __bang_cycle_mul(weight_polation_nram, weight_polation_nram, weight_attn_nram, + 4 * pad_total_points, pad_total_points); + __bang_mul(weight_polation_nram, weight_polation_nram, + cond_point_polation_nram, pad_total_points * 4); + __bang_sub((float*)data_offset_nram_bl, (float*)data_offset_nram_bl, + (float*)data_offset_nram_tl, pad_total_points); + __bang_sub((float*)data_offset_nram_tr, (float*)data_offset_nram_tr, + (float*)data_offset_nram_tl, pad_total_points); +} + +/* + shape of each tensor: + output_nram: (channels) + input_nram: (4, valid_num, channels) + input_trans: (channels, 4, valid_num) + weight_selected_base: (4, deal_n, num_levels, num_points) + weight_compute: (4, valid_num) +*/ +template +__mlu_func__ void reduceLevel(T* output_nram, T* input_nram, T* input_trans, + T* weight_selected_base, T* weight_compute, + const int32_t pad_num_levels_points, + const int32_t pad_channels, + const int32_t pad_sample_stride_3) { + int32_t ci = 4 * pad_num_levels_points; + int32_t co = pad_channels; + __bang_write_value(weight_compute, 4 * pad_num_levels_points, 0); + __memcpy(weight_compute, weight_selected_base, + pad_num_levels_points * sizeof(T), NRAM2NRAM, + pad_num_levels_points * sizeof(T), pad_sample_stride_3 * sizeof(T), + 3); + __bang_transpose(input_trans, input_nram, ci, co); + __bang_cycle_mul(input_trans, input_trans, weight_compute, co * ci, ci); + __bang_sumpool(input_nram, input_trans, pad_num_levels_points, pad_channels, + 4, 1, 4, 1, 1); + __bang_transpose(input_trans, input_nram, pad_channels, + pad_num_levels_points); + __bang_sumpool(output_nram, input_trans, pad_channels, pad_num_levels_points, + 1, pad_num_levels_points, 1, 1, 1); +} + +__mlu_func__ void loadNram2Gpr(int32_t& v1, int32_t& v2, int32_t& v3, + int32_t* p1, int32_t* p2, int32_t* p3, + int32_t num_heads, int32_t channels_size, + bool sram_stay, int32_t sram_level_start_index) { + v1 = (int32_t)(*(float*)p1); + v2 = (int32_t)(*(float*)p2); + v3 = (int32_t)(*(float*)p3); + int32_t stride = sram_stay ? channels_size : num_heads * channels_size; + if (sram_stay) { + v1 = (v1 - sram_level_start_index) * stride; + v2 = v2 * stride; + v3 = v3 * stride; + } else { + v1 = v1 * stride; + v2 = v2 * stride; + v3 = v3 * stride; + } +} + +/* + Load 4 neighbors use 2 2D-memcpy, just use offset of N1, stride_3_1 + and + stride_2_1. + |<- stride_3_1 ->| + N1 N3 + ^ + | + stride_2_1 + | + v + N2 N4 + + Trickly fold the loop as 2. +*/ +template +__mlu_func__ void loadDataValueXram2NramAsync( + T* buf_value_nram_1, int32_t* offset_1, int32_t* stride_2_1, + int32_t* stride_3_1, T* value_src, const int32_t pad_num_levels_points, + const int32_t deal_points, const int32_t start_points_index, + const int32_t channel_size, const int32_t num_heads, bool sram_stay, + const int32_t sram_level_start_offset) { + int32_t offset_1_a, stride_2_1_a, stride_3_1_a; + int32_t offset_1_b, stride_2_1_b, stride_3_1_b; + loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, + offset_1 + start_points_index, stride_2_1 + start_points_index, + stride_3_1 + start_points_index, num_heads, channel_size, + sram_stay, sram_level_start_offset); + loadNram2Gpr( + offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + start_points_index + 1, + stride_2_1 + start_points_index + 1, stride_3_1 + start_points_index + 1, + num_heads, channel_size, sram_stay, sram_level_start_offset); + + int32_t value_offset = 0; + int32_t next = 0; + int32_t loop_num = deal_points / 2; + int32_t remain = deal_points % 2; + int32_t pad_channels = + PAD_UP(channel_size / sizeof(T), NFU_ALIGN_SIZE / sizeof(T)); + int32_t pad_channels_size = pad_channels * sizeof(T); + int32_t pad_data_value_stride = pad_num_levels_points * pad_channels_size; + for (int32_t j = start_points_index; j < start_points_index + loop_num * 2; + j += 2) { + value_offset = j * pad_channels_size; + next = j + 2; + for (int i = 0; i < 2; i++) { + __memcpy_async( + (int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i, + (int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, DIR, + 2 * pad_data_value_stride, stride_3_1_a, 1); + } + + loadNram2Gpr(offset_1_a, stride_2_1_a, stride_3_1_a, offset_1 + next, + stride_2_1 + next, stride_3_1 + next, num_heads, channel_size, + sram_stay, sram_level_start_offset); + + for (int i = 0; i < 2; i++) { + __memcpy_async((int8_t*)buf_value_nram_1 + value_offset + + pad_channels_size + pad_data_value_stride * i, + (int8_t*)value_src + offset_1_b + i * stride_2_1_b, + channel_size, DIR, 2 * pad_data_value_stride, stride_3_1_b, + 1); + } + + loadNram2Gpr(offset_1_b, stride_2_1_b, stride_3_1_b, offset_1 + next + 1, + stride_2_1 + next + 1, stride_3_1 + next + 1, num_heads, + channel_size, sram_stay, sram_level_start_offset); + } + + if (remain > 0) { + value_offset = (start_points_index + loop_num * 2) * pad_channels_size; + for (int i = 0; i < 2; i++) { + __memcpy_async( + (int8_t*)buf_value_nram_1 + value_offset + pad_data_value_stride * i, + (int8_t*)value_src + offset_1_a + i * stride_2_1_a, channel_size, DIR, + 2 * pad_data_value_stride, stride_3_1_a, 1); + } + } +} + +template +__mlu_func__ void loadNeighborPolationAttn( + T* value_output_nram, T* value_gdram, int32_t* data_offset_nram, + T* weight_polation_nram, T* cond_point_polation_nram, + T* cond_point_valid_nram, T* weight_attn_nram, T* buf_nram, + T* compute_buf_nram, const int32_t deal_n, const int32_t num_levels, + const int32_t num_points, const int32_t num_keys, const int32_t channels, + const int32_t num_heads, const int32_t pad_channels, + const int32_t pad_num_levels_points, T* value_sram, + const int32_t sram_level_start_index, + const int32_t sram_level_start_offset) { + int32_t channel_size = channels * sizeof(T); + int32_t pad_sample_stride_3 = deal_n * pad_num_levels_points; + + T* buf_value_nram = buf_nram; // (4, num_levels, num_points, channels) + T* buf_value_nram_trans = buf_nram + 4 * pad_num_levels_points * pad_channels; + T* weight_compute_nram = compute_buf_nram; // (4, num_levels, num_points) + + int32_t* offset = data_offset_nram; + int32_t* stride_2_1 = offset + pad_sample_stride_3; + int32_t* stride_3_1 = stride_2_1 + pad_sample_stride_3; + T* output_nram = value_output_nram; + int32_t step_offset = 0; + for (int32_t i = 0; i < deal_n; i++) { + __bang_write_value(buf_value_nram, 4 * pad_num_levels_points * pad_channels, + 0); + __sync_compute(); + if (sram_level_start_index > 0) { + loadDataValueXram2NramAsync( + buf_value_nram, offset, stride_2_1, stride_3_1, value_gdram, + pad_num_levels_points, num_points * sram_level_start_index, 0, + channel_size, num_heads, false, sram_level_start_offset); + } + if (sram_level_start_index < num_levels) { + loadDataValueXram2NramAsync( + buf_value_nram, offset, stride_2_1, stride_3_1, value_sram, + pad_num_levels_points, + num_points * (num_levels - sram_level_start_index), + num_points * sram_level_start_index, channel_size, num_heads, true, + sram_level_start_offset); + } + __sync_io_move_compute(); + reduceLevel(output_nram, buf_value_nram, buf_value_nram_trans, + weight_polation_nram + step_offset, weight_compute_nram, + pad_num_levels_points, pad_channels, pad_sample_stride_3); + step_offset += pad_num_levels_points; + offset = data_offset_nram + step_offset; + stride_2_1 = offset + pad_sample_stride_3; + stride_3_1 = stride_2_1 + pad_sample_stride_3; + output_nram += pad_channels; + } +} + +template +__mlu_func__ void prepareLoop( + int32_t* spatial_offset_nram, int32_t* spatial_hw_nram, + T* spatial_offset_bd_nram, T* spatial_h_bd_nram, T* spatial_w_bd_nram, + const char* data_level_start_index_gdram, + const char* data_spatial_shapes_gdram, const int32_t num_keys, + const int32_t num_levels, const int32_t num_points, + const int32_t max_deal_n, const int32_t channels) { + __memcpy(spatial_offset_nram, data_level_start_index_gdram, + num_levels * sizeof(int32_t), GDRAM2NRAM); + __memcpy(spatial_hw_nram, data_spatial_shapes_gdram, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + broadcastSpatialHW(spatial_offset_bd_nram, spatial_h_bd_nram, + spatial_w_bd_nram, spatial_hw_nram, spatial_offset_nram, + num_levels, num_points); +} + +/* + The shape of each tensor: + buf_compute_nram: (8, num_levels, num_points) + spatial_offset_nram: (num_levels) + spatial_hw_nram: (num_levels, 2) + spatial_offset_bd_nram: (num_levels, num_points) + spatial_w_bd_nram: (num_levels, num_points) + spatial_h_bd_nram: (num_levels, num_points) + value_output_nram: (deal_n, channels) + data_offset_nram: (4, deal_n, num_levels, num_points) + weight_polation_nram: (4, deal_n, num_levels, num_points) + cond_point_polation_nram: (4, deal_n, num_levels, num_points) + cond_point_valid_nram: (deal_n, num_levels, num_points) + loc_nram: (deal_n, num_levels, num_points, 2) + weight_attn_nram: (deal_n, num_levels, num_points) + buf_nram: (6, deal_n, num_levels, num_points) + + Note: buf_nram is reused in polation computing. +*/ +template +__mlu_func__ void memPolicyCommon( + T*& buf_compute_nram, T*& value_output_nram, int32_t*& data_offset_nram, + T*& weight_polation_nram, T*& cond_point_polation_nram, + T*& cond_point_valid_nram, T*& loc_nram, T*& weight_attn_nram, T*& buf_nram, + T*& buf_nram_end, T*& spatial_offset_bd_nram, T*& spatial_w_bd_nram, + T*& spatial_h_bd_nram, T*& value_sram, int32_t*& spatial_offset_nram, + int32_t*& spatial_hw_nram, int32_t& max_deal_n, int32_t& pad_channels, + int32_t& pad_num_levels_points, int32_t& pad_total_points, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points) { + pad_channels = PAD_UP(channels, NFU_ALIGN_SIZE / sizeof(T)); + int32_t num_levels_points = num_levels * num_points; + pad_num_levels_points = PAD_UP(num_levels_points, NFU_ALIGN_SIZE / sizeof(T)); + int32_t pad_num_levels_points_8 = 8 * pad_num_levels_points; + int32_t spatial_info_size = + PAD_UP(3 * num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); + int32_t fix_space_size = + spatial_info_size + + (3 * pad_num_levels_points + pad_num_levels_points) * sizeof(T); + int32_t left_space_size = NRAM_AVALIABLE_SIZE - fix_space_size; + int32_t common_buffer_size_each = 6 * pad_num_levels_points * sizeof(T); + int32_t inter_result_size_each = + 17 * pad_num_levels_points * sizeof(T) + pad_channels * sizeof(T); + + max_deal_n = + left_space_size / (common_buffer_size_each + inter_result_size_each); + + int32_t compute_buffer_size = + (9 * pad_num_levels_points * pad_channels) * sizeof(T); + int32_t common_buffer_size = max_deal_n * common_buffer_size_each; + // make sure buf_nram is large enough for compute + if (compute_buffer_size > common_buffer_size) { + int32_t tmp_deal_n = + (left_space_size - compute_buffer_size) / inter_result_size_each; + max_deal_n = __mluop_min(max_deal_n, tmp_deal_n); + } + + pad_total_points = max_deal_n * pad_num_levels_points; + buf_compute_nram = (T*)nram_buffer; + spatial_offset_nram = (int32_t*)(buf_compute_nram + pad_num_levels_points_8); + int32_t pad_3_levels = PAD_UP(3 * num_levels, NFU_ALIGN_SIZE / sizeof(T)); + spatial_hw_nram = spatial_offset_nram + num_levels; + spatial_offset_bd_nram = (T*)(spatial_offset_nram + pad_3_levels); + spatial_w_bd_nram = spatial_offset_bd_nram + pad_num_levels_points; + spatial_h_bd_nram = spatial_w_bd_nram + pad_num_levels_points; + value_output_nram = spatial_h_bd_nram + pad_num_levels_points; + data_offset_nram = (int32_t*)(value_output_nram + max_deal_n * pad_channels); + weight_polation_nram = (T*)(data_offset_nram + 4 * pad_total_points); + cond_point_polation_nram = weight_polation_nram + 4 * pad_total_points; + cond_point_valid_nram = cond_point_polation_nram + 4 * pad_total_points; + loc_nram = cond_point_valid_nram + pad_total_points; + weight_attn_nram = + loc_nram + + 2 * pad_total_points; // total_coord_pad = 2 * pad_total_points + buf_nram = weight_attn_nram + pad_total_points; + buf_nram_end = buf_nram + 6 * max_deal_n * pad_num_levels_points; + value_sram = (T*)sram_buffer; +} + +template +__mlu_func__ void loadDataValueGdram2Sram(T* value_sram, T* data_value_gdram, + const int32_t batch_idx, + const int32_t head_idx, + const int32_t sram_num_keys, + const int32_t num_heads, + const int32_t channels, + const int32_t skip_num_key) { + int32_t loop_num = + (sram_num_keys + MAX_MEMCPY_SEGNUM - 1) / MAX_MEMCPY_SEGNUM; + int32_t num_heads_channels = num_heads * channels; + for (int32_t i = 0; i < loop_num; i++) { + int32_t load_num = + __mluop_min(MAX_MEMCPY_SEGNUM, sram_num_keys - i * MAX_MEMCPY_SEGNUM); + size_t src_offset = ((size_t)batch_idx * sram_num_keys + skip_num_key + + i * MAX_MEMCPY_SEGNUM) * + num_heads_channels + + head_idx * channels; + int32_t dst_offset = i * MAX_MEMCPY_SEGNUM * channels; + __memcpy(value_sram + dst_offset, (T*)data_value_gdram + src_offset, + channels * sizeof(T), GDRAM2SRAM, channels * sizeof(T), + num_heads_channels * sizeof(T), load_num - 1); + } +} + +template +__mlu_func__ void computeSramCacheSizeAndOffset( + int32_t* sram_level_cache_size, int32_t* sram_level_start_index, + int32_t* sram_level_start_offset, const int32_t num_levels, + const int32_t num_keys, const int32_t channels, + const T* data_level_start_index_gdram, const int32_t sram_size) { + for (int32_t level_id = num_levels; level_id > 0; level_id--) { + int current_level_end_index = + level_id == num_levels + ? num_keys + : ((int32_t*)data_level_start_index_gdram)[level_id]; + int32_t current_level_size = + current_level_end_index - + ((int32_t*)data_level_start_index_gdram)[level_id - 1]; + if ((*sram_level_cache_size + current_level_size) * channels * sizeof(T) > + sram_size) { + break; + } + *sram_level_cache_size += current_level_size; + *sram_level_start_index = level_id - 1; + *sram_level_start_offset = num_keys - *sram_level_cache_size; + } +} + +template +__mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl( + const char* data_value_gdram, const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram) { + int32_t input_stride_4 = num_queries * num_heads * num_levels * num_points; + int32_t input_stride_3 = num_heads * num_levels * num_points; + int32_t input_stride_2 = num_levels * num_points; + int32_t output_stride_3 = num_queries * num_heads * channels; + int32_t output_stride_2 = num_heads * channels; + int32_t data_value_stride_3 = num_keys * num_heads * channels; + + T* value_output_nram = nullptr; // (deal_n, channels) + int32_t* data_offset_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* weight_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* cond_point_polation_nram = nullptr; // (4, deal_n, num_levels, num_points) + T* cond_point_valid_nram = nullptr; // (deal_n, num_levels, num_points) + T* loc_nram = nullptr; // (deal_n, num_levels, num_points, 2) + T* weight_attn_nram = nullptr; // (deal_n, num_levels, num_points) + T* buf_nram = nullptr; // (6, deal_n, num_levels, num_points) + T* buf_nram_end = nullptr; + T* spatial_offset_bd_nram = nullptr; // (num_levels, num_points) + T* spatial_w_bd_nram = nullptr; // (num_levels, num_points) + T* spatial_h_bd_nram = nullptr; // (num_levels, num_points) + int32_t* spatial_offset_nram = nullptr; // (num_levels) + int32_t* spatial_hw_nram = nullptr; // (num_levels, 2) + T* buf_compute_nram = nullptr; // (8, num_levels, num_points) + int32_t max_deal_n = 0; + int32_t pad_channels = 0; + int32_t pad_num_levels_points = 0; + int32_t pad_total_points = 0; + T* value_sram = nullptr; + + memPolicyCommon(buf_compute_nram, value_output_nram, data_offset_nram, + weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, buf_nram, + buf_nram_end, spatial_offset_bd_nram, spatial_w_bd_nram, + spatial_h_bd_nram, value_sram, spatial_offset_nram, + spatial_hw_nram, max_deal_n, pad_channels, + pad_num_levels_points, pad_total_points, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points); + if (max_deal_n <= 0) { + return; + } + + // split batch*head into taskDimY + int32_t batch_head = batch_size * num_heads; + int32_t cluster_avg_batch_head = (batch_head + taskDimY - 1) / taskDimY; + int32_t cluster_begin_batch_head = taskIdY * cluster_avg_batch_head; + int32_t cluster_act_batch_head = __mluop_min( + cluster_avg_batch_head, batch_head - cluster_begin_batch_head); + int32_t cluster_end_batch_head = + cluster_begin_batch_head + cluster_act_batch_head; + // split query into coreDim + int32_t core_avg_query = (num_queries + coreDim - 1) / coreDim; + int32_t core_begin_query = coreId * core_avg_query; + int32_t core_act_query = + __mluop_min(num_queries - core_begin_query, core_avg_query); + int32_t core_loop_num = (core_act_query + max_deal_n - 1) / max_deal_n; + int32_t core_step_query = + core_loop_num > 0 ? (core_act_query + core_loop_num - 1) / core_loop_num + : 0; + int32_t core_remain_query = + core_act_query - (core_loop_num - 1) * core_step_query; + int32_t first_deal_query = + (int)(core_loop_num > 0) * + (core_loop_num > 1 ? core_step_query : core_remain_query); + + prepareLoop(spatial_offset_nram, spatial_hw_nram, spatial_offset_bd_nram, + spatial_h_bd_nram, spatial_w_bd_nram, + data_level_start_index_gdram, data_spatial_shapes_gdram, num_keys, + num_levels, num_points, max_deal_n, channels); + + int sram_total_size = 0; + int sram_level_start_index = num_levels; + int sram_level_start_offset = 0; + computeSramCacheSizeAndOffset( + &sram_total_size, &sram_level_start_index, &sram_level_start_offset, + num_levels, num_keys, channels, (int32_t*)data_level_start_index_gdram, + SRAM_FOR_VALUE_SIZE); + + for (int32_t bh_idx = cluster_begin_batch_head; + bh_idx < cluster_end_batch_head; bh_idx++) { + int32_t b = bh_idx / num_heads; + int32_t head_idx = bh_idx % num_heads; + + size_t output_base_offset = + (size_t)b * output_stride_3 + head_idx * channels; + int32_t attn_weight_base_offset = + b * input_stride_4 + head_idx * input_stride_2; + + if (__is_mpu() && (sram_level_start_index != num_levels)) { + loadDataValueGdram2Sram(value_sram, (T*)data_value_gdram, b, head_idx, + sram_total_size, num_heads, channels, + sram_level_start_offset); + } + + __sync_cluster(); + + if (__is_ipu()) { + // compute weight, offset and condition + int32_t attn_weight_offset = + attn_weight_base_offset + core_begin_query * input_stride_3; + int32_t loc_offset = attn_weight_offset * 2; + if (first_deal_query > 0) { + __bang_write_value(loc_nram, 2 * pad_total_points, 0); + __bang_write_value(weight_attn_nram, pad_total_points, 0); + __sync_compute(); + __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, + input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * 2 * sizeof(T), + input_stride_3 * 2 * sizeof(T), first_deal_query - 1); + __memcpy_async(weight_attn_nram, + (T*)data_attn_weight_gdram + attn_weight_offset, + input_stride_2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * sizeof(T), + input_stride_3 * sizeof(T), first_deal_query - 1); + getConditionCoordWeight( + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, + spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, + buf_nram, first_deal_query, num_levels, num_points, num_heads, + pad_num_levels_points); + } + } + + for (int32_t i = 0; __is_ipu() && i < core_loop_num; i++) { + __bang_write_value(loc_nram, 2 * pad_total_points, 0); + __bang_write_value(weight_attn_nram, pad_total_points, 0); + int32_t deal_n = + i < core_loop_num - 1 ? core_step_query : core_remain_query; + int32_t load_n = + i < core_loop_num - 2 ? core_step_query : core_remain_query; + // load value and polation + loadNeighborPolationAttn( + value_output_nram, + (T*)data_value_gdram + b * data_value_stride_3 + head_idx * channels, + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, weight_attn_nram, buf_nram, buf_compute_nram, + deal_n, num_levels, num_points, num_keys, channels, num_heads, + pad_channels, pad_num_levels_points, value_sram, + sram_level_start_index, sram_level_start_offset); + __sync_io_move_compute(); + // load next weight and loc + if (i < core_loop_num - 1) { + int32_t core_query_offset = (i + 1) * core_step_query; + int32_t attn_weight_offset = + attn_weight_base_offset + + (core_begin_query + core_query_offset) * input_stride_3; + int32_t loc_offset = attn_weight_offset * 2; + __memcpy_async(loc_nram, (T*)data_sampling_loc_gdram + loc_offset, + input_stride_2 * 2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * 2 * sizeof(T), + input_stride_3 * 2 * sizeof(T), load_n - 1); + __memcpy_async(weight_attn_nram, + (T*)data_attn_weight_gdram + attn_weight_offset, + input_stride_2 * sizeof(T), GDRAM2NRAM, + pad_num_levels_points * sizeof(T), + input_stride_3 * sizeof(T), load_n - 1); + __sync_io_move_compute(); + } + // store result + size_t output_offset = + ((size_t)core_begin_query + i * core_step_query) * output_stride_2; + __memcpy_async((T*)data_col_gdram + output_base_offset + output_offset, + value_output_nram, channels * sizeof(T), NRAM2GDRAM, + output_stride_2 * sizeof(T), pad_channels * sizeof(T), + deal_n - 1); + + // compute cond/weight/offset + if (i < core_loop_num - 1) { + getConditionCoordWeight( + data_offset_nram, weight_polation_nram, cond_point_polation_nram, + cond_point_valid_nram, loc_nram, weight_attn_nram, + spatial_offset_bd_nram, spatial_w_bd_nram, spatial_h_bd_nram, + buf_nram, load_n, num_levels, num_points, num_heads, + pad_num_levels_points); + } + __sync_io_move_compute(); + } + __sync_cluster(); + } +} + +template +__mlu_global__ void MLUKernelMsDeformAttnForwardFast( + const char* data_value_gdram, const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram) { + MLUKernelMsDeformAttnForwardFastImpl( + data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, + data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); +} + +template __mlu_global__ void MLUKernelMsDeformAttnForwardFast( + const char* data_value_gdram, const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram); + +void KernelMsDeformAttnForwardFast( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char* data_value_gdram, + const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram) { + MLUKernelMsDeformAttnForwardFast<<>>( + data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, + data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); +} diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index 845465ae4b..e118f632cd 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -11,6 +11,7 @@ *************************************************************************/ #include "pytorch_device_registry.hpp" #include "pytorch_mlu_helper.hpp" +#include "ms_deform_attn_fast_mlu_kernel.hpp" #define MIN(a, b) (((a) < (b)) ? (a) : (b)) @@ -20,6 +21,8 @@ typedef enum { 1, /*!< MLUKernelMsDeformAttnForwardDefault */ MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL = 2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */ + MS_DEFORM_ATTN_FORWARD_FAST = + 3, /*!< MLUKernelMsDeformAttnForwardFast */ } MsDeformAttnForwardPolicy; void KernelMsDeformAttnForwardDefault( @@ -40,6 +43,15 @@ void KernelMsDeformAttnForwardSmallChannel( const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char* data_col_gdram); +void KernelMsDeformAttnForwardFast( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char* data_value_gdram, + const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram); typedef enum { MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0, @@ -99,7 +111,9 @@ MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( #endif int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { + if (num_levels * num_points <= 128 && num_levels * num_points * channels <= 8192) { + return MS_DEFORM_ATTN_FORWARD_FAST; + } else if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { return MS_DEFORM_ATTN_FORWARD_DEFAULT; } else if (channels > nram_size / 12 / sizeof(float)) { return MS_DEFORM_ATTN_FORWARD_DEFAULT; @@ -310,6 +324,18 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, (char*)output_ptr); break; } + case MS_DEFORM_ATTN_FORWARD_FAST: { + CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardFast<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelMsDeformAttnForwardFast( + k_dim, k_type, queue, data_type, (char*)value_ptr, + (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, + (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, + (char*)output_ptr); + break; + } + } output = output.view({batch_size, num_queries, num_heads * channels});