Skip to content

Commit

Permalink
kafka/pubsub: fix LSN potentially being updated too early
Browse files Browse the repository at this point in the history
LSN should not be updated before success confirmed,
as intermediate value may now be read before error aborts when queue is flushed

Also fix pubsub needing EnableMessageOrdering explicitly enabled
  • Loading branch information
serprex committed May 10, 2024
1 parent a2e9ab8 commit 317e35d
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 48 deletions.
55 changes: 36 additions & 19 deletions flow/connectors/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,18 @@ func lvalueToKafkaRecord(ls *lua.LState, value lua.LValue) (*kgo.Record, error)
return kr, nil
}

type poolResult struct {
records []*kgo.Record
lsn int64
}

func (c *KafkaConnector) createPool(
ctx context.Context,
script string,
flowJobName string,
lastSeenLSN *atomic.Int64,
queueErr func(error),
) (*utils.LPool[[]*kgo.Record], error) {
produceCb := func(_ *kgo.Record, err error) {
if err != nil {
queueErr(err)
}
}

) (*utils.LPool[poolResult], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int {
top := ls.GetTop()
Expand All @@ -194,25 +194,40 @@ func (c *KafkaConnector) createPool(
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}
return ls, nil
}, func(krs []*kgo.Record) {
for _, kr := range krs {
c.client.Produce(ctx, kr, produceCb)
}, func(result poolResult) {
lenRecords := int32(len(result.records))
if lenRecords == 0 {
if lastSeenLSN != nil {
shared.AtomicInt64Max(lastSeenLSN, result.lsn)
}
} else {
recordCounter := atomic.Int32{}
recordCounter.Store(lenRecords)
for _, kr := range result.records {
c.client.Produce(ctx, kr, func(_ *kgo.Record, err error) {
if err != nil {
queueErr(err)
} else if recordCounter.Add(-1) == 0 && lastSeenLSN != nil {
shared.AtomicInt64Max(lastSeenLSN, result.lsn)
}
})
}
}
})
}

func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) {
numRecords := atomic.Int64{}
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
lastSeenLSN := atomic.Int64{}

queueCtx, queueErr := context.WithCancelCause(ctx)
pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, queueErr)
pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, &lastSeenLSN, queueErr)
if err != nil {
return nil, err
}
defer pool.Close()

