Skip to content

Commit 42d9c66

Browse files
committed
SVE Implementation for SDDMMCOO with copyrhs op
SVE intrinsic code is added to improve the performance of SDDMMCOO Op when bacst is disabled and reduce_size=1
1 parent 743e65f commit 42d9c66

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

CMakeLists.txt

+11
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ if (${BUILD_TYPE} STREQUAL "dev")
122122
if (MSVC)
123123
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Od")
124124
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Od")
125+
elseif ( CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(AARCH64)")
126+
# Check if the compiler supports ARMv8.2-A or later with SVE
127+
include(CheckCCompilerFlag)
128+
# Try to detect whether the system supports SVE
129+
check_c_compiler_flag("-march=armv8.2-a+sve" SUPPORTS_SVE)
130+
# Output the result
131+
if(SUPPORTS_SVE)
132+
message(STATUS "Hardware supports SVE")
133+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve")
134+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb -march=armv8.2-a+sve")
135+
endif()
125136
else()
126137
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0 -g3 -ggdb")
127138
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3 -ggdb")

src/array/cpu/sddmm.h

+52
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include <dgl/bcast.h>
1111
#include <dgl/runtime/parallel_for.h>
1212

13+
#ifdef __ARM_FEATURE_SVE
14+
#include <arm_sve.h> // to leverage sve intrinsics
15+
#endif
16+
1317
#include "../selector.h"
1418

1519
namespace dgl {
@@ -222,6 +226,54 @@ struct Dot {
222226

223227
} // namespace op
224228

229+
// SDDMMCoo Specialization
230+
#ifdef __ARM_FEATURE_SVE
231+
template <>
232+
void SDDMMCoo <int32_t, float, dgl::aten::cpu::op::CopyRhs<float>, 0, 2> (
233+
const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out) {
234+
const bool has_idx = !IsNullArray(coo.data);
235+
const int32_t* row = coo.row.Ptr<int32_t>();
236+
const int32_t* col = coo.col.Ptr<int32_t>();
237+
const int32_t* edges = coo.data.Ptr<int32_t>();
238+
const float* X = lhs.Ptr<float>();
239+
const float* Y = rhs.Ptr<float>();
240+
const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
241+
rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
242+
float* O = out.Ptr<float>();
243+
#pragma omp parallel for
244+
for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
245+
const int32_t rid = row[i];
246+
const int32_t cid = col[i];
247+
const int32_t eid = has_idx ? edges[i] : i;
248+
float* out_off = O + eid * dim;
249+
if (!bcast.use_bcast && reduce_size == 1) {
250+
for (int64_t k = 0; k < dim; k += svcntw()) {
251+
svbool_t pgk = svwhilelt_b32(k, dim);
252+
int64_t rhs_base1 = cid * rhs_dim;
253+
svfloat32_t rhs_off_vector = svld1_f32(pgk, &Y[rhs_base1 + k]);
254+
svst1_f32(pgk, &out_off[k], rhs_off_vector);
255+
}
256+
} else {
257+
//with bcast.use_bcast == true, Op::use_lhs == false, and Op::Call
258+
for (int64_t k = 0; k < dim; ++k) {
259+
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
260+
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
261+
const float* lhs_off =
262+
dgl::aten::cpu::op::CopyRhs<float>::use_lhs ? X + rid * lhs_dim +
263+
lhs_add * reduce_size
264+
: nullptr;
265+
266+
const float* rhs_off =
267+
dgl::aten::cpu::op::CopyRhs<float>::use_rhs ? Y + cid * rhs_dim +
268+
rhs_add * reduce_size
269+
: nullptr;
270+
out_off[k] = dgl::aten::cpu::op::CopyRhs<float>::Call(lhs_off, rhs_off, bcast.reduce_size);
271+
}
272+
}
273+
}
274+
}
275+
#endif
276+
225277
} // namespace cpu
226278
} // namespace aten
227279
} // namespace dgl

0 commit comments

Comments
 (0)