From 60bcef925579301e77de0b0279ebef34e6285372 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Mon, 15 Apr 2024 19:19:02 +0000 Subject: [PATCH] parallelized lua script execution with serialized output (#1613) --- flow/connectors/kafka/kafka.go | 109 ++++++++++------- flow/connectors/pubsub/pubsub.go | 193 +++++++++++++++++-------------- flow/connectors/utils/lua.go | 94 +++++++++++++++ flow/go.mod | 2 +- flow/go.sum | 4 +- flow/peerdbenv/config.go | 4 + flow/pua/peerdb_test.go | 2 +- 7 files changed, 269 insertions(+), 139 deletions(-) diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index bb36aee8ef..35260f1ecf 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -169,38 +169,43 @@ func lvalueToKafkaRecord(ls *lua.LState, value lua.LValue) (*kgo.Record, error) func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { var wg sync.WaitGroup wgCtx, wgErr := context.WithCancelCause(ctx) - produceCb := func(r *kgo.Record, err error) { + produceCb := func(_ *kgo.Record, err error) { if err != nil { wgErr(err) } wg.Done() } - numRecords := int64(0) + numRecords := atomic.Int64{} tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) - ls, err := utils.LoadScript(wgCtx, req.Script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() + pool, err := utils.LuaPool(func() (*lua.LState, error) { + ls, err := utils.LoadScript(wgCtx, req.Script, func(ls *lua.LState) int { + top := ls.GetTop() + ss := make([]string, top) + for i := range top { + ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() + } + _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) + return 0 + }) + if err != nil { + return nil, err + } + if req.Script == "" { + ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) + } + return ls, nil + }, func(krs []*kgo.Record) { + wg.Add(len(krs)) + for _, kr := range krs { + c.client.Produce(wgCtx, kr, produceCb) } - _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) - return 0 }) if err != nil { return nil, err } - defer ls.Close() - if req.Script == "" { - ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) - } - - lfn := ls.Env.RawGetString("onRecord") - fn, ok := lfn.(*lua.LFunction) - if !ok { - return nil, fmt.Errorf("script should define `onRecord` as function, not %s", lfn) - } + defer pool.Close() lastSeenLSN := atomic.Int64{} flushLoopDone := make(chan struct{}) @@ -242,39 +247,54 @@ Loop: break Loop } - ls.Push(fn) - ls.Push(pua.LuaRecord.New(ls, record)) - err := ls.PCall(1, -1, nil) - if err != nil { - return nil, fmt.Errorf("script failed: %w", err) - } - args := ls.GetTop() - for i := range args { - kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args)) + pool.Run(func(ls *lua.LState) []*kgo.Record { + lfn := ls.Env.RawGetString("onRecord") + fn, ok := lfn.(*lua.LFunction) + if !ok { + wgErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) + return nil + } + + ls.Push(fn) + ls.Push(pua.LuaRecord.New(ls, record)) + err := ls.PCall(1, -1, nil) if err != nil { - return nil, err + wgErr(fmt.Errorf("script failed: %w", err)) + return nil } - if kr != nil { - if kr.Topic == "" { - kr.Topic = record.GetDestinationTableName() - } - wg.Add(1) - c.client.Produce(wgCtx, kr, produceCb) - record.PopulateCountMap(tableNameRowsMapping) + args := ls.GetTop() + results := make([]*kgo.Record, 0, args) + for i := range args { + kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args)) + if err != nil { + wgErr(err) + return nil + } + if kr != nil { + if kr.Topic == "" { + kr.Topic = record.GetDestinationTableName() + } + results = append(results, kr) + record.PopulateCountMap(tableNameRowsMapping) + } } - } - ls.SetTop(0) - numRecords += 1 - shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) + ls.SetTop(0) + numRecords.Add(1) + shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) + return results + }) case <-wgCtx.Done(): - return nil, wgCtx.Err() + break Loop } } close(flushLoopDone) - if err := c.client.Flush(ctx); err != nil { + if err := pool.Wait(wgCtx); err != nil { + return nil, err + } + if err := c.client.Flush(wgCtx); err != nil { return nil, fmt.Errorf("[kafka] final flush error: %w", err) } waitChan := make(chan struct{}) @@ -289,15 +309,14 @@ Loop: } lastCheckpoint := req.Records.GetLastCheckpoint() - err = c.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) - if err != nil { + if err := c.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint); err != nil { return nil, err } return &model.SyncResponse{ CurrentSyncBatchID: req.SyncBatchID, LastSyncedCheckpointID: lastCheckpoint, - NumRecordsSynced: numRecords, + NumRecordsSynced: numRecords.Load(), TableNameRowsMapping: tableNameRowsMapping, TableSchemaDeltas: req.Records.SchemaDeltas, }, nil diff --git a/flow/connectors/pubsub/pubsub.go b/flow/connectors/pubsub/pubsub.go index 3b0340d0b9..b6f5debe0a 100644 --- a/flow/connectors/pubsub/pubsub.go +++ b/flow/connectors/pubsub/pubsub.go @@ -72,7 +72,12 @@ func (c *PubSubConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName return nil } -func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (string, *pubsub.Message, error) { +type PubSubMessage struct { + *pubsub.Message + Topic string +} + +func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (PubSubMessage, error) { var topic string var msg *pubsub.Message switch v := value.(type) { @@ -83,15 +88,15 @@ func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (string, *pubsub.Me case *lua.LTable: key, err := utils.LVAsStringOrNil(ls, ls.GetField(v, "key")) if err != nil { - return "", nil, fmt.Errorf("invalid key, %w", err) + return PubSubMessage{}, fmt.Errorf("invalid key, %w", err) } value, err := utils.LVAsReadOnlyBytes(ls, ls.GetField(v, "value")) if err != nil { - return "", nil, fmt.Errorf("invalid value, %w", err) + return PubSubMessage{}, fmt.Errorf("invalid value, %w", err) } topic, err = utils.LVAsStringOrNil(ls, ls.GetField(v, "topic")) if err != nil { - return "", nil, fmt.Errorf("invalid topic, %w", err) + return PubSubMessage{}, fmt.Errorf("invalid topic, %w", err) } msg = &pubsub.Message{ OrderingKey: key, @@ -104,13 +109,16 @@ func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (string, *pubsub.Me msg.Attributes[k.String()] = v.String() }) } else if lua.LVAsBool(lheaders) { - return "", nil, fmt.Errorf("invalid headers, must be nil or table: %s", lheaders) + return PubSubMessage{}, fmt.Errorf("invalid headers, must be nil or table: %s", lheaders) } case *lua.LNilType: default: - return "", nil, fmt.Errorf("script returned invalid value: %s", value) + return PubSubMessage{}, fmt.Errorf("script returned invalid value: %s", value) } - return topic, msg, nil + return PubSubMessage{ + Message: msg, + Topic: topic, + }, nil } type topicCache struct { @@ -163,57 +171,70 @@ func (tc *topicCache) GetOrSet(topic string, f func() (*pubsub.Topic, error)) (* } func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { - numRecords := int64(0) + numRecords := atomic.Int64{} tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) + topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} + publish := make(chan *pubsub.PublishResult, 32) - ls, err := utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() + waitChan := make(chan struct{}) + wgCtx, wgErr := context.WithCancelCause(ctx) + + pool, err := utils.LuaPool(func() (*lua.LState, error) { + ls, err := utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int { + top := ls.GetTop() + ss := make([]string, top) + for i := range top { + ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() + } + _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) + return 0 + }) + if err != nil { + return nil, err + } + if req.Script == "" { + ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) + } + return ls, nil + }, func(messages []PubSubMessage) { + for _, message := range messages { + topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) { + topicClient := c.client.Topic(message.Topic) + exists, err := topicClient.Exists(wgCtx) + if err != nil { + return nil, fmt.Errorf("error checking if topic exists: %w", err) + } + if !exists { + topicClient, err = c.client.CreateTopic(wgCtx, message.Topic) + if err != nil { + return nil, fmt.Errorf("error creating topic: %w", err) + } + } + return topicClient, nil + }) + if err != nil { + wgErr(err) + return + } + + publish <- topicClient.Publish(ctx, message.Message) } - _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) - return 0 }) if err != nil { return nil, err } - defer ls.Close() - if req.Script == "" { - ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) - } - - lfn := ls.Env.RawGetString("onRecord") - fn, ok := lfn.(*lua.LFunction) - if !ok { - return nil, fmt.Errorf("script should define `onRecord` as function, not %s", lfn) - } + defer pool.Close() - var wg sync.WaitGroup - wgCtx, wgErr := context.WithCancelCause(ctx) - publish := make(chan *pubsub.PublishResult, 60) go func() { - var curpub *pubsub.PublishResult - for { - select { - case curpub, ok = <-publish: - if !ok { - return - } - case <-ctx.Done(): - wgErr(ctx.Err()) - return - } - _, err := curpub.Get(ctx) - if err != nil { + for curpub := range publish { + if _, err := curpub.Get(ctx); err != nil { wgErr(err) - return + break } - wg.Done() } + close(waitChan) }() - topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} lastSeenLSN := atomic.Int64{} flushLoopDone := make(chan struct{}) go func() { @@ -250,63 +271,56 @@ Loop: c.logger.Info("flushing batches because no more records") break Loop } - ls.Push(fn) - ls.Push(pua.LuaRecord.New(ls, record)) - err := ls.PCall(1, -1, nil) - if err != nil { - return nil, fmt.Errorf("script failed: %w", err) - } - args := ls.GetTop() - for i := range args { - topic, msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args)) + + pool.Run(func(ls *lua.LState) []PubSubMessage { + lfn := ls.Env.RawGetString("onRecord") + fn, ok := lfn.(*lua.LFunction) + if !ok { + wgErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) + return nil + } + + ls.Push(fn) + ls.Push(pua.LuaRecord.New(ls, record)) + err := ls.PCall(1, -1, nil) if err != nil { - return nil, err + wgErr(fmt.Errorf("script failed: %w", err)) + return nil } - if msg != nil { - if topic == "" { - topic = record.GetDestinationTableName() + + args := ls.GetTop() + results := make([]PubSubMessage, 0, args) + for i := range args { + msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args)) + if err != nil { + wgErr(err) + return nil } - topicClient, err := topiccache.GetOrSet(topic, func() (*pubsub.Topic, error) { - topicClient := c.client.Topic(topic) - exists, err := topicClient.Exists(wgCtx) - if err != nil { - return nil, fmt.Errorf("error checking if topic exists: %w", err) + if msg.Message != nil { + if msg.Topic == "" { + msg.Topic = record.GetDestinationTableName() } - if !exists { - topicClient, err = c.client.CreateTopic(wgCtx, topic) - if err != nil { - return nil, fmt.Errorf("error creating topic: %w", err) - } - } - return topicClient, nil - }) - if err != nil { - return nil, err + results = append(results, msg) + record.PopulateCountMap(tableNameRowsMapping) } - - pubresult := topicClient.Publish(ctx, msg) - wg.Add(1) - publish <- pubresult - record.PopulateCountMap(tableNameRowsMapping) } - } - ls.SetTop(0) - numRecords += 1 - shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) + ls.SetTop(0) + numRecords.Add(1) + shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) + return results + }) case <-wgCtx.Done(): - return nil, wgCtx.Err() + break Loop } } close(flushLoopDone) close(publish) + if err := pool.Wait(wgCtx); err != nil { + return nil, err + } topiccache.Stop(wgCtx) - waitChan := make(chan struct{}) - go func() { - wg.Wait() - close(waitChan) - }() select { case <-wgCtx.Done(): return nil, wgCtx.Err() @@ -314,15 +328,14 @@ Loop: } lastCheckpoint := req.Records.GetLastCheckpoint() - err = c.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) - if err != nil { + if err := c.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint); err != nil { return nil, err } return &model.SyncResponse{ CurrentSyncBatchID: req.SyncBatchID, LastSyncedCheckpointID: lastCheckpoint, - NumRecordsSynced: numRecords, + NumRecordsSynced: numRecords.Load(), TableNameRowsMapping: tableNameRowsMapping, TableSchemaDeltas: req.Records.SchemaDeltas, }, nil diff --git a/flow/connectors/utils/lua.go b/flow/connectors/utils/lua.go index fff574da86..47676721b3 100644 --- a/flow/connectors/utils/lua.go +++ b/flow/connectors/utils/lua.go @@ -9,6 +9,7 @@ import ( "github.com/PeerDB-io/gluaflatbuffers" "github.com/PeerDB-io/gluajson" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/pua" "github.com/PeerDB-io/peer-flow/shared" ) @@ -80,3 +81,96 @@ func DefaultOnRecord(ls *lua.LState) int { ls.Call(1, 1) return 1 } + +type LPoolMessage[T any] struct { + f func(*lua.LState) T + ret chan<- T +} +type LPool[T any] struct { + messages chan LPoolMessage[T] + returns chan<- (<-chan T) + wait <-chan struct{} + cons func() (*lua.LState, error) + maxSize int + size int + closed bool +} + +func LuaPool[T any](cons func() (*lua.LState, error), merge func(T)) (*LPool[T], error) { + maxSize := peerdbenv.PeerDBQueueParallelism() + returns := make(chan (<-chan T), maxSize) + wait := make(chan struct{}) + go func() { + for ret := range returns { + for val := range ret { + merge(val) + } + } + close(wait) + }() + + pool := &LPool[T]{ + messages: make(chan LPoolMessage[T]), + returns: returns, + wait: wait, + cons: cons, + maxSize: maxSize, + size: 0, + closed: false, + } + if err := pool.Spawn(); err != nil { + pool.Close() + return nil, err + } + return pool, nil +} + +func (pool *LPool[T]) Spawn() error { + ls, err := pool.cons() + if err != nil { + return err + } + pool.size += 1 + go func() { + defer ls.Close() + for message := range pool.messages { + message.ret <- message.f(ls) + close(message.ret) + } + }() + return nil +} + +func (pool *LPool[T]) Close() { + if !pool.closed { + close(pool.returns) + close(pool.messages) + pool.closed = true + } +} + +func (pool *LPool[T]) Run(f func(*lua.LState) T) { + ret := make(chan T, 1) + msg := LPoolMessage[T]{f: f, ret: ret} + if pool.size < pool.maxSize { + select { + case pool.messages <- msg: + pool.returns <- ret + return + default: + _ = pool.Spawn() + } + } + pool.messages <- msg + pool.returns <- ret +} + +func (pool *LPool[T]) Wait(ctx context.Context) error { + pool.Close() + select { + case <-pool.wait: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/flow/go.mod b/flow/go.mod index 19cd29da3d..382250af47 100644 --- a/flow/go.mod +++ b/flow/go.mod @@ -15,7 +15,7 @@ require ( github.com/PeerDB-io/gluabit32 v1.0.2 github.com/PeerDB-io/gluaflatbuffers v1.0.1 github.com/PeerDB-io/gluajson v1.0.2 - github.com/PeerDB-io/gluamsgpack v1.0.2 + github.com/PeerDB-io/gluamsgpack v1.0.4 github.com/PeerDB-io/gluautf8 v1.0.0 github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2/config v1.27.11 diff --git a/flow/go.sum b/flow/go.sum index b4852d9ce6..9f6c24becf 100644 --- a/flow/go.sum +++ b/flow/go.sum @@ -66,8 +66,8 @@ github.com/PeerDB-io/gluaflatbuffers v1.0.1 h1:Oxlv0VlMYoQ05Q5n/k4hXAsvtDnuVNC99 github.com/PeerDB-io/gluaflatbuffers v1.0.1/go.mod h1:unZOM4Mm2Sn+aAFuVjoJDZ2Dji7jlDWrt4Hvq79as2g= github.com/PeerDB-io/gluajson v1.0.2 h1:Kv5Qabj2Md6gxRZsX5QVUOQDf5WMOQEF8lIkKXguajM= github.com/PeerDB-io/gluajson v1.0.2/go.mod h1:arRzpblxlLiWfBAluxP9Xibf6J8UkUIfoY4FPHTsz4Q= -github.com/PeerDB-io/gluamsgpack v1.0.2 h1:J5VhMSfJdWfCMJ1wszDNC2BVD4F+ATgXwGVs1lNft9g= -github.com/PeerDB-io/gluamsgpack v1.0.2/go.mod h1:1ufs5NK2DczzQS78Nhy0AkCA0dOVyt/KVEk39lbWzyU= +github.com/PeerDB-io/gluamsgpack v1.0.4 h1:JrZtdNAAkE6RtllVuhEuWWa26lQZ/K5BJWiI3q9EhL0= +github.com/PeerDB-io/gluamsgpack v1.0.4/go.mod h1:1ufs5NK2DczzQS78Nhy0AkCA0dOVyt/KVEk39lbWzyU= github.com/PeerDB-io/gluautf8 v1.0.0 h1:Ubhy6FVnrED5jrosdUOxzAkf3YdcgebYJzX2YBdGedE= github.com/PeerDB-io/gluautf8 v1.0.0/go.mod h1:+4RQlCVFCMikYFmiKUA9ADZftgGAZseMmdErxfE1EZQ= github.com/alecthomas/assert/v2 v2.6.0 h1:o3WJwILtexrEUk3cUVal3oiQY2tfgr/FHWiz/v2n4FU= diff --git a/flow/peerdbenv/config.go b/flow/peerdbenv/config.go index db252c84f0..be580d5884 100644 --- a/flow/peerdbenv/config.go +++ b/flow/peerdbenv/config.go @@ -40,6 +40,10 @@ func PeerDBQueueFlushTimeoutSeconds() time.Duration { return time.Duration(x) * time.Second } +func PeerDBQueueParallelism() int { + return getEnvInt("PEERDB_QUEUE_PARALLELISM", 4) +} + // env variable doesn't exist anymore, but tests appear to depend on this // in lieu of an actual value of IdleTimeoutSeconds func PeerDBCDCIdleTimeoutSeconds(providedValue int) time.Duration { diff --git a/flow/pua/peerdb_test.go b/flow/pua/peerdb_test.go index 2085fb26c0..d08c9b4db6 100644 --- a/flow/pua/peerdb_test.go +++ b/flow/pua/peerdb_test.go @@ -66,6 +66,6 @@ assert(msgpack.encode(uuid) == string.char(0xc4, 16, 2, 3, 5, 7, 11, 13, 17, 19, local json = require "json" assert(json.encode(row) == "{\"a\":5040}") -print(json.encode(row_empty_array.a) == "[]") +assert(json.encode(row_empty_array.a) == "[]") `) }