forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mish.cu
196 lines (160 loc) · 5.62 KB
/
mish.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#include <cmath>
#include <stdio.h>
#include <cassert>
#include <iostream>
#include "mish.h"
namespace nvinfer1
{
MishPlugin::MishPlugin()
{
}
MishPlugin::~MishPlugin()
{
}
// create the plugin at runtime from a byte stream
MishPlugin::MishPlugin(const void* data, size_t length)
{
assert(length == sizeof(input_size_));
input_size_ = *reinterpret_cast<const int*>(data);
}
void MishPlugin::serialize(void* buffer) const
{
*reinterpret_cast<int*>(buffer) = input_size_;
}
size_t MishPlugin::getSerializationSize() const
{
return sizeof(input_size_);
}
int MishPlugin::initialize()
{
return 0;
}
Dims MishPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims)
{
assert(nbInputDims == 1);
assert(index == 0);
input_size_ = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
// Output dimensions
return Dims3(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
}
// Set plugin namespace
void MishPlugin::setPluginNamespace(const char* pluginNamespace)
{
mPluginNamespace = pluginNamespace;
}
const char* MishPlugin::getPluginNamespace() const
{
return mPluginNamespace;
}
// Return the DataType of the plugin output at the requested index
DataType MishPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
{
return DataType::kFLOAT;
}
// Return true if output tensor is broadcast across a batch.
bool MishPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const
{
return false;
}
// Return true if plugin can use input that is broadcast across batch without replication.
bool MishPlugin::canBroadcastInputAcrossBatch(int inputIndex) const
{
return false;
}
void MishPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput)
{
}
// Attach the plugin object to an execution context and grant the plugin the access to some context resource.
void MishPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
{
}
// Detach the plugin object from its execution context.
void MishPlugin::detachFromContext() {}
const char* MishPlugin::getPluginType() const
{
return "Mish_TRT";
}
const char* MishPlugin::getPluginVersion() const
{
return "1";
}
void MishPlugin::destroy()
{
delete this;
}
// Clone the plugin
IPluginV2IOExt* MishPlugin::clone() const
{
MishPlugin *p = new MishPlugin();
p->input_size_ = input_size_;
p->setPluginNamespace(mPluginNamespace);
return p;
}
__device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}
__device__ float softplus_kernel(float x, float threshold = 20) {
if (x > threshold) return x; // too large
else if (x < -threshold) return expf(x); // too small
return logf(expf(x) + 1);
}
__global__ void mish_kernel(const float *input, float *output, int num_elem) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= num_elem) return;
//float t = exp(input[idx]);
//if (input[idx] > 20.0) {
// t *= t;
// output[idx] = (t - 1.0) / (t + 1.0);
//} else {
// float tt = t * t;
// output[idx] = (tt + 2.0 * t) / (tt + 2.0 * t + 2.0);
//}
//output[idx] *= input[idx];
output[idx] = input[idx] * tanh_activate_kernel(softplus_kernel(input[idx]));
}
void MishPlugin::forwardGpu(const float *const * inputs, float* output, cudaStream_t stream, int batchSize) {
int block_size = thread_count_;
int grid_size = (input_size_ * batchSize + block_size - 1) / block_size;
mish_kernel<<<grid_size, block_size>>>(inputs[0], output, input_size_ * batchSize);
}
int MishPlugin::enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
{
//assert(batchSize == 1);
//GPU
//CUDA_CHECK(cudaStreamSynchronize(stream));
forwardGpu((const float *const *)inputs, (float*)outputs[0], stream, batchSize);
return 0;
}
PluginFieldCollection MishPluginCreator::mFC{};
std::vector<PluginField> MishPluginCreator::mPluginAttributes;
MishPluginCreator::MishPluginCreator()
{
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* MishPluginCreator::getPluginName() const
{
return "Mish_TRT";
}
const char* MishPluginCreator::getPluginVersion() const
{
return "1";
}
const PluginFieldCollection* MishPluginCreator::getFieldNames()
{
return &mFC;
}
IPluginV2IOExt* MishPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
MishPlugin* obj = new MishPlugin();
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
IPluginV2IOExt* MishPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength)
{
// This object will be deleted when the network is destroyed, which will
// call MishPlugin::destroy()
MishPlugin* obj = new MishPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
}