Skip to content

Commit 987e665

Browse files
committed
Add flush method to writer
1 parent b2b17ac commit 987e665

File tree

2 files changed

+117
-12
lines changed

2 files changed

+117
-12
lines changed

Diff for: writer.go

+47-6
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,41 @@ func (w *Writer) spawn(f func()) {
548548
}()
549549
}
550550

551+
// Flush writes all currently buffered messages to the kafka cluster. This will
552+
// block until all messages in the batch has been written to kafka, or until the
553+
// context is canceled.
554+
func (w *Writer) Flush(ctx context.Context) error {
555+
w.mutex.Lock()
556+
557+
var wg sync.WaitGroup
558+
559+
// flush all writers
560+
for _, writer := range w.writers {
561+
w := writer
562+
wg.Add(1)
563+
go func() {
564+
b := w.flush()
565+
<-b.done
566+
wg.Done()
567+
}()
568+
}
569+
570+
w.mutex.Unlock()
571+
done := make(chan struct{})
572+
573+
go func() {
574+
wg.Wait()
575+
close(done)
576+
}()
577+
578+
select {
579+
case <-done:
580+
return nil
581+
case <-ctx.Done():
582+
return ctx.Err()
583+
}
584+
}
585+
551586
// Close flushes pending writes, and waits for all writes to complete before
552587
// returning. Calling Close also prevents new writes from being submitted to
553588
// the writer, further calls to WriteMessages and the like will fail with
@@ -1184,17 +1219,23 @@ func (ptw *partitionWriter) writeBatch(batch *writeBatch) {
11841219
batch.complete(err)
11851220
}
11861221

1187-
func (ptw *partitionWriter) close() {
1222+
func (ptw *partitionWriter) flush() *writeBatch {
11881223
ptw.mutex.Lock()
11891224
defer ptw.mutex.Unlock()
11901225

1191-
if ptw.currBatch != nil {
1192-
batch := ptw.currBatch
1193-
ptw.queue.Put(batch)
1194-
ptw.currBatch = nil
1195-
batch.trigger()
1226+
if ptw.currBatch == nil {
1227+
return nil
11961228
}
11971229

1230+
batch := ptw.currBatch
1231+
ptw.queue.Put(batch)
1232+
ptw.currBatch = nil
1233+
batch.trigger()
1234+
return batch
1235+
}
1236+
1237+
func (ptw *partitionWriter) close() {
1238+
ptw.flush()
11981239
ptw.queue.Close()
11991240
}
12001241

Diff for: writer_test.go

+70-6
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ func TestWriter(t *testing.T) {
105105
scenario: "closing a writer right after creating it returns promptly with no error",
106106
function: testWriterClose,
107107
},
108-
109108
{
110109
scenario: "writing 1 message through a writer using round-robin balancing produces 1 message to the first partition",
111110
function: testWriterRoundRobin1,
@@ -130,6 +129,10 @@ func TestWriter(t *testing.T) {
130129
scenario: "writing a batch of messages",
131130
function: testWriterBatchSize,
132131
},
132+
{
133+
scenario: "writing and flushing a batch of messages",
134+
function: testsWriterFlush,
135+
},
133136

134137
{
135138
scenario: "writing messages with a small batch byte size",
@@ -450,7 +453,7 @@ func readPartition(topic string, partition int, offset int64) (msgs []Message, e
450453
}
451454
}
452455

453-
func testWriterBatchBytes(t *testing.T) {
456+
func testsWriterFlush(t *testing.T) {
454457
topic := makeTopic()
455458
createTopic(t, topic, 1)
456459
defer deleteTopic(t, topic)
@@ -461,10 +464,13 @@ func testWriterBatchBytes(t *testing.T) {
461464
}
462465

463466
w := newTestWriter(WriterConfig{
464-
Topic: topic,
465-
BatchBytes: 50,
466-
BatchTimeout: math.MaxInt32 * time.Second,
467+
Topic: topic,
468+
// Set the batch timeout to a large value to avoid the timeout
469+
BatchSize: 1000,
470+
BatchBytes: 1000000,
471+
BatchTimeout: 1000 * time.Second,
467472
Balancer: &RoundRobin{},
473+
Async: true,
468474
})
469475
defer w.Close()
470476

@@ -480,7 +486,65 @@ func testWriterBatchBytes(t *testing.T) {
480486
return
481487
}
482488

483-
if w.Stats().Writes != 2 {
489+
if err := w.Flush(ctx); err != nil {
490+
t.Errorf("flush error %v", err)
491+
return
492+
}
493+
494+
if w.Stats().Writes != 1 {
495+
t.Error("didn't create expected batches")
496+
return
497+
}
498+
msgs, err := readPartition(topic, 0, offset)
499+
if err != nil {
500+
t.Error("error reading partition", err)
501+
return
502+
}
503+
504+
if len(msgs) != 4 {
505+
t.Error("bad messages in partition", msgs)
506+
return
507+
}
508+
509+
for i, m := range msgs {
510+
if string(m.Value) == "M"+strconv.Itoa(i) {
511+
continue
512+
}
513+
t.Error("bad messages in partition", string(m.Value))
514+
}
515+
}
516+
517+
func testWriterBatchBytes(t *testing.T) {
518+
topic := makeTopic()
519+
createTopic(t, topic, 1)
520+
defer deleteTopic(t, topic)
521+
522+
offset, err := readOffset(topic, 0)
523+
if err != nil {
524+
t.Fatal(err)
525+
}
526+
527+
w := newTestWriter(WriterConfig{
528+
Topic: topic,
529+
BatchBytes: 50,
530+
BatchTimeout: math.MaxInt32 * time.Second,
531+
Balancer: &RoundRobin{},
532+
})
533+
defer w.Close()
534+
535+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
536+
defer cancel()
537+
if err := w.WriteMessages(ctx, []Message{
538+
{Value: []byte("M0")},
539+
{Value: []byte("M1")},
540+
{Value: []byte("M2")},
541+
{Value: []byte("M3")},
542+
}...); err != nil {
543+
t.Error(err)
544+
return
545+
}
546+
547+
if w.Stats().Writes != 1 {
484548
t.Error("didn't create expected batches")
485549
return
486550
}

0 commit comments

Comments
 (0)