lastSeenLSN := atomic.Int64{}
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
flushLoopDone := make(chan struct{})
go func() {
ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds())
Expand Down Expand Up @@ -251,20 +266,20 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []*kgo.Record {
pool.Run(func(ls *lua.LState) poolResult {
lfn := ls.Env.RawGetString("onRecord")
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -273,7 +288,7 @@ Loop:
kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if kr != nil {
if kr.Topic == "" {
Expand All @@ -285,8 +300,10 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID())
return results
return poolResult{
records: results,
lsn: record.GetCheckpointID(),
}
})

case <-queueCtx.Done():
Expand Down
12 changes: 6 additions & 6 deletions flow/connectors/kafka/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (c *KafkaConnector) SyncQRepRecords(
schema := stream.Schema()

queueCtx, queueErr := context.WithCancelCause(ctx)
pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, queueErr)
pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, nil, queueErr)
if err != nil {
return 0, err
}
Expand All @@ -44,7 +44,7 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []*kgo.Record {
pool.Run(func(ls *lua.LState) poolResult {
items := model.NewRecordItems(len(qrecord))
for i, val := range qrecord {
items.AddColumn(schema.Fields[i].Name, val)
Expand All @@ -61,15 +61,15 @@ Loop:
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -78,7 +78,7 @@ Loop:
kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if kr != nil {
if kr.Topic == "" {
Expand All @@ -89,7 +89,7 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
return results
return poolResult{records: results}
})

case <-queueCtx.Done():
Expand Down
51 changes: 37 additions & 14 deletions flow/connectors/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ type PubSubMessage struct {
Topic string
}

type poolResult struct {
messages []PubSubMessage
lsn int64
}

type publishResult struct {
*pubsub.PublishResult
lsn int64
}

func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (PubSubMessage, error) {
var topic string
var msg *pubsub.Message
Expand Down Expand Up @@ -126,9 +136,9 @@ func (c *PubSubConnector) createPool(
script string,
flowJobName string,
topiccache *topicCache,
publish chan<- *pubsub.PublishResult,
publish chan<- publishResult,
queueErr func(error),
) (*utils.LPool[[]PubSubMessage], error) {
) (*utils.LPool[poolResult], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int {
top := ls.GetTop()
Expand All @@ -146,10 +156,14 @@ func (c *PubSubConnector) createPool(
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}
return ls, nil
}, func(messages []PubSubMessage) {
for _, message := range messages {
}, func(result poolResult) {
for _, message := range result.messages {
topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) {
topicClient := c.client.Topic(message.Topic)
if message.OrderingKey != "" {
topicClient.EnableMessageOrdering = true
}

exists, err := topicClient.Exists(ctx)
if err != nil {
return nil, fmt.Errorf("error checking if topic exists: %w", err)
Expand All @@ -167,7 +181,12 @@ func (c *PubSubConnector) createPool(
return
}

publish <- topicClient.Publish(ctx, message.Message)
publish <- publishResult{
PublishResult: topicClient.Publish(ctx, message.Message),
}
}
publish <- publishResult{
lsn: result.lsn,
}
})
}
Expand Down Expand Up @@ -223,9 +242,10 @@ func (tc *topicCache) GetOrSet(topic string, f func() (*pubsub.Topic, error)) (*

func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) {
numRecords := atomic.Int64{}
lastSeenLSN := atomic.Int64{}
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
topiccache := topicCache{cache: make(map[string]*pubsub.Topic)}
publish := make(chan *pubsub.PublishResult, 32)
publish := make(chan publishResult, 32)
waitChan := make(chan struct{})

queueCtx, queueErr := context.WithCancelCause(ctx)
Expand All @@ -237,15 +257,16 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord

go func() {
for curpub := range publish {
if _, err := curpub.Get(ctx); err != nil {
if curpub.PublishResult == nil {
shared.AtomicInt64Max(&lastSeenLSN, curpub.lsn)
} else if _, err := curpub.Get(ctx); err != nil {
queueErr(err)
break
}
}
close(waitChan)
}()

lastSeenLSN := atomic.Int64{}
flushLoopDone := make(chan struct{})
go func() {
ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds())
Expand Down Expand Up @@ -281,20 +302,20 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []PubSubMessage {
pool.Run(func(ls *lua.LState) poolResult {
lfn := ls.Env.RawGetString("onRecord")
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -303,7 +324,7 @@ Loop:
msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if msg.Message != nil {
if msg.Topic == "" {
Expand All @@ -315,8 +336,10 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID())
return results
return poolResult{
messages: results,
lsn: record.GetCheckpointID(),
}
})

case <-queueCtx.Done():
Expand Down
20 changes: 11 additions & 9 deletions flow/connectors/pubsub/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (c *PubSubConnector) SyncQRepRecords(
numRecords := atomic.Int64{}
schema := stream.Schema()
topiccache := topicCache{cache: make(map[string]*pubsub.Topic)}
publish := make(chan *pubsub.PublishResult, 32)
publish := make(chan publishResult, 32)
waitChan := make(chan struct{})

queueCtx, queueErr := context.WithCancelCause(ctx)
Expand All @@ -40,9 +40,11 @@ func (c *PubSubConnector) SyncQRepRecords(

go func() {
for curpub := range publish {
if _, err := curpub.Get(ctx); err != nil {
queueErr(err)
break
if curpub.PublishResult != nil {
if _, err := curpub.Get(ctx); err != nil {
queueErr(err)
break
}
}
}
close(waitChan)
Expand All @@ -57,7 +59,7 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []PubSubMessage {
pool.Run(func(ls *lua.LState) poolResult {
items := model.NewRecordItems(len(qrecord))
for i, val := range qrecord {
items.AddColumn(schema.Fields[i].Name, val)
Expand All @@ -74,15 +76,15 @@ Loop:
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -91,7 +93,7 @@ Loop:
msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if msg.Message != nil {
if msg.Topic == "" {
Expand All @@ -102,7 +104,7 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
return results
return poolResult{messages: results}
})

case <-queueCtx.Done():
Expand Down

0 comments on commit 317e35d

Please sign in to comment.