forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReducedPrecisionFloatGemvFastPathKernel.cpp
481 lines (440 loc) · 19.1 KB
/
ReducedPrecisionFloatGemvFastPathKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Half.h>
#include <c10/util/Unroll.h>
#if defined(__aarch64__) && !defined(C10_MOBILE)
#include <arm_neon.h>
#include <cpuinfo.h>
#endif
namespace at::native {
inline namespace CPU_CAPABILITY {
#if !defined(C10_MOBILE)
constexpr auto kF32RegisterPairsPerIteration = 4;
constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2;
constexpr auto kF32ElementsPerRegister = vec::Vectorized<float>::size();
constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister;
namespace {
template <typename T>
constexpr int IntegerLog2(T n, int p = 0) {
return (n <= 1) ? p : IntegerLog2(n / 2, p + 1);
}
} // namespace
/*
* NOTE [ GGML Copyright Notice ]
* The below reduce overload and fp16_dot_with_fp16_arith function is
* adapted from llama.cpp's ggml_vec_dot_f16 and surrounding utility
* functions, so here is the required copyright notice:
*
* MIT License
*
* Copyright (c) 2023-2024 The ggml authors
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#if !defined(__aarch64__) || defined( __ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
constexpr auto kF16RegistersPerIteration = 16;
constexpr auto kF16ElementsPerRegister = vec::Vectorized<Half>::size();
constexpr auto kF16ElementsPerIteration = kF16RegistersPerIteration * kF16ElementsPerRegister;
float reduce(vec::VectorizedN<Half, kF16RegistersPerIteration>& x) {
int offset = kF16RegistersPerIteration;
c10::ForcedUnroll<IntegerLog2(kF16RegistersPerIteration)>{}([&offset, &x](auto idx) {
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = x[i] + x[offset + i];
}
});
const auto [t0, t1] = vec::convert_half_float(x[0]);
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
return vaddvq_f32(t0 + t1);
#else
return vec::vec_reduce_all<float>(
std::plus<vec::Vectorized<float>>(),
t0 + t1);
#endif
}
float fp16_dot_with_fp16_arith(const Half* x, const Half* a, int len) {
vec::VectorizedN<Half, kF16RegistersPerIteration> sum(0);
const auto len_aligned = len & ~(kF16ElementsPerIteration - 1);
for (int j = 0; j < len_aligned ; j += kF16ElementsPerIteration) {
for (int k = 0; k < kF16RegistersPerIteration; ++k) {
const auto temp_x = vec::Vectorized<Half>::loadu(x + j + k * vec::Vectorized<Half>::size());
const auto temp_a = vec::Vectorized<Half>::loadu(a + j + k * vec::Vectorized<Half>::size());
sum[k] = vec::fmadd(temp_x, temp_a, sum[k]);
}
}
auto reduced_sum = reduce(sum);
for (int j = len_aligned; j < len; ++j) {
reduced_sum += x[j] * a[j];
}
return reduced_sum;
}
// Rather than unrolling to process multiple rows (transposed columns)
// of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll
// along an individual dot product.
static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const Half* a, const int lda, const Half *x, const float beta, Half* y, int incy) {
if (beta == 0.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
y[i * incy] = fp16_dot_with_fp16_arith(x, a + lda * i, m);
}
});
} else if (beta == 1.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
y[i * incy] += fp16_dot_with_fp16_arith(x, a + lda * i, m);
}
});
} else {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
y[i * incy] = beta * y[i * incy] + fp16_dot_with_fp16_arith(x, a + lda * i, m);
}
});
}
}
#endif // !defined(__aarch64__) || defined( __ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
float reduce(vec::Vectorized<float> x) {
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
return vaddvq_f32(x);
#else
return vec::vec_reduce_all<float>(
std::plus<vec::Vectorized<float>>(),
x);
#endif
}
// The below reduce overload and fp16_dot_with_fp32_arith are adapted
// from llama.cpp's ggml_vec_dot_f32 and surrounding utility
// functions. See NOTE [ GGML Copyright Notice ] above for the
// required notice.
float reduce(vec::VectorizedN<float, kF32RegistersPerIteration>& x) {
int offset = kF32RegistersPerIteration;
c10::ForcedUnroll<IntegerLog2(kF32RegistersPerIteration)>{}([&offset, &x](auto idx) {
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = x[i] + x[offset + i];
}
});
return reduce(x[0]);
}
// We would have to write a separate SVE-specific path to use SVE
// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path
// working.
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
// https://godbolt.org/z/z8P4Yncra
#define COMPILER_SUPPORTS_BF16_TARGET 1
#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
// https://gcc.gnu.org/gcc-10/changes.html
// https://godbolt.org/z/cdGG7vn8o
#define COMPILER_SUPPORTS_BF16_TARGET 1
#else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
#define COMPILER_SUPPORTS_BF16_TARGET 0
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
#if COMPILER_SUPPORTS_BF16_TARGET
#define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16")))
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void
dot_with_fp32_arith_main_inner_loop_bfdot(
const BFloat16* vec1,
const BFloat16* vec2,
vec::VectorizedN<float, kF32RegistersPerIteration>& sum,
int registerPairIndex) {
// NOTE[Intrinsics in bfdot variant]: We can't use
// vec::Vectorized<BFloat16>::loadu here because linux-aarch64 GCC
// inexplicably can't convert Vectorized<BFloat16> to
// bfloat16x8_t. I suspect a bug or incomplete
// __attribute__((target)) implementation. Intrinsics should be fine
// because we're using vbfdotq_f32 below anyway.
const auto temp_vec1 = vld1q_bf16(
reinterpret_cast<const bfloat16_t*>(
&vec1[registerPairIndex * vec::Vectorized<BFloat16>::size()]));
const auto temp_vec2 = vld1q_bf16(
reinterpret_cast<const bfloat16_t*>(
&vec2[registerPairIndex * vec::Vectorized<BFloat16>::size()]));
sum[registerPairIndex] =
vbfdotq_f32(sum[registerPairIndex], temp_vec1, temp_vec2);
}
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
const at::BFloat16* vec1,
const at::BFloat16* vec2,
vec::Vectorized<float>* tail_sum,
int idx) {
// See NOTE[Intrinsics in bfdot variant] above.
const auto temp_vec1 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec1[idx]));
const auto temp_vec2 = vld1q_bf16(reinterpret_cast<const bfloat16_t*>(&vec2[idx]));
*tail_sum = vbfdotq_f32(*tail_sum, temp_vec1, temp_vec2);
}
#else
#define TARGET_ARM_BF16_ATTRIBUTE
#endif // COMPILER_SUPPORTS_BF16_TARGET
namespace {
// Returns (acc_low + a_low_half * b_low_half, acc_high + a_high_half * b_high_half)
std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
const vec::Vectorized<c10::Half>& a,
const vec::Vectorized<c10::Half>& b,
const vec::Vectorized<float>& acc_low,
const vec::Vectorized<float>& acc_high) {
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE)
return std::make_pair(vfmlalq_low_f16(acc_low, a, b), vfmlalq_high_f16(acc_high, a, b));
#else
const auto [a_float_low, a_float_high] = convert_half_float(a);
const auto [b_float_low, b_float_high] = convert_half_float(b);
return std::make_pair(fmadd(a_float_low, b_float_low, acc_low), fmadd(a_float_high, b_float_high, acc_high));
#endif
}
[[maybe_unused]] std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
const vec::Vectorized<c10::BFloat16>& a,
const vec::Vectorized<c10::BFloat16>& b,
const vec::Vectorized<float>& acc_low,
const vec::Vectorized<float>& acc_high) {
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
return std::make_pair(fmadd(a_float_low, b_float_low, acc_low), fmadd(a_float_high, b_float_high, acc_high));
}
// Return a + b_low * c_low + b_high * c_high
vec::Vectorized<float> fmadd(vec::Vectorized<float> a, vec::Vectorized<Half> b, vec::Vectorized<Half> c) {
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML)
// NOTE: this instruction is an optional instruction in ARM v8.2 and
// v8.3, but mandatory in v8.4 per
// https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en
// I'm not certain that I have the right feature test macro.
vec::Vectorized<float> first = vfmlalq_low_f16(a, b, c);
return vfmlalq_high_f16(first, b, c);
#else
const auto [b_float_low, b_float_high] = convert_half_float(b);
const auto [c_float_low, c_float_high] = convert_half_float(c);
const auto first = vec::fmadd(b_float_low, c_float_low, a);
return vec::fmadd(b_float_high, c_float_high, first);
#endif
}
[[maybe_unused]] vec::Vectorized<float> fmadd(
const vec::Vectorized<float>& acc,
const vec::Vectorized<c10::BFloat16>& a,
const vec::Vectorized<c10::BFloat16>& b) {
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
return fmadd(a_float_high, b_float_high, fmadd(a_float_low, b_float_low, acc));
}
} // namespace
template <typename T>
C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
const T* vec1,
const T* vec2,
vec::VectorizedN<float, kF32RegistersPerIteration>& sum,
int registerPairIndex) {
static_assert(std::is_same_v<T, Half> || std::is_same_v<T, BFloat16>);
const auto temp_vec1 = vec::Vectorized<T>::loadu(&vec1[registerPairIndex * vec::Vectorized<T>::size()]);
const auto temp_vec2 = vec::Vectorized<T>::loadu(&vec2[registerPairIndex * vec::Vectorized<T>::size()]);
const auto [result_low, result_high] = fmadd(temp_vec1, temp_vec2, sum[2 * registerPairIndex], sum[2 * registerPairIndex + 1]);
sum[2 * registerPairIndex] = result_low;
sum[2 * registerPairIndex + 1] = result_high;
}
template <typename T>
C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
const T* vec1,
const T* vec2,
vec::Vectorized<float>* tail_sum,
int idx) {
const auto temp_vec1 = vec::Vectorized<T>::loadu(&vec1[idx]);
const auto temp_vec2 = vec::Vectorized<T>::loadu(&vec2[idx]);
*tail_sum = fmadd(*tail_sum, temp_vec1, temp_vec2);
}
template <typename T>
C10_ALWAYS_INLINE auto
dot_with_fp32_arith_main_loop_no_bfdot(
const T* vec1,
const T* vec2,
int64_t len) {
vec::VectorizedN<float, kF32RegistersPerIteration> sum(0);
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
const auto* vec1_ = vec1 + j;
const auto* vec2_ = vec2 + j;
c10::ForcedUnroll<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k);
});
}
return reduce(sum);
}
#if COMPILER_SUPPORTS_BF16_TARGET
template <int n>
struct ForcedUnrollTargetBFloat16 {
template <typename Func>
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
ForcedUnrollTargetBFloat16<n - 1>{}(f);
f(n - 1);
}
};
template <>
struct ForcedUnrollTargetBFloat16<1> {
template <typename Func>
TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const {
f(0);
}
};
C10_ALWAYS_INLINE TARGET_ARM_BF16_ATTRIBUTE auto
dot_with_fp32_arith_main_loop_bfdot(
const BFloat16* vec1,
const BFloat16* vec2,
int64_t len) {
vec::VectorizedN<float, kF32RegistersPerIteration> sum(0);
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) {
const auto* vec1_ = vec1 + j;
const auto* vec2_ = vec2 + j;
ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}([vec1_, vec2_, &sum](auto k)
C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k);
});
}
return reduce(sum);
}
#endif // COMPILER_SUPPORTS_BF16_TARGET
static_assert(
(vec::Vectorized<Half>::size() & (vec::Vectorized<Half>::size() - 1)) == 0,
"Below code expects power-of-2 vector register size!");
// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with
// TARGET_ARM_BF16_ATTRIBUTE failed because unlike clang, GCC will not
// allow inlining a non-bf16-specific function into a bf16-specific
// function. We can work around this by duplicating the code into the
// bfdot and non-bfdot callsites. The code is in this macro to avoid
// actual copy/paste.
#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \
/* First-tier tail fixup: make sure we handle workloads that can */ \
/* benefit from vectorization, but don't fit into our fully unrolled */ \
/* loop above. */ \
vec::Vectorized<float> tail_sum(0); \
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \
const auto len_aligned_vec = len & ~(vec::Vectorized<Half>::size() - 1); \
for (int j = len_aligned; j < len_aligned_vec; j += vec::Vectorized<Half>::size()) { \
dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix(vec1, vec2, &tail_sum, j); \
} \
reduced_sum += reduce(tail_sum); \
\
/* Second-tier tail fixup: handle all workloads. */ \
for (int j = len_aligned_vec; j < len; ++j) { \
/* Attempting to use Half here caused multiple test failures; */ \
/* using float to unbreak. (Suspect we need a scalar FMA.) */ \
float x1 = vec1[j]; \
float x2 = vec2[j]; \
reduced_sum += x1 * x2; \
} \
return reduced_sum
#if COMPILER_SUPPORTS_BF16_TARGET
TARGET_ARM_BF16_ATTRIBUTE float
dot_with_fp32_arith_bfdot(const BFloat16* vec1, const BFloat16* vec2, int64_t len) {
auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len);
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot);
}
#endif // COMPILER_SUPPORTS_BF16_TARGET
template <typename T>
C10_ALWAYS_INLINE float
dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) {
auto reduced_sum = dot_with_fp32_arith_main_loop_no_bfdot(vec1, vec2, len);
DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_no_bfdot);
}
#undef DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY
float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len) {
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
}
void fp16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const Half* a, const int lda, const Half *x, const float beta, Half* y, int incy) {
if (beta == 0.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
y[i * incy] = fp16_dot_with_fp32_arith(x, a + lda * i, m);
}
});
} else if (beta == 1.0f) {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
// We need to accumulate in fp32; y[i * incy] += ... gets wrong results.
y[i * incy] = static_cast<float>(y[i * incy]) + fp16_dot_with_fp32_arith(x, a + lda * i, m);
}
});
} else {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
y[i * incy] = beta * y[i * incy] + fp16_dot_with_fp32_arith(x, a + lda * i, m);
}
});
}
}
void fp16_gemv_trans(
const int m,
const int n,
const float alpha,
const Half* a,
const int lda,
const Half* x,
const int incx,
const float beta,
Half* y,
const int incy) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0);
#if !defined(__aarch64__) || defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
if (at::globalContext().allowFP16ReductionCPU()) {
return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, beta, y, incy);
}
#endif
return fp16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, beta, y, incy);
}
float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
#if COMPILER_SUPPORTS_BF16_TARGET
if (cpuinfo_has_arm_bf16()) {
return dot_with_fp32_arith_bfdot(vec1, vec2, len);
} else
#endif // COMPILER_SUPPORTS_BF16_TARGET
{
return dot_with_fp32_arith_no_bfdot(vec1, vec2, len);
}
}
void bf16_gemv_trans_fp32_arith_by_dot_products(const int m, const int n, const at::BFloat16* a, const int lda, const at::BFloat16 *x, at::BFloat16* y, int incy) {
parallel_for(0, n, 1, [&](int begin, int end) {
for (int i = begin; i < end; ++i) {
y[i * incy] = bf16_dot_with_fp32_arith(x, a + lda * i, m);
}
});
}
void bf16_gemv_trans(
const int m,
const int n,
const at::BFloat16 alpha,
const at::BFloat16* a,
const int lda,
const at::BFloat16* x,
const int incx,
const at::BFloat16 beta,
at::BFloat16* y,
const int incy) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0);
return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy);
}
#endif // !defined(C10_MOBILE)
} // namespace CPU_CAPABILITY
#if !defined(C10_MOBILE)
REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans)
REGISTER_DISPATCH(bf16_gemv_trans_stub, &bf16_gemv_trans)
#endif //!defined(C10_MOBILE)
} // namespace at::native