forked from secretflow/spu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bit_utils.h
164 lines (143 loc) · 4.95 KB
/
bit_utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// Copyright 2021 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include "absl/numeric/bits.h"
#include "yacl/base/int128.h"
namespace spu {
inline constexpr int Log2Floor(uint64_t n) {
return (n <= 1) ? 0 : (63 - absl::countl_zero(n));
}
inline constexpr int Log2Ceil(uint64_t n) {
return (n <= 1) ? 0 : (64 - absl::countl_zero(n - 1));
}
// TODO: move to constexpr when yacl is ready.
template <typename T>
size_t BitWidth(const T& v) {
if constexpr (sizeof(T) == 16) {
auto [hi, lo] = yacl::DecomposeUInt128(v);
if (hi != 0) {
return absl::bit_width(hi) + 64;
} else {
return absl::bit_width(lo);
}
} else {
return absl::bit_width(v);
}
}
namespace detail {
uint64_t BitDeintlWithPdepext(uint64_t in, int64_t stride);
uint64_t BitIntlWithPdepext(uint64_t in, int64_t stride);
inline constexpr std::array<uint128_t, 6> kBitIntlSwapMasks = {{
yacl::MakeUint128(0x2222222222222222, 0x2222222222222222), // 4bit
yacl::MakeUint128(0x0C0C0C0C0C0C0C0C, 0x0C0C0C0C0C0C0C0C), // 8bit
yacl::MakeUint128(0x00F000F000F000F0, 0x00F000F000F000F0), // 16bit
yacl::MakeUint128(0x0000FF000000FF00, 0x0000FF000000FF00), // 32bit
yacl::MakeUint128(0x00000000FFFF0000, 0x00000000FFFF0000), // 64bit
yacl::MakeUint128(0x0000000000000000, 0xFFFFFFFF00000000), // 128bit
}};
inline constexpr std::array<uint128_t, 6> kBitIntlKeepMasks = {{
yacl::MakeUint128(0x9999999999999999, 0x9999999999999999), // 4bit
yacl::MakeUint128(0xC3C3C3C3C3C3C3C3, 0xC3C3C3C3C3C3C3C3), // 8bit
yacl::MakeUint128(0xF00FF00FF00FF00F, 0xF00FF00FF00FF00F), // 16bit
yacl::MakeUint128(0xFF0000FFFF0000FF, 0xFF0000FFFF0000FF), // 32bit
yacl::MakeUint128(0xFFFF00000000FFFF, 0xFFFF00000000FFFF), // 64bit
yacl::MakeUint128(0xFFFFFFFF00000000, 0x00000000FFFFFFFF), // 128bit
}};
} // namespace detail
// Bit de-interleave function.
//
// The reverse bit interleave method, put the even bits at lower half, and odd
// bits at upper half.
//
// aXbYcZdW -> abcdXYZW
//
// stride represent the log shift of the interleaved distance.
//
// 01010101 -> 00001111 stride = 0
// 00110011 -> 00001111 stride = 1
// 00001111 -> 00001111 stride = 2
//
template <typename T, std::enable_if_t<std::is_unsigned_v<T>, bool> = true>
T BitDeintl(T in, int64_t stride, int64_t nbits = -1) {
if (nbits == -1) {
nbits = sizeof(T) * 8;
}
// TODO:
// 1. handle nbits
// 2. enable this when benchmark test passed.
// if constexpr (std::is_same_v<T, uint64_t>) {
// return detail::BitDeintlWithPdepext(in, stride);
// }
// The general log(n) algorithm
// algorithm:
// 0101010101010101
// swap ^^ ^^ ^^ ^^
// 0011001100110011
// swap ^^^^ ^^^^
// 0000111100001111
// swap ^^^^^^^^
// 0000000011111111
T r = in;
for (int64_t level = stride; level + 1 < Log2Ceil(nbits); level++) {
const T K = static_cast<T>(detail::kBitIntlKeepMasks[level]);
const T M = static_cast<T>(detail::kBitIntlSwapMasks[level]);
int S = 1 << level;
r = (r & K) ^ ((r >> S) & M) ^ ((r & M) << S);
}
return r;
}
/// Bit interleave function.
//
// Interleave bits of input, so the upper bits of input are in the even
// positions and lower bits in the odd. Also called Morton Number.
//
// abcdXYZW -> aXbYcZdW
//
// stride represent the log shift of the interleaved distance.
//
// 00001111 -> 01010101 stride = 0
// 00001111 -> 00110011 stride = 1
// 00001111 -> 00001111 stride = 2
//
template <typename T, std::enable_if_t<std::is_unsigned_v<T>, bool> = true>
T BitIntl(T in, int64_t stride, int64_t nbits = -1) {
if (nbits == -1) {
nbits = sizeof(T) * 8;
}
// TODO: fast path for intrinsic.
// 1. handle nbits
// 2. enable this when benchmark test passed.
// if constexpr (std::is_same_v<T, uint64_t>) {
// return detail::BitIntlWithPdepext(in, stride);
// }
// The general log(n) algorithm
// algorithm:
// 0000000011111111
// swap ^^^^^^^^
// 0000111100001111
// swap ^^^^ ^^^^
// 0011001100110011
// swap ^^ ^^ ^^ ^^
// 0101010101010101
T r = in;
for (int64_t level = Log2Ceil(nbits) - 2; level >= stride; level--) {
const T K = static_cast<T>(detail::kBitIntlKeepMasks[level]);
const T M = static_cast<T>(detail::kBitIntlSwapMasks[level]);
int S = 1 << level;
r = (r & K) ^ ((r >> S) & M) ^ ((r & M) << S);
}
return r;
}
} // namespace spu