-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathutils.cu
695 lines (538 loc) · 31 KB
/
utils.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
//
// lmbspecialops - a collection of tensorflow ops
// Copyright (C) 2017 Benjamin Ummenhofer, Huizhong Zhou
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//
#define EIGEN_USE_GPU
#include "/usr/local/cuda/include/vector_types.h"
#include "/usr/local/cuda/include/cuda.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/host/host_platform_id.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/scratch_allocator.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/kernel.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/register_types.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h"
using namespace tensorflow;
inline int divup(int x,int y){
div_t tmp = std::div(x,y);
return tmp.quot+(tmp.rem!=0?1:0);
}
struct CudaLaunchConfig {
int virtual_thread_count = -1;
int thread_per_block = -1;
int block_count = -1;
};
const int d_getNumGpuMultiProcessors=24;
const int d_maxGpuThreadsPerMultiProcessor=2048;
const int d_maxGpuThreadsPerBlock=1024;
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count){
CHECK_GT(work_element_count, 0);
CudaLaunchConfig config;
const int virtual_thread_count = work_element_count;
const int physical_thread_count = std::min(
d_getNumGpuMultiProcessors*d_maxGpuThreadsPerMultiProcessor,
virtual_thread_count);
const int thread_per_block = std::min(1024,d_maxGpuThreadsPerBlock);
const int block_count =std::min(divup(physical_thread_count,thread_per_block),d_getNumGpuMultiProcessors);
config.virtual_thread_count = virtual_thread_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
return config;
}
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
const cudaStream_t* ptr = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack()));
return *ptr;
}
perftools::gputools::DeviceMemory<float> AsDeviceMemory(const float* cuda_memory) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<float*>(cuda_memory));
perftools::gputools::DeviceMemory<float> typed(wrapped);
return typed;
}
class CublasScratchAllocator : public perftools::gputools::ScratchAllocator {
public:
using Stream = ::perftools::gputools::Stream;
using DeviceMemoryBytes = ::perftools::gputools::DeviceMemory<uint8>;
CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
perftools::gputools::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
Stream* stream, int64 byte_size) override {
Tensor temporary_memory;
Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
if (!allocation_status.ok()) {
return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
DeviceMemoryBytes::MakeFromByteSize(
temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}
private:
OpKernelContext* context_;
std::vector<Tensor> allocated_tensors_;
};
REGISTER_OP("WarpComputation")
.Input("imgs_flat:float")
.Input("index:int32")
.Input("w:float")
.Output("output:float");
__global__ void warp_kernel(float* out, const float* in, const int* index,const float* weights,const int npixels,const int nchannels){
int pixel = blockIdx.x * blockDim.x + threadIdx.x;
int channel = blockIdx.y * blockDim.y + threadIdx.y;
if( channel >= nchannels||pixel >= npixels)
return;
out[nchannels*pixel+channel]=in[nchannels*index[4*pixel]+channel]*weights[4*pixel]
+in[nchannels*index[4*pixel+1]+channel]*weights[4*pixel+1]
+in[nchannels*index[4*pixel+2]+channel]*weights[4*pixel+2]
+in[nchannels*index[4*pixel+3]+channel]*weights[4*pixel+3];
}
REGISTER_OP("EquationConstruction")
.Input("jacobian:float")
.Input("gradient:float")
.Input("difference:float")
.Output("left:float")
.Output("right:float")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle batch_size=c->Vector(c->Dim(c->input(0),0));
shape_inference::DimensionHandle jacobian_cols=c->Dim(c->input(0),3);
shape_inference::DimensionHandle one=c->Dim(c->input(2),3);
shape_inference::ShapeHandle output1_shape;
shape_inference::ShapeHandle output2_shape;
c->Concatenate(batch_size,c->Matrix(jacobian_cols,jacobian_cols),&output1_shape);
c->Concatenate(batch_size,c->Matrix(jacobian_cols,one),&output2_shape);
c->set_output(0,output1_shape);
c->set_output(1,output2_shape);
return Status::OK();
});
struct Sum {
__host__ __device__ float operator()(const float& a, const float& b) const {
return a + b;
}
};
__global__ void ColumnReduceSimpleKernel(const float* in,float* out, int num_planes,
int num_rows, int num_cols) {
const int gid = threadIdx.x + blockIdx.x * blockDim.x;
const int elems_per_plane = num_rows * num_cols;
const int plane = gid / num_cols;
const int col = gid % num_cols;
if (plane >= num_planes)
return;
float sum = in[plane * elems_per_plane + col]+in[plane * elems_per_plane + num_cols + col];
for (int row = 2; row < num_rows; ++row) {
sum = sum+in[plane * elems_per_plane + row * num_cols + col];
}
out[plane * num_cols + col] = sum;
}
//template <typename T, typename Op, typename OUT_T, typename IN_T>
void Launch3DYReduction(const cudaStream_t& cu_stream,float* out,const float* in, int extent_x,
int extent_y, int extent_z) {
int threads_per_block = 128;
int num_blocks =(extent_x * extent_z + threads_per_block - 1) / threads_per_block;
ColumnReduceSimpleKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>(
in, out, extent_x, extent_y, extent_z);
}
#define MAX_BUFFERS 5
#define MAX_GPUS 4
typedef perftools::gputools::DeviceMemory<float> DeviceMemoryType;
static Tensor* buffer[MAX_GPUS]={nullptr};
static std::vector<DeviceMemoryType> memory[MAX_BUFFERS*MAX_GPUS];
static std::vector<DeviceMemoryType*> memory_ptr[MAX_BUFFERS*MAX_GPUS];
class EquationConstruction: public OpKernel
{
private:
int gpu_id;
public:
explicit EquationConstruction(OpKernelConstruction* context):OpKernel(context){
gpu_id=context->device()->tensorflow_gpu_device_info()->gpu_id;
//std::cout<<"gpu:"<<gpu_id<<std::endl;
}
void Compute( OpKernelContext* context ) override {
//std::cout<<"computing:"<<gpu_id<<std::endl;
const Tensor& jacobian=context->input(0);
const TensorShape jacobian_shape(jacobian.shape());
const Tensor& gradient=context->input(1);
const TensorShape gradient_shape(gradient.shape());
const Tensor& difference=context->input(2);
const TensorShape difference_shape(difference.shape());
int batch_size=jacobian_shape.dim_size(0);
int npixels =jacobian_shape.dim_size(1);
int n_matrix =batch_size*npixels;
int jacobian_rows=jacobian_shape.dim_size(2);
int jacobian_cols=jacobian_shape.dim_size(3);
int gradient_rows=gradient_shape.dim_size(2);
int gradient_cols=gradient_shape.dim_size(3);
int difference_rows=difference_shape.dim_size(2);
int difference_cols=difference_shape.dim_size(3);
if (buffer[gpu_id]==nullptr){
TensorShape buffer_shape(jacobian.shape());
buffer_shape.set_dim(2,jacobian_rows+jacobian_cols);
PersistentTensor* newtensor= new PersistentTensor();
context->allocate_persistent(jacobian.dtype(),buffer_shape,newtensor,&buffer[gpu_id]);
//std::cout<<"EquationConstruction"<<buffer[gpu_id]<<std::endl;
const float* buffer_base_ptr1 = buffer[gpu_id]->flat<float>().data();
const float* buffer_base_ptr2 = buffer_base_ptr1+(n_matrix*jacobian_cols*jacobian_cols);
for(int i=0;i<MAX_BUFFERS;i++){
memory[MAX_BUFFERS*gpu_id+i].clear();
memory[MAX_BUFFERS*gpu_id+i].reserve(n_matrix);
memory_ptr[MAX_BUFFERS*gpu_id+i].clear();
memory_ptr[MAX_BUFFERS*gpu_id+i].reserve(n_matrix);
}
for(int i=0;i<n_matrix;i++){
memory[MAX_BUFFERS*gpu_id].push_back(AsDeviceMemory(buffer_base_ptr1+i*gradient_cols*gradient_cols));
memory[MAX_BUFFERS*gpu_id+1].push_back(AsDeviceMemory(buffer_base_ptr2+i*jacobian_rows*jacobian_cols));
memory[MAX_BUFFERS*gpu_id+2].push_back(AsDeviceMemory(buffer_base_ptr1+i*jacobian_cols*jacobian_cols));
memory[MAX_BUFFERS*gpu_id+3].push_back(AsDeviceMemory(buffer_base_ptr2+i*gradient_cols*difference_cols));
memory[MAX_BUFFERS*gpu_id+4].push_back(AsDeviceMemory(buffer_base_ptr1+i*jacobian_cols));
memory_ptr[MAX_BUFFERS*gpu_id].push_back(&memory[MAX_BUFFERS*gpu_id].back());
memory_ptr[MAX_BUFFERS*gpu_id+1].push_back(&memory[MAX_BUFFERS*gpu_id+1].back());
memory_ptr[MAX_BUFFERS*gpu_id+2].push_back(&memory[MAX_BUFFERS*gpu_id+2].back());
memory_ptr[MAX_BUFFERS*gpu_id+3].push_back(&memory[MAX_BUFFERS*gpu_id+3].back());
memory_ptr[MAX_BUFFERS*gpu_id+4].push_back(&memory[MAX_BUFFERS*gpu_id+4].back());
}
//std::cout<<gpu_id<<" initialized"<<std::endl;
}
const float* jacobian_base_ptr = jacobian.flat<float>().data();
const float* gradient_base_ptr = gradient.flat<float>().data();
const float* difference_base_ptr=difference.flat<float>().data();
std::vector<DeviceMemoryType> jacobian_device_memory;
std::vector<DeviceMemoryType> gradient_device_memory;
std::vector<DeviceMemoryType> difference_device_memory;
std::vector<DeviceMemoryType*> jacobian_ptrs;
std::vector<DeviceMemoryType*> gradient_ptrs;
std::vector<DeviceMemoryType*> difference_ptrs;
jacobian_device_memory.reserve(n_matrix);
gradient_device_memory.reserve(n_matrix);
difference_device_memory.reserve(n_matrix);
jacobian_ptrs.reserve(n_matrix);
gradient_ptrs.reserve(n_matrix);
difference_ptrs.reserve(n_matrix);
for (int i = 0; i < n_matrix; ++i) {
jacobian_device_memory.push_back(AsDeviceMemory(jacobian_base_ptr+i*jacobian_rows*jacobian_cols));
gradient_device_memory.push_back(AsDeviceMemory(gradient_base_ptr+i*gradient_rows*gradient_cols));
difference_device_memory.push_back(AsDeviceMemory(difference_base_ptr+i*difference_rows*difference_cols));
jacobian_ptrs.push_back(&jacobian_device_memory.back());
gradient_ptrs.push_back(&gradient_device_memory.back());
difference_ptrs.push_back(&difference_device_memory.back());
}
//std::cout<<"computing0 done:"<<gpu_id<<std::endl;
CublasScratchAllocator scratch_allocator(context);
bool blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kTranspose,
gradient_cols, gradient_cols,gradient_rows,static_cast<float>(1.0),
gradient_ptrs, gradient_cols,
gradient_ptrs, gradient_cols,
static_cast<float>(0.0),memory_ptr[MAX_BUFFERS*gpu_id],gradient_cols,
n_matrix, &scratch_allocator).ok();
//std::cout<<"computing1 done:"<<gpu_id<<std::endl;
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kNoTranspose,
jacobian_cols,gradient_cols,jacobian_rows,static_cast<float>(1.0),
jacobian_ptrs, jacobian_cols,
memory_ptr[MAX_BUFFERS*gpu_id],gradient_cols,
static_cast<float>(0.0),memory_ptr[MAX_BUFFERS*gpu_id+1],jacobian_cols,
n_matrix, &scratch_allocator).ok();
//std::cout<<"computing2 done:"<<gpu_id<<std::endl;
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kTranspose,
jacobian_cols,jacobian_cols,gradient_cols,static_cast<float>(1.0),
memory_ptr[MAX_BUFFERS*gpu_id+1], jacobian_cols,
jacobian_ptrs, jacobian_cols,
static_cast<float>(0.0),memory_ptr[MAX_BUFFERS*gpu_id+2],jacobian_cols,
n_matrix, &scratch_allocator).ok();
//std::cout<<"computing3 done:"<<gpu_id<<std::endl;
TensorShape output_shape;
output_shape.AddDim(batch_size);
output_shape.AddDim(jacobian_cols);
output_shape.AddDim(jacobian_cols);
Tensor* output_tensor=NULL;
OP_REQUIRES_OK(context,context->allocate_output(0,output_shape,&output_tensor));
auto output = output_tensor->flat<float>();
auto input = buffer[gpu_id]->flat<float>();
int nelements=jacobian_cols*jacobian_cols;
Launch3DYReduction(GetCudaStream(context),output.data(),input.data(),batch_size,npixels,nelements);
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kTranspose,
difference_cols,gradient_cols,gradient_rows,static_cast<float>(1.0),
difference_ptrs,difference_cols,
gradient_ptrs,gradient_cols,
static_cast<float>(0.0),memory_ptr[MAX_BUFFERS*gpu_id+3],difference_cols,
n_matrix, &scratch_allocator).ok();
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kTranspose,
difference_cols,jacobian_cols,jacobian_rows,static_cast<float>(1.0),
memory_ptr[MAX_BUFFERS*gpu_id+3],difference_cols,
jacobian_ptrs, jacobian_cols,
static_cast<float>(0.0),memory_ptr[MAX_BUFFERS*gpu_id+4],difference_cols,
n_matrix, &scratch_allocator).ok();
TensorShape output2_shape;
output2_shape.AddDim(batch_size);
output2_shape.AddDim(jacobian_cols);
output2_shape.AddDim(difference_cols);
Tensor* output2_tensor=NULL;
OP_REQUIRES_OK(context,context->allocate_output(1,output2_shape,&output2_tensor));
auto output2 = output2_tensor->flat<float>();
nelements=jacobian_cols*difference_cols;
Launch3DYReduction(GetCudaStream(context),output2.data(),input.data(),batch_size,npixels,nelements);
}
};
REGISTER_KERNEL_BUILDER(Name("EquationConstruction").Device(DEVICE_GPU),EquationConstruction);
REGISTER_OP("EquationConstructionGrad")
.Input("jacobian:float")
.Input("gradient:float")
.Input("difference:float")
.Input("left_grad:float")
.Input("right_grad:float")
.Output("jacobian_grad:float")
.Output("gradient_grad:float")
.Output("difference_grad:float");
// .Output("test_output:float");
#define GRAD_MAX_BUFFERS 4
enum{
GRAD_0,
GRAD_1,
JACOBIAN_GRAD,
JACOBIAN_GRAD_GRAD
};
static Tensor* grad_buffer[MAX_GPUS]={nullptr};
static std::vector<DeviceMemoryType> grad_memory[GRAD_MAX_BUFFERS*MAX_GPUS];
static std::vector<DeviceMemoryType*> grad_memory_ptr[GRAD_MAX_BUFFERS*MAX_GPUS];
__global__ void tile_kernel(const float* in,float* out,
int num_planes, int num_rows, int num_cols) {
const int gid = threadIdx.x + blockIdx.x * blockDim.x;
const int elems_per_plane = num_rows * num_cols;
const int plane = gid / num_rows;
const int row = gid % num_rows;
if (plane >= num_planes)
return;
for (int col=0;col<num_cols; ++col){
out[plane * elems_per_plane + row * num_cols + col]=in[plane*num_cols+col];
}
}
void tile_gpu(const cudaStream_t& stream,float* out,const float* in,const int batch_size,const int npixels,const int nelements){
int threads_per_block = 128;
int num_blocks =(batch_size * npixels+ threads_per_block - 1) / threads_per_block;
tile_kernel<<<num_blocks,threads_per_block,0,stream>>>(in,out,batch_size,npixels,nelements);
}
class EquationConstructionGrad: public OpKernel
{
private:
int gpu_id;
public:
explicit EquationConstructionGrad(OpKernelConstruction* context):OpKernel(context){
gpu_id=context->device()->tensorflow_gpu_device_info()->gpu_id;
}
void Compute( OpKernelContext* context ) override {
const Tensor& jacobian=context->input(0);
const TensorShape jacobian_shape(jacobian.shape());
const Tensor& gradient=context->input(1);
const TensorShape gradient_shape(gradient.shape());
const Tensor& difference=context->input(2);
const TensorShape difference_shape(difference.shape());
const Tensor& grad0=context->input(3);
const Tensor& grad1=context->input(4);
Tensor* jacobian_grad_tensor=NULL;
OP_REQUIRES_OK(context,context->allocate_output(0,jacobian_shape,&jacobian_grad_tensor));
Tensor* gradient_grad_tensor=NULL;
OP_REQUIRES_OK(context,context->allocate_output(1,gradient_shape,&gradient_grad_tensor));
Tensor* difference_grad_tensor=NULL;
OP_REQUIRES_OK(context,context->allocate_output(2,difference_shape,&difference_grad_tensor));
int batch_size=jacobian_shape.dim_size(0);
int npixels =jacobian_shape.dim_size(1);
int n_matrix =batch_size*npixels;
int jacobian_rows=jacobian_shape.dim_size(2);
int jacobian_cols=jacobian_shape.dim_size(3);
int gradient_rows=gradient_shape.dim_size(2);
int gradient_cols=gradient_shape.dim_size(3);
int difference_rows=difference_shape.dim_size(2);
int difference_cols=difference_shape.dim_size(3);
// std::cout<<buffer[gpu_id]<<std::endl;
float* tiled_grad0_ptr=buffer[gpu_id]->flat<float>().data();
float* tiled_grad1_ptr=tiled_grad0_ptr+(n_matrix*jacobian_cols*jacobian_cols);
if (grad_buffer[gpu_id]==nullptr){
TensorShape buffer_shape(jacobian.shape());
buffer_shape.set_dim(2,2*gradient_rows);
PersistentTensor* newtensor= new PersistentTensor();
context->allocate_persistent(jacobian.dtype(),buffer_shape,newtensor,&grad_buffer[gpu_id]);
float* buffer_base_ptr1 = grad_buffer[gpu_id]->flat<float>().data();
float* buffer_base_ptr2 = buffer_base_ptr1+(n_matrix*gradient_rows*jacobian_cols);
for(int i=0;i<GRAD_MAX_BUFFERS;i++){
grad_memory[GRAD_MAX_BUFFERS*gpu_id+i].clear();
grad_memory[GRAD_MAX_BUFFERS*gpu_id+i].reserve(n_matrix);
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+i].clear();
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+i].reserve(n_matrix);
}
for(int i=0;i<n_matrix;i++){
grad_memory[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD].push_back(AsDeviceMemory(buffer_base_ptr2+i*gradient_rows*jacobian_cols));
grad_memory[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD_GRAD].push_back(AsDeviceMemory(buffer_base_ptr1+i*gradient_rows*jacobian_cols));
grad_memory[GRAD_MAX_BUFFERS*gpu_id+GRAD_0].push_back(AsDeviceMemory(tiled_grad0_ptr+i*jacobian_cols*jacobian_cols));
grad_memory[GRAD_MAX_BUFFERS*gpu_id+GRAD_1].push_back(AsDeviceMemory(tiled_grad1_ptr+i*difference_cols*jacobian_cols));
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD].push_back(&grad_memory[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD].back());
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD_GRAD].push_back(&grad_memory[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD_GRAD].back());
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+GRAD_0].push_back(&grad_memory[GRAD_MAX_BUFFERS*gpu_id+GRAD_0].back());
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+GRAD_1].push_back(&grad_memory[GRAD_MAX_BUFFERS*gpu_id+GRAD_1].back());
}
}
const float* jacobian_base_ptr = jacobian.flat<float>().data();
const float* gradient_base_ptr = gradient.flat<float>().data();
const float* difference_base_ptr=difference.flat<float>().data();
const float* jacobian_grad_base_ptr=jacobian_grad_tensor->flat<float>().data();
const float* gradient_grad_base_ptr=gradient_grad_tensor->flat<float>().data();
const float* difference_grad_base_ptr=difference_grad_tensor->flat<float>().data();
std::vector<DeviceMemoryType> jacobian_device_memory;
std::vector<DeviceMemoryType> gradient_device_memory;
std::vector<DeviceMemoryType> difference_device_memory;
std::vector<DeviceMemoryType> jacobian_grad_device_memory;
std::vector<DeviceMemoryType> gradient_grad_device_memory;
std::vector<DeviceMemoryType> difference_grad_device_memory;
std::vector<DeviceMemoryType*> jacobian_ptrs;
std::vector<DeviceMemoryType*> gradient_ptrs;
std::vector<DeviceMemoryType*> difference_ptrs;
std::vector<DeviceMemoryType*> jacobian_grad_ptrs;
std::vector<DeviceMemoryType*> gradient_grad_ptrs;
std::vector<DeviceMemoryType*> difference_grad_ptrs;
jacobian_device_memory.reserve(n_matrix);
gradient_device_memory.reserve(n_matrix);
difference_device_memory.reserve(n_matrix);
jacobian_grad_device_memory.reserve(n_matrix);
gradient_grad_device_memory.reserve(n_matrix);
difference_grad_device_memory.reserve(n_matrix);
jacobian_ptrs.reserve(n_matrix);
gradient_ptrs.reserve(n_matrix);
difference_ptrs.reserve(n_matrix);
jacobian_grad_ptrs.reserve(n_matrix);
gradient_grad_ptrs.reserve(n_matrix);
difference_grad_ptrs.reserve(n_matrix);
for (int i = 0; i < n_matrix; ++i) {
jacobian_device_memory.push_back(AsDeviceMemory(jacobian_base_ptr+i*jacobian_rows*jacobian_cols));
gradient_device_memory.push_back(AsDeviceMemory(gradient_base_ptr+i*gradient_rows*gradient_cols));
difference_device_memory.push_back(AsDeviceMemory(difference_base_ptr+i*difference_rows*difference_cols));
jacobian_ptrs.push_back(&jacobian_device_memory.back());
gradient_ptrs.push_back(&gradient_device_memory.back());
difference_ptrs.push_back(&difference_device_memory.back());
jacobian_grad_device_memory.push_back(AsDeviceMemory(jacobian_grad_base_ptr+i*jacobian_rows*jacobian_cols));
gradient_grad_device_memory.push_back(AsDeviceMemory(gradient_grad_base_ptr+i*gradient_rows*gradient_cols));
difference_grad_device_memory.push_back(AsDeviceMemory(difference_grad_base_ptr+i*difference_rows*difference_cols));
jacobian_grad_ptrs.push_back(&jacobian_grad_device_memory.back());
gradient_grad_ptrs.push_back(&gradient_grad_device_memory.back());
difference_grad_ptrs.push_back(&difference_grad_device_memory.back());
}
int grad0_nelements=jacobian_cols*jacobian_cols;
tile_gpu(GetCudaStream(context),tiled_grad0_ptr,grad0.flat<float>().data(),batch_size,npixels,grad0_nelements);
int grad1_nelements=jacobian_cols*difference_cols;
tile_gpu(GetCudaStream(context),tiled_grad1_ptr,grad1.flat<float>().data(),batch_size,npixels,grad1_nelements);
CublasScratchAllocator scratch_allocator(context);
bool blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kNoTranspose,
jacobian_cols,gradient_rows,jacobian_rows,static_cast<float>(1.0),
jacobian_ptrs, jacobian_cols,
gradient_ptrs, gradient_cols,
static_cast<float>(0.0),grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD],jacobian_cols,
n_matrix, &scratch_allocator).ok();
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kNoTranspose,
difference_cols,gradient_rows,jacobian_cols,static_cast<float>(1.0),
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+GRAD_1],difference_cols,
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD],jacobian_cols,
static_cast<float>(0.0),difference_grad_ptrs,difference_cols,
n_matrix, &scratch_allocator).ok();
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kNoTranspose,
jacobian_cols,gradient_rows,jacobian_cols,static_cast<float>(2.0),
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+GRAD_0],jacobian_cols,
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD],jacobian_cols,
static_cast<float>(0.0),grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD_GRAD],jacobian_cols,
n_matrix, &scratch_allocator).ok();
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kTranspose,
perftools::gputools::blas::Transpose::kNoTranspose,
jacobian_cols,gradient_rows,difference_cols,static_cast<float>(1.0),
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+GRAD_1],difference_cols,
difference_ptrs,difference_cols,
static_cast<float>(1.0),grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD_GRAD],jacobian_cols,
n_matrix, &scratch_allocator).ok();
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kTranspose,
jacobian_cols,gradient_cols,gradient_rows,static_cast<float>(1.0),
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD_GRAD],jacobian_cols,
gradient_ptrs,gradient_cols,
static_cast<float>(0.0),jacobian_grad_ptrs,jacobian_cols,
n_matrix, &scratch_allocator).ok();
blas_launch_status = context->op_device_context()
->stream()
->ThenBlasGemmBatchedWithScratch(
perftools::gputools::blas::Transpose::kTranspose,
perftools::gputools::blas::Transpose::kNoTranspose,
gradient_cols,gradient_rows,jacobian_cols,static_cast<float>(1.0),
jacobian_ptrs,jacobian_cols,
grad_memory_ptr[GRAD_MAX_BUFFERS*gpu_id+JACOBIAN_GRAD_GRAD],jacobian_cols,
static_cast<float>(0.0),gradient_grad_ptrs,gradient_cols,
n_matrix, &scratch_allocator).ok();
}
};
REGISTER_KERNEL_BUILDER(Name("EquationConstructionGrad").Device(DEVICE_GPU),EquationConstructionGrad);