forked from facebookresearch/FBTT-Embedding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hashtbl_cuda_utils.cuh
154 lines (134 loc) · 3.92 KB
/
hashtbl_cuda_utils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
/*
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/TensorAccessor.h>
#include <thrust/pair.h>
#include <cassert>
#include <iostream>
#include <iterator>
#include <limits>
#include <type_traits>
#include <cuda.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
constexpr uint32_t c1 = 0xcc9e2d51;
constexpr uint32_t c2 = 0x1b873593;
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
DEVICE_INLINE int64_t gpuAtomicCAS(int64_t* p, int64_t compare, int64_t val) {
static_assert(
sizeof(int64_t) == sizeof(unsigned long long),
"expected int64_t to be unsigned long long");
return static_cast<int64_t>(atomicCAS(
reinterpret_cast<unsigned long long int*>(p),
static_cast<unsigned long long int>(compare),
static_cast<unsigned long long int>(val)));
}
DEVICE_INLINE int32_t gpuAtomicCAS(int32_t* p, int32_t compare, int32_t val) {
return atomicCAS(p, compare, val);
}
__host__ DEVICE_INLINE uint32_t rotl32(uint32_t x, int8_t r) {
return (x << r) | (x >> (32 - r));
}
__host__ DEVICE_INLINE uint32_t murmor_hash_3_32(int64_t h_in, int32_t C) {
uint32_t h = 0;
uint32_t* ptr = reinterpret_cast<uint32_t*>(&h_in);
uint32_t k1 = ptr[0];
k1 *= c1;
k1 = rotl32(k1, 15);
k1 *= c2;
h ^= k1;
h = rotl32(h, 13);
h = h * 5 + 0xe6546b64;
uint32_t k2 = ptr[1];
k2 *= c1;
k2 = rotl32(k2, 15);
k2 *= c2;
h ^= k2;
h = rotl32(h, 13);
h = h * 5 + 0xe6546b64;
h ^= 2;
// MurmorHash3 32-bit mixing function.
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
return ((uint64_t)h * (uint64_t)C) >> 32;
}
__host__ DEVICE_INLINE uint32_t murmor_hash_3_32(int32_t h_in, int32_t C) {
uint32_t h = 0;
uint32_t k = h_in;
k *= c1;
k = rotl32(k, 15);
k *= c2;
h ^= k;
h = rotl32(h, 13);
h = h * 5 + 0xe6546b64;
h ^= 1;
// MurmorHash3 32-bit mixing function.
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
return ((uint64_t)h * (uint64_t)C) >> 32;
}
#define UNUSED_KEY -1
template <typename key_type, typename value_type, bool accumulate>
__forceinline__ __device__ int32_t hashtbl_insert(
key_type insert_key,
value_type insert_value,
int32_t hashtbl_size,
int32_t max_probes,
key_type* __restrict__ hashtbl_keys,
value_type* __restrict__ hashtbl_values) {
int32_t hashtbl_idx = murmor_hash_3_32(insert_key, hashtbl_size);
int32_t counter = 0;
while (counter++ < max_probes) {
key_type old_key =
gpuAtomicCAS(&hashtbl_keys[hashtbl_idx], UNUSED_KEY, insert_key);
if (accumulate) {
if (UNUSED_KEY == old_key || insert_key == old_key) {
gpuAtomicAdd(&hashtbl_values[hashtbl_idx], insert_value);
return hashtbl_idx;
}
} else {
if (UNUSED_KEY == old_key) {
hashtbl_values[hashtbl_idx] = insert_value;
return hashtbl_idx;
} else if (insert_key == old_key) {
return hashtbl_idx;
}
}
// linear probe
hashtbl_idx = (hashtbl_idx + 1) % hashtbl_size;
}
return -1;
}
template <typename key_type>
__forceinline__ __device__ int32_t hashtbl_find(
key_type key,
int32_t hashtbl_size,
int32_t max_probes,
const key_type* __restrict__ hashtbl_keys) {
int32_t hashtbl_idx = murmor_hash_3_32(key, hashtbl_size);
int32_t counter = 0;
while (counter++ < max_probes) {
if (key == hashtbl_keys[hashtbl_idx]) {
return hashtbl_idx;
} else if (UNUSED_KEY == key) {
return -1;
}
// linear probe
hashtbl_idx = (hashtbl_idx + 1) % hashtbl_size;
}
return -1;
}