This repository has been archived by the owner on Apr 27, 2023. It is now read-only.
forked from a8m/kinesis-producer
-
Notifications
You must be signed in to change notification settings - Fork 1
/
shard_map.go
345 lines (309 loc) · 10.1 KB
/
shard_map.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
package producer
import (
"context"
"crypto/md5"
"math/big"
"sort"
"sync"
k "github.com/aws/aws-sdk-go-v2/service/kinesis"
"github.com/aws/aws-sdk-go-v2/service/kinesis/types"
)
// 2^128 exclusive upper bound
// Hash key ranges are 0 indexed, so true max is 2^128 - 1
var maxHashKeyRange = "340282366920938463463374607431768211455"
// ShardLister is the interface that wraps the KinesisAPI.ListShards method.
type ShardLister interface {
ListShards(ctx context.Context, params *k.ListShardsInput, optFns ...func(*k.Options)) (*k.ListShardsOutput, error)
}
// GetKinesisShardsFunc gets the active list of shards from Kinesis.ListShards API
func GetKinesisShardsFunc(client ShardLister, streamName string) GetShardsFunc {
return func(old []types.Shard) ([]types.Shard, bool, error) {
var (
shards []types.Shard
next *string
)
for {
input := &k.ListShardsInput{}
if next != nil {
input.NextToken = next
} else {
input.StreamName = &streamName
}
resp, err := client.ListShards(context.Background(), input)
if err != nil {
return nil, false, err
}
for _, shard := range resp.Shards {
// There may be many shards with overlapping HashKeyRanges due to prior merge and
// split operations. The currently open shards are the ones that do not have a
// SequenceNumberRange.EndingSequenceNumber.
if shard.SequenceNumberRange.EndingSequenceNumber == nil {
shards = append(shards, shard)
}
}
next = resp.NextToken
if next == nil {
break
}
}
sort.Sort(ShardSlice(shards))
if shardsEqual(old, shards) {
return nil, false, nil
}
return shards, true, nil
}
}
// StaticGetShardsFunc returns a GetShardsFunc that when called, will generate a static
// list of shards with length count whos HashKeyRanges are evenly distributed
func StaticGetShardsFunc(count int) GetShardsFunc {
return func(old []types.Shard) ([]types.Shard, bool, error) {
if count == 0 {
return nil, false, nil
}
step := big.NewInt(int64(0))
step, _ = step.SetString(maxHashKeyRange, 10)
bCount := big.NewInt(int64(count))
step = step.Div(step, bCount)
b1 := big.NewInt(int64(1))
shards := make([]types.Shard, count)
key := big.NewInt(int64(0))
for i := 0; i < count; i++ {
bI := big.NewInt(int64(i))
// starting key range (step * i)
key = key.Mul(bI, step)
startingHashKey := key.String()
// ending key range ((step * (i + 1)) - 1)
bINext := big.NewInt(int64(i + 1))
key = key.Mul(bINext, step)
key = key.Sub(key, b1)
endingHashKey := key.String()
shards[i].HashKeyRange = &types.HashKeyRange{
StartingHashKey: &startingHashKey,
EndingHashKey: &endingHashKey,
}
}
// Set last shard end range to max to account for small rounding errors
shards[len(shards)-1].HashKeyRange.EndingHashKey = &maxHashKeyRange
return shards, false, nil
}
}
type ShardSlice []types.Shard
func (p ShardSlice) Len() int { return len(p) }
func (p ShardSlice) Less(i, j int) bool {
a, _ := new(big.Int).SetString(*p[i].HashKeyRange.StartingHashKey, 10)
b, _ := new(big.Int).SetString(*p[j].HashKeyRange.StartingHashKey, 10)
// a < b
return a.Cmp(b) == -1
}
func (p ShardSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Checks to see if the shards have the same hash key ranges
func shardsEqual(a, b []types.Shard) bool {
if len(a) != len(b) {
return false
}
for i, ashard := range a {
bshard := b[i]
if *ashard.HashKeyRange.StartingHashKey != *bshard.HashKeyRange.StartingHashKey ||
*ashard.HashKeyRange.EndingHashKey != *bshard.HashKeyRange.EndingHashKey {
return false
}
}
return true
}
type ShardMap struct {
sync.RWMutex
shards []types.Shard
aggregators []*Aggregator
// aggregateBatchCount determine the maximum number of items to pack into an aggregated record.
aggregateBatchCount int
}
// NewShardMap initializes an aggregator for each shard.
// UserRecords that map to the same shard based on MD5 hash of their partition
// key (Same method used by Kinesis) will be aggregated together. Aggregators will use an
// ExplicitHashKey from their assigned shards when creating kinesis.PutRecordsRequestEntry.
// A ShardMap with an empty shards slice will return to unsharded behavior with a single
// aggregator. The aggregator will instead use the PartitionKey of the first UserRecord and
// no ExplicitHashKey.
func NewShardMap(shards []types.Shard, aggregateBatchCount int) *ShardMap {
return &ShardMap{
shards: shards,
aggregators: makeAggregators(shards),
aggregateBatchCount: aggregateBatchCount,
}
}
// Put puts a UserRecord into the aggregator that maps to its partition key.
func (m *ShardMap) Put(userRecord UserRecord) (*AggregatedRecordRequest, error) {
m.RLock()
drained, err := m.put(userRecord)
// Not using defer to avoid runtime overhead
m.RUnlock()
return drained, err
}
// Size return how many bytes stored in all the aggregators.
// including partition keys.
func (m *ShardMap) Size() int {
m.RLock()
size := 0
for _, a := range m.aggregators {
a.RLock()
size += a.Size()
a.RUnlock()
}
m.RUnlock()
return size
}
// Drain drains all the aggregators and returns a list of the results
func (m *ShardMap) Drain() ([]*AggregatedRecordRequest, []error) {
m.RLock()
var (
requests []*AggregatedRecordRequest
errs []error
)
for _, a := range m.aggregators {
a.Lock()
req, err := a.Drain()
a.Unlock()
if err != nil {
errs = append(errs, err)
} else if req != nil {
requests = append(requests, req)
}
}
m.RUnlock()
return requests, errs
}
// Shards returns the list of shards
func (m *ShardMap) Shards() []types.Shard {
m.RLock()
shards := m.shards
m.RUnlock()
return shards
}
// Update the list of shards and redistribute buffered user records.
// Returns any records that were drained due to redistribution.
// Shards are not updated if an error occurs during redistribution.
// TODO: Can we optimize this?
// TODO: How to handle shard splitting? If a shard splits but we don't remap before sending
// records to the new shards, once we do update our mapping, user records may end up
// in a new shard and we would lose the shard ordering. Consumer can probably figure
// it out since we retain original partition keys (but not explicit hash keys)
// Shard merging should not be an issue since records from both shards should fall
// into the merged hash key range.
func (m *ShardMap) UpdateShards(shards []types.Shard, pendingRecords []*AggregatedRecordRequest) ([]*AggregatedRecordRequest, error) {
m.Lock()
defer m.Unlock()
update := NewShardMap(shards, m.aggregateBatchCount)
var drained []*AggregatedRecordRequest
// first put any pending UserRecords from inflight requests
for _, record := range pendingRecords {
for _, userRecord := range record.UserRecords {
req, err := update.put(userRecord)
if err != nil {
// if we encounter an error trying to redistribute the records, return the pending
// records to the Producer tries to send them again. They won't be redistributed
// across new shards, but at least they won't be lost.
return pendingRecords, err
}
if req != nil {
drained = append(drained, req)
}
}
}
// then redistribute the records still being aggregated
for _, agg := range m.aggregators {
// We don't need to get the aggregator lock because we have the shard map write lock
for _, userRecord := range agg.buf {
req, err := update.put(userRecord)
if err != nil {
return pendingRecords, err
}
if req != nil {
drained = append(drained, req)
}
}
}
// Only update m if we successfully redistributed all the user records
m.shards = update.shards
m.aggregators = update.aggregators
return drained, nil
}
// puts a UserRecord into the aggregator that maps to its partition key.
// Not thread safe. acquire lock before calling.
func (m *ShardMap) put(userRecord UserRecord) (*AggregatedRecordRequest, error) {
bucket := m.bucket(userRecord)
if bucket == -1 {
return nil, &ShardBucketError{UserRecord: userRecord}
}
a := m.aggregators[bucket]
a.Lock()
var (
needToDrain = a.WillOverflow(userRecord) || a.Count() >= m.aggregateBatchCount
drained *AggregatedRecordRequest
err error
)
if needToDrain {
drained, err = a.Drain()
}
a.Put(userRecord)
a.Unlock()
return drained, err
}
// bucket returns the index of the shard the given partition key maps to.
// Returns -1 if partition key is outside shard range.
// Assumes shards is ordered by contiguous HaskKeyRange ascending. If there are gaps in
// shard hash key ranges and the partition key falls into one of the gaps, it will be placed
// in the shard with the larger starting HashKeyRange
// Not thread safe. acquire lock before calling.
// TODO: Can we optimize this? Cache for pk -> bucket?
func (m *ShardMap) bucket(userRecord UserRecord) int {
if len(m.shards) == 0 {
return 0
}
hk := userRecord.ExplicitHashKey()
if hk == nil {
hk = hashKey(userRecord.PartitionKey())
}
sortFunc := func(i int) bool {
shard := m.shards[i]
end := big.NewInt(int64(0))
end, _ = end.SetString(*shard.HashKeyRange.EndingHashKey, 10)
// end >= hk
return end.Cmp(hk) > -1
}
// Search uses binary search to find and return the smallest index i in [0, n)
// at which f(i) is true
// See https://golang.org/pkg/sort/#Search
bucket := sort.Search(len(m.shards), sortFunc)
if bucket == len(m.shards) {
return -1
}
return bucket
}
// Calculate a new explicit hash key based on the given partition key.
// (following the algorithm from the original KPL).
// Copied from: https://github.com/a8m/kinesis-producer/issues/1#issuecomment-524620994
func hashKey(pk string) *big.Int {
h := md5.New()
h.Write([]byte(pk))
sum := h.Sum(nil)
hk := big.NewInt(int64(0))
for i := 0; i < md5.Size; i++ {
p := big.NewInt(int64(sum[i]))
p = p.Lsh(p, uint((16-i-1)*8))
hk = hk.Add(hk, p)
}
return hk
}
func makeAggregators(shards []types.Shard) []*Aggregator {
count := len(shards)
if count == 0 {
return []*Aggregator{NewAggregator(nil)}
}
aggregators := make([]*Aggregator, count)
for i := 0; i < count; i++ {
shard := shards[i]
// Is using the StartingHashKey sufficient?
aggregators[i] = NewAggregator(shard.HashKeyRange.StartingHashKey)
}
return aggregators
}