|
10 | 10 | #include <dgl/bcast.h>
|
11 | 11 | #include <dgl/runtime/parallel_for.h>
|
12 | 12 |
|
| 13 | +#ifdef __ARM_FEATURE_SVE |
| 14 | +#include <arm_sve.h> // to leverage sve intrinsics |
| 15 | +#endif |
| 16 | + |
13 | 17 | #include "../selector.h"
|
14 | 18 |
|
15 | 19 | namespace dgl {
|
@@ -222,6 +226,54 @@ struct Dot {
|
222 | 226 |
|
223 | 227 | } // namespace op
|
224 | 228 |
|
| 229 | +// SDDMMCoo Specialization |
| 230 | +#ifdef __ARM_FEATURE_SVE |
| 231 | +template <> |
| 232 | +void SDDMMCoo <int32_t, float, dgl::aten::cpu::op::CopyRhs<float>, 0, 2> ( |
| 233 | + const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out) { |
| 234 | + const bool has_idx = !IsNullArray(coo.data); |
| 235 | + const int32_t* row = coo.row.Ptr<int32_t>(); |
| 236 | + const int32_t* col = coo.col.Ptr<int32_t>(); |
| 237 | + const int32_t* edges = coo.data.Ptr<int32_t>(); |
| 238 | + const float* X = lhs.Ptr<float>(); |
| 239 | + const float* Y = rhs.Ptr<float>(); |
| 240 | + const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, |
| 241 | + rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size; |
| 242 | + float* O = out.Ptr<float>(); |
| 243 | +#pragma omp parallel for |
| 244 | + for (int64_t i = 0; i < coo.row->shape[0]; ++i) { |
| 245 | + const int32_t rid = row[i]; |
| 246 | + const int32_t cid = col[i]; |
| 247 | + const int32_t eid = has_idx ? edges[i] : i; |
| 248 | + float* out_off = O + eid * dim; |
| 249 | + if (!bcast.use_bcast && reduce_size == 1) { |
| 250 | + for (int64_t k = 0; k < dim; k += svcntw()) { |
| 251 | + svbool_t pgk = svwhilelt_b32(k, dim); |
| 252 | + int64_t rhs_base1 = cid * rhs_dim; |
| 253 | + svfloat32_t rhs_off_vector = svld1_f32(pgk, &Y[rhs_base1 + k]); |
| 254 | + svst1_f32(pgk, &out_off[k], rhs_off_vector); |
| 255 | + } |
| 256 | + } else { |
| 257 | + //with bcast.use_bcast == true, Op::use_lhs == false, and Op::Call |
| 258 | + for (int64_t k = 0; k < dim; ++k) { |
| 259 | + const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; |
| 260 | + const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; |
| 261 | + const float* lhs_off = |
| 262 | + dgl::aten::cpu::op::CopyRhs<float>::use_lhs ? X + rid * lhs_dim + |
| 263 | + lhs_add * reduce_size |
| 264 | + : nullptr; |
| 265 | + |
| 266 | + const float* rhs_off = |
| 267 | + dgl::aten::cpu::op::CopyRhs<float>::use_rhs ? Y + cid * rhs_dim + |
| 268 | + rhs_add * reduce_size |
| 269 | + : nullptr; |
| 270 | + out_off[k] = dgl::aten::cpu::op::CopyRhs<float>::Call(lhs_off, rhs_off, bcast.reduce_size); |
| 271 | + } |
| 272 | + } |
| 273 | + } |
| 274 | +} |
| 275 | +#endif |
| 276 | + |
225 | 277 | } // namespace cpu
|
226 | 278 | } // namespace aten
|
227 | 279 | } // namespace dgl
|
|
0 commit comments