-
Notifications
You must be signed in to change notification settings - Fork 526
/
Copy pathop_sdpa.cpp
627 lines (564 loc) · 19.3 KB
/
op_sdpa.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
#include <executorch/extension/llm/custom_ops/op_sdpa_impl.h>
#include <executorch/kernels/optimized/blas/CPUBlas.h>
#include <executorch/kernels/optimized/vec/functional.h>
#include <executorch/kernels/optimized/vec/vec.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
// @lint-ignore CLANGTIDY facebook-unused-include-check
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#ifdef ET_USE_THREADPOOL
#include <executorch/extension/threadpool/threadpool.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>
#endif
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
namespace torch {
namespace executor {
namespace native {
namespace {
bool validate_flash_attention_args(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask) {
ET_CHECK_OR_RETURN_FALSE(query.dim() == 4, "query must be a 4D tensor");
ET_CHECK_OR_RETURN_FALSE(key.dim() == 4, "key must be a 4D tensor");
ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor");
// Sizes
ET_CHECK_OR_RETURN_FALSE(
(query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
ET_CHECK_OR_RETURN_FALSE(
(query.scalar_type() == ScalarType::Float) ||
(query.scalar_type() == ScalarType::Char),
"Query must be Float type");
ET_CHECK_OR_RETURN_FALSE(
(query.scalar_type() == key.scalar_type()) &&
(query.scalar_type() == value.scalar_type()),
"Key and Value must have the same data type as Query");
ET_CHECK_OR_RETURN_FALSE(
!attn_mask.has_value() || attn_mask.value().dim() == 2,
"Attention mask must be a 2D tensor");
ET_CHECK_OR_RETURN_FALSE(
!attn_mask.has_value() ||
attn_mask.value().scalar_type() == query.scalar_type(),
"Attention mask must be a 2D tensor");
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(query.dim_order().data(), query.dim()),
"key cache must be in contiguous dim order");
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(key.dim_order().data(), key.dim()),
"value cache must be in contiguous dim order");
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
"value cache must be in contiguous dim order");
if (attn_mask.has_value()) {
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(
attn_mask.value().dim_order().data(), attn_mask.value().dim()),
"value cache must be in contiguous dim order");
}
return true;
}
bool validate_cache_quant_params_args(
const Tensor& t,
const Tensor& t_zero_points,
const Tensor& t_scales) {
ET_CHECK_OR_RETURN_FALSE(
t.dim() == t_scales.dim(),
"Quantized tensor and scales must have the same number of dimensions");
ET_CHECK_OR_RETURN_FALSE(
t.dim() == t_zero_points.dim(),
"Quantized tensor and scales must have the same number of dimensions");
ET_CHECK_OR_RETURN_FALSE(
(t.scalar_type() == ScalarType::Char), "Tensor must be of int8_t type");
ET_CHECK_OR_RETURN_FALSE(
(t_scales.scalar_type() == ScalarType::Float),
"Scales tensor must be of float type");
ET_CHECK_OR_RETURN_FALSE(
(t_zero_points.scalar_type() == ScalarType::Char),
"Zero points tensor must be of int8_t type");
// Sizes
for (int64_t i = 0; i < t.dim() - 1; i++) {
ET_CHECK_OR_RETURN_FALSE(
(t.size(i) == t_scales.size(i)),
"Quantized tensor and scales have different shape"
"at dim: %" PRId64 ", t: %zd, t_scales: %zd",
i,
t.size(i),
t_scales.size(i));
;
ET_CHECK_OR_RETURN_FALSE(
(t.size(i) == t_zero_points.size(i)),
"Quantized tensor and zero points have different shape"
"at dim: %" PRId64 ", t: %zd, t_scales: %zd",
i,
t.size(i),
t_zero_points.size(i));
;
}
return true;
}
bool validate_cache_params(
const Tensor& k_cache,
const Tensor& v_cache,
int64_t start_pos,
int64_t seq_length) {
ET_CHECK_OR_RETURN_FALSE(k_cache.dim() == 4, "kcache must be a 4D tensor");
ET_CHECK_OR_RETURN_FALSE(v_cache.dim() == 4, "v_cache must be a 4D tensor");
ET_CHECK_OR_RETURN_FALSE(
start_pos < k_cache.size(1),
"start_pos must be less than key cache at dim 1");
ET_CHECK_OR_RETURN_FALSE(
start_pos < v_cache.size(1),
"start_pos must be less than value cache at dim 1");
ET_CHECK_OR_RETURN_FALSE(
(start_pos + seq_length) <= k_cache.size(1),
"start_post + seq_length must be less than max seq length supported by key cache."
"start pos: %" PRId64 ", seq_length: %" PRId64
"."
"key cache size: %zd",
start_pos,
seq_length,
k_cache.size(1));
ET_CHECK_OR_RETURN_FALSE(
(start_pos + seq_length) <= v_cache.size(1),
"start_post + seq_length must be less than max seq length supported by key cache."
"start pos: %" PRId64 ", seq_length: %" PRId64
"."
"value cache size: %zd",
start_pos,
seq_length,
v_cache.size(1));
// Make sure they are in contiguous dim order
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()),
"key cache must be in contiguous dim order");
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(v_cache.dim_order().data(), v_cache.dim()),
"value cache must be in contiguous dim order");
return true;
}
// TODO: seq_length is not yet used for copy
void update_cache(
const Tensor& projected_value,
const Tensor& cache,
int64_t start_pos,
int64_t seq_length) { // NOLINT: unused parameter 'seq_length'
// 1) Cache shape should be [bs, max_seq_len, num heads, head dim]
// 2) projected_value shape should be [bs, seq_len, num heads, head dim]
// 3) We're updating the cache with projected_value, at position start_pos
ET_CHECK_MSG(
projected_value.size(0) == cache.size(0),
"projected_value batch size should be equal to the cache batch size.");
ET_CHECK_MSG(
projected_value.size(2) == cache.size(2),
"projected_value number of heads should be equal to the cache number of heads.");
ET_CHECK_MSG(
projected_value.size(3) == cache.size(3),
"projected_value embedding dimension should be equal to the cache embedding dimension.");
ET_CHECK_MSG(
projected_value.element_size() == cache.element_size(),
"projected_value data type size should be equal to the cache data type size.");
ET_CHECK_MSG(
is_contiguous_dim_order(
projected_value.dim_order().data(), projected_value.dim()),
"projected value must be in contiguous dim order");
const void* projected_value_data = projected_value.const_data_ptr();
void* cache_data = cache.mutable_data_ptr();
ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null");
ET_CHECK_MSG(cache_data, "cache data is null");
auto cache_strides = cache.strides();
::executorch::aten::StridesType cache_batch_dim_stride = cache_strides[0];
::executorch::aten::StridesType cache_seq_dim_stride = cache_strides[1];
auto value_strides = projected_value.strides();
::executorch::aten::StridesType value_batch_dim_stride = value_strides[0];
::executorch::aten::SizesType num_bytes_to_copy =
(projected_value.numel() / projected_value.size(0)) *
projected_value.element_size();
for (int64_t batch_line = 0; batch_line < projected_value.size(0);
++batch_line) {
::executorch::aten::SizesType cache_pos_offset =
(batch_line * cache_batch_dim_stride +
start_pos * cache_seq_dim_stride) *
cache.element_size();
::executorch::aten::SizesType value_pos_offset =
(batch_line * value_batch_dim_stride) * cache.element_size();
std::memcpy(
(uint8_t*)cache_data + cache_pos_offset,
(uint8_t*)projected_value_data + value_pos_offset,
num_bytes_to_copy);
}
}
} // anonymous namespace
Tensor& flash_attention_kernel_out(
RuntimeContext& ctx,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
validate_flash_attention_args(query, key, value, attn_mask),
InvalidArgument,
output);
ET_KERNEL_CHECK(
ctx,
resize_tensor(output, query.sizes()) == Error::Ok,
InvalidArgument,
output);
auto seq_len = query.size(2);
ET_SWITCH_FLOAT_TYPES(
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
output,
query,
key,
value,
dropout_p,
is_causal,
attn_mask,
scale,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt);
}
});
return output;
}
Tensor& custom_sdpa_out_impl(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output,
const optional<Tensor>& q_zero_points = nullopt,
const optional<Tensor>& q_scales = nullopt,
const optional<Tensor>& k_zero_points = nullopt,
const optional<Tensor>& k_scales = nullopt,
const optional<Tensor>& v_zero_points = nullopt,
const optional<Tensor>& v_scales = nullopt,
bool is_seq_at_dim_2 = false) {
ET_KERNEL_CHECK_MSG(
ctx,
!attn_mask.has_value() || !is_causal,
InvalidArgument,
output,
"attn_mask and is_causal cannot be set at the same time");
ET_KERNEL_CHECK_MSG(
ctx,
validate_flash_attention_args(q, k, v, attn_mask),
InvalidArgument,
output,
"Invalid arguments");
int64_t seq_len = q.size(1);
SeqDim seq_dim{SeqDim::TWO};
if (!is_seq_at_dim_2) {
seq_dim = SeqDim::ONE;
}
if (q.scalar_type() == ScalarType::Char) {
if (seq_dim == SeqDim::TWO) {
seq_len = q.size(2);
}
ET_KERNEL_CHECK_MSG(
ctx,
q_scales.has_value() && q_zero_points.has_value() &&
k_scales.has_value() && k_zero_points.has_value() &&
q_scales.has_value() && q_zero_points.has_value(),
InvalidArgument,
output,
"If q is quantized, k and v must be quantized as well");
ET_KERNEL_CHECK_MSG(
ctx,
validate_cache_quant_params_args(
q, q_zero_points.value(), q_scales.value()),
InvalidArgument,
output,
"Invalid arguments for quantized query");
ET_KERNEL_CHECK_MSG(
ctx,
validate_cache_quant_params_args(
k, k_zero_points.value(), k_scales.value()),
InvalidArgument,
output,
"Invalid arguments for quantized key");
ET_KERNEL_CHECK_MSG(
ctx,
validate_cache_quant_params_args(
v, v_zero_points.value(), v_scales.value()),
InvalidArgument,
output,
"Invalid arguments for quantized value");
}
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
const int64_t num_keys_for_causal_attention = start_pos + seq_len;
ET_KERNEL_CHECK(
ctx,
resize_tensor(output, q.sizes()) == Error::Ok,
InvalidArgument,
output);
// TODO(task): replace the template param selection logic
// with whatever apprpriately makes more sense for
ET_SWITCH_FLOAT_TYPES(
output.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
q_zero_points, // q_zero_points
q_scales, // q_scales
k_zero_points, // k_zero_points
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
seq_dim, /* seq_dim */
start_pos,
num_keys_for_causal_attention);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
q_zero_points, // q_zero_points
q_scales, // q_scales
k_zero_points, // k_zero_points
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
seq_dim, /* seq_dim */
start_pos,
num_keys_for_causal_attention);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
output,
q,
k,
v,
dropout_p,
is_causal,
attn_mask,
scale,
q_zero_points, // q_zero_points
q_scales, // q_scales
k_zero_points, // k_zero_points
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
seq_dim, /* seq_dim */
start_pos,
num_keys_for_causal_attention);
}
});
return output;
}
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
Tensor& custom_quantized_sdpa_out(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const optional<Tensor>& q_zero_points,
const optional<Tensor>& q_scales,
const optional<Tensor>& k_zero_points,
const optional<Tensor>& k_scales,
const optional<Tensor>& v_zero_points,
const optional<Tensor>& v_scales,
const bool is_seq_at_dim_2,
Tensor& output) {
return custom_sdpa_out_impl(
ctx,
q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
output,
q_zero_points,
q_scales,
k_zero_points,
k_scales,
v_zero_points,
v_scales,
is_seq_at_dim_2);
}
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
/*
Input params
@param[in] q_projected Projected query with query weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] k_projected Projected query with key weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] v_projected Projected query with value weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] key_cache Cache of previous k_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
@param[in] key_cache Cache of previous v_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
....
@param[in] start_pos: sequence position
*/
Tensor& custom_sdpa_out(
RuntimeContext& ctx,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const int64_t start_pos,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
return custom_sdpa_out_impl(
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
}
/*
Input params
@param[in] q_projected Projected query with query weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] k_projected Projected query with key weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] v_projected Projected query with value weights.
Format [n_layers, batch size, seq_len, num heads, head dim]
@param[in] key_cache Cache of previous k_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
@param[in] key_cache Cache of previous v_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
....
@param[in] start_pos: sequence position
@param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
*/
Tensor& sdpa_with_kv_cache_out(
KernelRuntimeContext& ctx,
const Tensor& q_projected,
const Tensor& k_projected,
const Tensor& v_projected,
Tensor& key_cache,
Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
const optional<Tensor>& attn_mask,
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
Tensor& output) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
validate_cache_params(key_cache, value_cache, start_pos, seq_len),
InvalidArgument,
output);
ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");
update_cache(k_projected, key_cache, start_pos, seq_len);
update_cache(v_projected, value_cache, start_pos, seq_len);
custom_sdpa_out(
ctx,
q_projected,
key_cache,
value_cache,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
output);
return output;
}
} // namespace native
} // namespace executor
} // namespace torch
EXECUTORCH_LIBRARY(
llama,
"sdpa_with_kv_cache.out",
torch::executor::native::sdpa_with_kv_cache_out);
EXECUTORCH_LIBRARY(
llama,
"custom_sdpa.out",
torch::executor::native::custom_sdpa_out);
#ifdef ENABLE_CUSTOM_QUANTIZED_SDPA
EXECUTORCH_LIBRARY(
llama,
"custom_quantized_sdpa.out",
torch::executor::native::custom_quantized_sdpa_out);
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA