Skip to content

Commit dafa65d

Browse files
committed
Add flush method to writer
1 parent b2b17ac commit dafa65d

File tree

2 files changed

+113
-6
lines changed

2 files changed

+113
-6
lines changed

writer.go

+47-6
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,41 @@ func (w *Writer) spawn(f func()) {
548548
}()
549549
}
550550

551+
// Flush writes all currently buffered messages to the kafka cluster. This will
552+
// block until all messages in the batch has been written to kafka, or until the
553+
// context is canceled.
554+
func (w *Writer) Flush(ctx context.Context) error {
555+
w.mutex.Lock()
556+
557+
var wg sync.WaitGroup
558+
559+
// flush all writers
560+
for _, writer := range w.writers {
561+
w := writer
562+
wg.Add(1)
563+
go func() {
564+
b := w.flush()
565+
<-b.done
566+
wg.Done()
567+
}()
568+
}
569+
570+
w.mutex.Unlock()
571+
done := make(chan struct{})
572+
573+
go func() {
574+
wg.Wait()
575+
close(done)
576+
}()
577+
578+
select {
579+
case <-done:
580+
return nil
581+
case <-ctx.Done():
582+
return ctx.Err()
583+
}
584+
}
585+
551586
// Close flushes pending writes, and waits for all writes to complete before
552587
// returning. Calling Close also prevents new writes from being submitted to
553588
// the writer, further calls to WriteMessages and the like will fail with
@@ -1184,17 +1219,23 @@ func (ptw *partitionWriter) writeBatch(batch *writeBatch) {
11841219
batch.complete(err)
11851220
}
11861221

1187-
func (ptw *partitionWriter) close() {
1222+
func (ptw *partitionWriter) flush() *writeBatch {
11881223
ptw.mutex.Lock()
11891224
defer ptw.mutex.Unlock()
11901225

1191-
if ptw.currBatch != nil {
1192-
batch := ptw.currBatch
1193-
ptw.queue.Put(batch)
1194-
ptw.currBatch = nil
1195-
batch.trigger()
1226+
if ptw.currBatch == nil {
1227+
return nil
11961228
}
11971229

1230+
batch := ptw.currBatch
1231+
ptw.queue.Put(batch)
1232+
ptw.currBatch = nil
1233+
batch.trigger()
1234+
return batch
1235+
}
1236+
1237+
func (ptw *partitionWriter) close() {
1238+
ptw.flush()
11981239
ptw.queue.Close()
11991240
}
12001241

writer_test.go

+66
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ func TestWriter(t *testing.T) {
121121
function: testWriterMaxBytes,
122122
},
123123

124+
{
125+
scenario: "writing a batch of message and flush",
126+
function: testWriterFlush,
127+
},
128+
124129
{
125130
scenario: "writing a batch of message based on batch byte size",
126131
function: testWriterBatchBytes,
@@ -503,6 +508,67 @@ func testWriterBatchBytes(t *testing.T) {
503508
}
504509
}
505510

511+
func testWriterFlush(t *testing.T) {
512+
topic := makeTopic()
513+
createTopic(t, topic, 1)
514+
defer deleteTopic(t, topic)
515+
516+
offset, err := readOffset(topic, 0)
517+
if err != nil {
518+
t.Fatal(err)
519+
}
520+
521+
w := newTestWriter(WriterConfig{
522+
Topic: topic,
523+
// Set the batch timeout to a large value to avoid the timeout
524+
BatchSize: 1000,
525+
BatchBytes: 1000000,
526+
BatchTimeout: 1000 * time.Second,
527+
Balancer: &RoundRobin{},
528+
Async: true,
529+
})
530+
defer w.Close()
531+
532+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
533+
defer cancel()
534+
if err := w.WriteMessages(ctx, []Message{
535+
{Value: []byte("M0")}, // 25 Bytes
536+
{Value: []byte("M1")}, // 25 Bytes
537+
{Value: []byte("M2")}, // 25 Bytes
538+
{Value: []byte("M3")}, // 25 Bytes
539+
}...); err != nil {
540+
t.Error(err)
541+
return
542+
}
543+
544+
if err := w.Flush(ctx); err != nil {
545+
t.Errorf("flush error %v", err)
546+
return
547+
}
548+
549+
if w.Stats().Writes != 1 {
550+
t.Error("didn't create expected batches")
551+
return
552+
}
553+
msgs, err := readPartition(topic, 0, offset)
554+
if err != nil {
555+
t.Error("error reading partition", err)
556+
return
557+
}
558+
559+
if len(msgs) != 4 {
560+
t.Error("bad messages in partition", msgs)
561+
return
562+
}
563+
564+
for i, m := range msgs {
565+
if string(m.Value) == "M"+strconv.Itoa(i) {
566+
continue
567+
}
568+
t.Error("bad messages in partition", string(m.Value))
569+
}
570+
}
571+
506572
func testWriterBatchSize(t *testing.T) {
507573
topic := makeTopic()
508574
createTopic(t, topic, 1)

0 commit comments

Comments
 (0)