forked from fairy-stockfish/variant-nnue-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
feature_transformer.py
605 lines (500 loc) · 23.6 KB
/
feature_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
import torch
from torch import nn
from torch import autograd
import cupy as cp
import math
def _find_nearest_divisor(value, target):
divisors = []
for i in range(1, value+1):
if value % i == 0:
divisors.append((i, abs(target-i)))
divisors.sort(key=lambda x:x[1])
return divisors[0][0]
_num_threads_forward_cache = dict()
def _get_num_threads_for_forward(output_size):
optimal_num_threads = 512
if output_size not in _num_threads_forward_cache:
_num_threads_forward_cache[output_size] = _find_nearest_divisor(output_size, optimal_num_threads)
return _num_threads_forward_cache[output_size]
_num_threads_backward_cache = dict()
def _get_num_threads_for_backward(output_size):
optimal_num_threads = 512
if output_size not in _num_threads_backward_cache:
_num_threads_backward_cache[output_size] = _find_nearest_divisor(output_size, optimal_num_threads)
return _num_threads_backward_cache[output_size]
def _kernel_with_threads(kernel, threads):
def f(grid, args):
kernel(grid=grid, block=threads, args=args)
return f
_feature_transformer_slice_forward_kernel_cache = dict()
def make_feature_transformer_slice_forward_kernel(max_active_features, output_size):
'''
@param: max_active_features
The maximum number of features that are active
(non-zero) for a single position. This value determines
the shape of the inputs.
This value is of type uint32_t.
@param: output_size
The number of outputs. Must match the shape of weights
and biases.
This value is of type uint32.
'''
num_threads = _get_num_threads_for_forward(output_size)
output_thread_slice_size = output_size // num_threads
key = (max_active_features, output_size, num_threads)
if key not in _feature_transformer_slice_forward_kernel_cache:
kernel = cp.RawKernel(r'''
typedef unsigned int uint32_t;
typedef int int32_t;
extern "C" __global__
/*
@assumptions:
The blocks must have dimensionality (BATCH_SIZE,)
The threads must have dimensionality (N,), where
N * output_thread_slice_size == output_size.
@param: feature_indices
A matrix of shape (BATCH_SIZE, max_active_features)
containing indices of active features for each position
in a batch. Feature index of -1 means that the slot is empty
and the weights will not be accumulated for it. Moreover
no further indices from this block will be considered.
The indices form an implicit matrix of shape
(BATCH_SIZE, NUM_INPUTS), where the first dimension index is
inferred from the memory location (BATCH_SIZE), and the
second dimension index is stored in the feature_indices matrix.
The type for feature indices is int32_t.
@param: feature_values
A matrix of shape (BATCH_SIZE, max_active_features)
containing the values (arity) of the corresponding
feature index in feature_indices.
The type for the feature value (arity) is float32.
@param: weight
The weight matrix of shape (NUM_INPUTS, output_size).
Weights must be of type float32.
@param: bias
The bias vector of shape (output_size,).
Bias values must be of type float32.
@param: output
An output matrix of shape (BATCH_SIZE, output_size).
It may not be initialized, bias is always copied
to the output first.
Output values must have type float32.
*/
void feature_transformer_slice_forward(
const int32_t* const feature_indices,
const float* const feature_values,
const float* const weight,
const float* const bias,
float* const output
) {{
__shared__
float shared_output[{output_size}];
const uint32_t block_idx = blockIdx.x;
const uint32_t slice_offset = threadIdx.x * {output_thread_slice_size};
float* const output_slice = output + block_idx * {output_size} + slice_offset;
const float* const bias_slice = bias + slice_offset;
float* shared_output_slice = shared_output + slice_offset;
const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features};
const float* const feature_value_row = feature_values + block_idx * {max_active_features};
#pragma unroll
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
{{
shared_output_slice[s] = bias_slice[s];
}}
for (uint32_t k = 0; k < {max_active_features}; ++k)
{{
const int32_t feature_index = feature_index_row[k];
const float feature_value = feature_value_row[k];
if (feature_index != -1)
{{
const float* const weight_slice = weight + feature_index * {output_size} + slice_offset;
#pragma unroll
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
{{
shared_output_slice[s] += weight_slice[s] * feature_value;
}}
}} else break;
}}
#pragma unroll
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
{{
output_slice[s] = shared_output_slice[s];
}}
}}
'''.format(
max_active_features=max_active_features,
output_thread_slice_size=output_thread_slice_size,
output_size=output_size),
'feature_transformer_slice_forward')
kernel.compile()
_feature_transformer_slice_forward_kernel_cache[key] = _kernel_with_threads(kernel, (num_threads,))
return _feature_transformer_slice_forward_kernel_cache[key]
_feature_transformer_slice_backward_kernel_cache = dict()
def make_feature_transformer_slice_backward_kernel(max_active_features, output_size):
''''
@param: max_active_features
The maximum number of features that are active
(non-zero) for a single position. This value determines
the shape of the inputs.
This value is of type uint32_t.
@param: output_size
The number of outputs. Must match the shape of weights
and biases.
This value is of type uint32.
'''
num_threads = _get_num_threads_for_backward(output_size)
output_thread_slice_size = output_size // num_threads
key = (max_active_features, output_size, num_threads)
if key not in _feature_transformer_slice_backward_kernel_cache:
kernel = cp.RawKernel(r'''
typedef unsigned int uint32_t;
typedef int int32_t;
extern "C" __global__
/*
@assumptions:
The blocks must have dimensionality (BATCH_SIZE,)
The threads must have dimensionality (N,), where
N * output_thread_slice_size == output_size.
@param: feature_indices
A matrix of shape (BATCH_SIZE, max_active_features)
containing indices of active features for each position
in a batch. Feature index of -1 means that the slot is empty
and the weights will not be accumulated for it. Moreover
no further indices from this block will be considered.
The indices form an implicit matrix of shape
(BATCH_SIZE, NUM_INPUTS), where the first dimension index is
inferred from the memory location (BATCH_SIZE), and the
second dimension index is stored in the feature_indices matrix.
The type for feature indices is int32_t.
@param: feature_values
A matrix of shape (BATCH_SIZE, max_active_features)
containing the values (arity) of the corresponding
feature index in feature_indices.
The type for the feature value (arity) is float32.
@param: weight_grad
The weight gradient matrix of shape (NUM_INPUTS, output_size).
The gradient is accumulated, i.e. it must be zero initialized
on the first call.
Weights must be of type float32.
@param: bias_grad
The bias gradient vector of shape (output_size,).
The gradient is accumulated, i.e. it must be zero initialized
on the first call.
Bias values must be of type float32.
@param: output_grad
An output gradient matrix of shape (BATCH_SIZE, output_size).
Output values must have type float32.
*/
void feature_transformer_slice_backward(
const int32_t* const feature_indices,
const float* const feature_values,
float* const weight_grad,
float* const bias_grad,
const float* const output_grad
) {{
__shared__
float shared_output_grad[{output_size}];
const uint32_t block_idx = blockIdx.x;
const uint32_t slice_offset = threadIdx.x * {output_thread_slice_size};
const float* const output_grad_slice = output_grad + block_idx * {output_size} + slice_offset;
float* const bias_grad_slice = bias_grad + slice_offset;
float* shared_output_grad_slice = shared_output_grad + slice_offset;
const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features};
const float* const feature_value_row = feature_values + block_idx * {max_active_features};
#pragma unroll
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
{{
shared_output_grad_slice[s] = output_grad_slice[s];
}}
#pragma unroll
for (uint32_t s = 0; s < {output_thread_slice_size}; ++s)
{{
const float sog = shared_output_grad_slice[s];
if (sog != 0.0f)
{{
atomicAdd(&bias_grad_slice[s], sog);
}}
}}
for (uint32_t k = 0; k < {max_active_features}; ++k)
{{
const int32_t feature_index = feature_index_row[k];
const float feature_value = feature_value_row[k];
if (feature_index != -1)
{{
float* const weight_grad_slice = weight_grad + feature_index * {output_size} + slice_offset;
#pragma unroll
for (int s = 0; s < {output_thread_slice_size}; ++s)
{{
const float sog = shared_output_grad_slice[s];
if (sog != 0.0f)
{{
atomicAdd(&weight_grad_slice[s], sog * feature_value);
}}
}}
}} else break;
}}
}}
'''.format(
max_active_features=max_active_features,
output_thread_slice_size=output_thread_slice_size,
output_size=output_size),
'feature_transformer_slice_backward')
kernel.compile()
_feature_transformer_slice_backward_kernel_cache[key] = _kernel_with_threads(kernel, (num_threads,))
return _feature_transformer_slice_backward_kernel_cache[key]
class FeatureTransformerSliceFunction(autograd.Function):
@staticmethod
def forward(ctx, feature_indices, feature_values, weight, bias):
ctx.save_for_backward(feature_indices, feature_values, weight, bias)
assert len(feature_indices.shape) == 2
assert len(feature_values.shape) == 2
assert feature_indices.shape[0] == feature_values.shape[0]
assert feature_indices.shape[1] == feature_values.shape[1]
assert feature_indices.dtype == torch.int32
assert feature_values.dtype == torch.float32
assert len(weight.shape) == 2
assert weight.dtype == torch.float32
assert len(bias.shape) == 1
assert bias.dtype == torch.float32
assert feature_indices.is_cuda
assert feature_values.is_cuda
assert weight.is_cuda
assert bias.is_cuda
assert feature_values.device == feature_indices.device
assert weight.device == feature_indices.device
assert bias.device == feature_indices.device
assert feature_indices.is_contiguous()
assert feature_values.is_contiguous()
assert weight.is_contiguous()
assert bias.is_contiguous()
device = feature_indices.device
batch_size = feature_indices.shape[0]
max_active_features = feature_indices.shape[1]
output_size = weight.shape[1]
output = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True)
kernel = make_feature_transformer_slice_forward_kernel(max_active_features, output_size)
kernel(
grid=(batch_size,),
args=(
feature_indices.data_ptr(),
feature_values.data_ptr(),
weight.data_ptr(),
bias.data_ptr(),
output.data_ptr()
)
)
return output
@staticmethod
def backward(ctx, grad_output):
assert not ctx.needs_input_grad[0]
assert not ctx.needs_input_grad[1]
grad_output = grad_output.contiguous()
feature_indices, feature_values, weight, bias = ctx.saved_tensors
device = feature_indices.device
batch_size = feature_indices.shape[0]
max_active_features = feature_indices.shape[1]
output_size = weight.shape[1]
weight_grad = torch.zeros(weight.shape[0], weight.shape[1], dtype=torch.float32, device=device)
bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device)
kernel = make_feature_transformer_slice_backward_kernel(max_active_features, output_size)
kernel(
grid=(batch_size,),
args=(
feature_indices.data_ptr(),
feature_values.data_ptr(),
weight_grad.data_ptr(),
bias_grad.data_ptr(),
grad_output.data_ptr()
)
)
return None, None, weight_grad, bias_grad
class DoubleFeatureTransformerSliceFunction(autograd.Function):
@staticmethod
def forward(ctx, feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, weight, bias):
ctx.save_for_backward(feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, weight, bias)
assert len(feature_indices_0.shape) == 2
assert len(feature_values_0.shape) == 2
assert feature_indices_0.shape[0] == feature_values_0.shape[0]
assert feature_indices_0.shape[1] == feature_values_0.shape[1]
assert feature_indices_0.dtype == torch.int32
assert feature_values_0.dtype == torch.float32
assert len(feature_indices_1.shape) == 2
assert len(feature_values_1.shape) == 2
assert feature_indices_1.shape[0] == feature_values_1.shape[0]
assert feature_indices_1.shape[1] == feature_values_1.shape[1]
assert feature_indices_1.dtype == torch.int32
assert feature_values_1.dtype == torch.float32
assert len(weight.shape) == 2
assert weight.dtype == torch.float32
assert len(bias.shape) == 1
assert bias.dtype == torch.float32
assert feature_indices_0.is_cuda
assert feature_values_0.is_cuda
assert feature_indices_1.is_cuda
assert feature_values_1.is_cuda
assert weight.is_cuda
assert bias.is_cuda
assert feature_values_0.device == feature_indices_0.device
assert feature_values_1.device == feature_indices_1.device
assert feature_indices_0.device == feature_indices_1.device
assert weight.device == feature_indices_0.device
assert bias.device == feature_indices_0.device
assert feature_indices_0.is_contiguous()
assert feature_values_0.is_contiguous()
assert feature_indices_1.is_contiguous()
assert feature_values_1.is_contiguous()
assert weight.is_contiguous()
assert bias.is_contiguous()
device = feature_indices_0.device
batch_size = feature_indices_0.shape[0]
max_active_features = feature_indices_0.shape[1]
output_size = weight.shape[1]
output0 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True)
output1 = torch.empty(batch_size, output_size, dtype=torch.float32, device=device, requires_grad=True)
kernel = make_feature_transformer_slice_forward_kernel(max_active_features, output_size)
kernel(
grid=(batch_size,),
args=(
feature_indices_0.data_ptr(),
feature_values_0.data_ptr(),
weight.data_ptr(),
bias.data_ptr(),
output0.data_ptr()
)
)
kernel(
grid=(batch_size,),
args=(
feature_indices_1.data_ptr(),
feature_values_1.data_ptr(),
weight.data_ptr(),
bias.data_ptr(),
output1.data_ptr()
)
)
return output0, output1
@staticmethod
def backward(ctx, grad_output_0, grad_output_1):
assert not ctx.needs_input_grad[0]
assert not ctx.needs_input_grad[1]
grad_output_0 = grad_output_0.contiguous()
grad_output_1 = grad_output_1.contiguous()
feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, weight, bias = ctx.saved_tensors
device = feature_indices_0.device
batch_size = feature_indices_0.shape[0]
max_active_features = feature_indices_0.shape[1]
output_size = weight.shape[1]
weight_grad = torch.zeros(weight.shape[0], weight.shape[1], dtype=torch.float32, device=device)
bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device)
kernel = make_feature_transformer_slice_backward_kernel(max_active_features, output_size)
kernel(
grid=(batch_size,),
args=(
feature_indices_0.data_ptr(),
feature_values_0.data_ptr(),
weight_grad.data_ptr(),
bias_grad.data_ptr(),
grad_output_0.data_ptr()
)
)
kernel(
grid=(batch_size,),
args=(
feature_indices_1.data_ptr(),
feature_values_1.data_ptr(),
weight_grad.data_ptr(),
bias_grad.data_ptr(),
grad_output_1.data_ptr()
)
)
return None, None, None, None, weight_grad, bias_grad
class FeatureTransformerSlice(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(FeatureTransformerSlice, self).__init__()
self.num_inputs = num_inputs
self.num_outputs = num_outputs
sigma = math.sqrt(1/num_inputs)
self.weight = nn.Parameter(torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma) - sigma)
self.bias = nn.Parameter(torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma)
def forward(self, feature_indices, feature_values):
return FeatureTransformerSliceFunction.apply(feature_indices, feature_values, self.weight, self.bias)
class DoubleFeatureTransformerSlice(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(DoubleFeatureTransformerSlice, self).__init__()
self.num_inputs = num_inputs
self.num_outputs = num_outputs
sigma = math.sqrt(1/num_inputs)
self.weight = nn.Parameter(torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma) - sigma)
self.bias = nn.Parameter(torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma)
def forward(self, feature_indices_0, feature_values_0, feature_indices_1, feature_values_1):
return DoubleFeatureTransformerSliceFunction.apply(feature_indices_0, feature_values_0, feature_indices_1, feature_values_1, self.weight, self.bias)
if __name__ == '__main__':
import time
import sys
import os
def FeatureTransformerSliceFunctionEmulate(feature_indices, feature_values, weight, bias):
batch_size = feature_indices.shape[0]
num_inputs = weight.shape[0]
max_active_features = feature_indices.shape[1]
inputs = torch.zeros(batch_size, num_inputs, dtype=torch.float32, device=weight.device)
for i in range(batch_size):
for j in range(max_active_features):
feature = feature_indices[i, j]
value = feature_values[i, j]
inputs[i, feature] += value
return torch.mm(inputs, weight) + bias
def test():
BATCH_SIZE = 16
INPUT_SIZE = 10
MAX_ACTIVE_FEATURES = 32
STRIDE = 128
MAX_ERROR = 1e-4
torch.manual_seed(0)
weight0 = torch.rand(INPUT_SIZE, STRIDE, dtype=torch.float32, requires_grad=True)
bias0 = torch.rand(STRIDE, dtype=torch.float32, requires_grad=True)
torch.manual_seed(0)
weight1 = torch.rand(INPUT_SIZE, STRIDE, dtype=torch.float32, requires_grad=True)
bias1 = torch.rand(STRIDE, dtype=torch.float32, requires_grad=True)
indices0 = (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES) * INPUT_SIZE).to(dtype=torch.int32)
indices1 = (torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES) * INPUT_SIZE).to(dtype=torch.int32)
values0 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32)
values1 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32)
output00 = FeatureTransformerSliceFunctionEmulate(indices0.clone(), values0.clone(), weight0, bias0)
output01 = FeatureTransformerSliceFunctionEmulate(indices1.clone(), values1.clone(), weight0, bias0)
#output10 = FeatureTransformerSliceFunction.apply(indices0.clone().cuda(), values0.clone().cuda(), weight1.cuda(), bias1.cuda())
#output11 = FeatureTransformerSliceFunction.apply(indices1.clone().cuda(), values1.clone().cuda(), weight1.cuda(), bias1.cuda())
output10, output11 = DoubleFeatureTransformerSliceFunction.apply(indices0.clone().cuda(), values0.clone().cuda(), indices1.clone().cuda(), values1.clone().cuda(), weight1.cuda(), bias1.cuda())
assert torch.max(output00.cpu() - output10.cpu()) < MAX_ERROR
assert torch.max(output01.cpu() - output11.cpu()) < MAX_ERROR
(output00 - output01).sum().backward()
(output10 - output11).sum().backward()
assert torch.max(weight0.grad.cpu() - weight1.grad.cpu()) < MAX_ERROR
assert torch.max(bias0.grad.cpu() - bias1.grad.cpu()) < MAX_ERROR
print('Tests passed.')
def bench():
INPUT_SIZE = 40960
BATCH_SIZE = 8192
ITERS = 64
STRIDE = 264
MAX_ACTIVE_FEATURES = 64
layer = DoubleFeatureTransformerSlice(INPUT_SIZE, STRIDE).cuda()
indices0 = torch.cat([torch.sort((torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES * 3 // 4) * INPUT_SIZE), dim=1)[0].to(dtype=torch.int32), torch.full((BATCH_SIZE, MAX_ACTIVE_FEATURES // 4), -1, dtype=torch.int32)], dim=1).cuda()
values0 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32).cuda()
indices1 = torch.cat([torch.sort((torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES * 3 // 4)) * INPUT_SIZE, dim=1)[0].to(dtype=torch.int32), torch.full((BATCH_SIZE, MAX_ACTIVE_FEATURES // 4), -1, dtype=torch.int32)], dim=1).cuda()
values1 = torch.rand(BATCH_SIZE, MAX_ACTIVE_FEATURES, dtype=torch.float32).cuda()
output0, output1 = layer(indices0, values0, indices1, values1)
device = indices0.device
start = time.time()
for i in range(ITERS):
output0, output1 = layer(indices0, values0, indices1, values1)
output0 = torch.clamp(output0, 0.0, 1.0)
output1 = torch.clamp(output1, 0.0, 1.0)
g = ((output0 - output1)**2).mean()
g.backward()
torch.cuda.synchronize()
end = time.time()
#for param in layer.parameters():
# print(param.grad)
print('{} pos/s'.format((ITERS * BATCH_SIZE) / (end - start)))
test()
bench()