diff --git a/flow/connectors/core.go b/flow/connectors/core.go index dc7afe0495..0814feabd9 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -329,8 +329,9 @@ var ( _ QRepSyncConnector = &connpostgres.PostgresConnector{} _ QRepSyncConnector = &connbigquery.BigQueryConnector{} _ QRepSyncConnector = &connsnowflake.SnowflakeConnector{} - _ QRepSyncConnector = &connclickhouse.ClickhouseConnector{} + _ QRepSyncConnector = &connkafka.KafkaConnector{} _ QRepSyncConnector = &conns3.S3Connector{} + _ QRepSyncConnector = &connclickhouse.ClickhouseConnector{} _ QRepSyncConnector = &connelasticsearch.ElasticsearchConnector{} _ QRepConsolidateConnector = &connsnowflake.SnowflakeConnector{} diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index 30b370e51e..cfe4652598 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "strings" - "sync" "sync/atomic" "time" @@ -166,42 +165,48 @@ func lvalueToKafkaRecord(ls *lua.LState, value lua.LValue) (*kgo.Record, error) return kr, nil } -func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { - var wg sync.WaitGroup - wgCtx, wgErr := context.WithCancelCause(ctx) +func (c *KafkaConnector) createPool( + ctx context.Context, + script string, + flowJobName string, + queueErr func(error), +) (*utils.LPool[[]*kgo.Record], error) { produceCb := func(_ *kgo.Record, err error) { if err != nil { - wgErr(err) + queueErr(err) } - wg.Done() } - numRecords := atomic.Int64{} - tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) - - pool, err := utils.LuaPool(func() (*lua.LState, error) { - ls, err := utils.LoadScript(wgCtx, req.Script, func(ls *lua.LState) int { + return utils.LuaPool(func() (*lua.LState, error) { + ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int { top := ls.GetTop() ss := make([]string, top) for i := range top { ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() } - _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) + _ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t")) return 0 }) if err != nil { return nil, err } - if req.Script == "" { + if script == "" { ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) } return ls, nil }, func(krs []*kgo.Record) { - wg.Add(len(krs)) for _, kr := range krs { - c.client.Produce(wgCtx, kr, produceCb) + c.client.Produce(ctx, kr, produceCb) } }) +} + +func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { + numRecords := atomic.Int64{} + tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) + + queueCtx, queueErr := context.WithCancelCause(ctx) + pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, queueErr) if err != nil { return nil, err } @@ -250,7 +255,7 @@ Loop: lfn := ls.Env.RawGetString("onRecord") fn, ok := lfn.(*lua.LFunction) if !ok { - wgErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) + queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) return nil } @@ -258,7 +263,7 @@ Loop: ls.Push(pua.LuaRecord.New(ls, record)) err := ls.PCall(1, -1, nil) if err != nil { - wgErr(fmt.Errorf("script failed: %w", err)) + queueErr(fmt.Errorf("script failed: %w", err)) return nil } @@ -267,7 +272,7 @@ Loop: for i := range args { kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args)) if err != nil { - wgErr(err) + queueErr(err) return nil } if kr != nil { @@ -284,28 +289,18 @@ Loop: return results }) - case <-wgCtx.Done(): + case <-queueCtx.Done(): break Loop } } close(flushLoopDone) - if err := pool.Wait(wgCtx); err != nil { + if err := pool.Wait(queueCtx); err != nil { return nil, err } - if err := c.client.Flush(wgCtx); err != nil { + if err := c.client.Flush(queueCtx); err != nil { return nil, fmt.Errorf("[kafka] final flush error: %w", err) } - waitChan := make(chan struct{}) - go func() { - wg.Wait() - close(waitChan) - }() - select { - case <-wgCtx.Done(): - return nil, wgCtx.Err() - case <-waitChan: - } lastCheckpoint := req.Records.GetLastCheckpoint() if err := c.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint); err != nil { diff --git a/flow/connectors/kafka/qrep.go b/flow/connectors/kafka/qrep.go new file mode 100644 index 0000000000..a856ad1ccf --- /dev/null +++ b/flow/connectors/kafka/qrep.go @@ -0,0 +1,111 @@ +package connkafka + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + "github.com/twmb/franz-go/pkg/kgo" + lua "github.com/yuin/gopher-lua" + + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/pua" +) + +func (*KafkaConnector) SetupQRepMetadataTables(_ context.Context, _ *protos.QRepConfig) error { + return nil +} + +func (c *KafkaConnector) SyncQRepRecords( + ctx context.Context, + config *protos.QRepConfig, + partition *protos.QRepPartition, + stream *model.QRecordStream, +) (int, error) { + startTime := time.Now() + numRecords := atomic.Int64{} + schema := stream.Schema() + + queueCtx, queueErr := context.WithCancelCause(ctx) + pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, queueErr) + if err != nil { + return 0, err + } + defer pool.Close() + +Loop: + for { + select { + case qrecord, ok := <-stream.Records: + if !ok { + c.logger.Info("flushing batches because no more records") + break Loop + } + + pool.Run(func(ls *lua.LState) []*kgo.Record { + items := model.NewRecordItems(len(qrecord)) + for i, val := range qrecord { + items.AddColumn(schema.Fields[i].Name, val) + } + record := &model.InsertRecord[model.RecordItems]{ + BaseRecord: model.BaseRecord{}, + Items: items, + SourceTableName: config.WatermarkTable, + DestinationTableName: config.DestinationTableIdentifier, + CommitID: 0, + } + + lfn := ls.Env.RawGetString("onRecord") + fn, ok := lfn.(*lua.LFunction) + if !ok { + queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) + return nil + } + + ls.Push(fn) + ls.Push(pua.LuaRecord.New(ls, record)) + err := ls.PCall(1, -1, nil) + if err != nil { + queueErr(fmt.Errorf("script failed: %w", err)) + return nil + } + + args := ls.GetTop() + results := make([]*kgo.Record, 0, args) + for i := range args { + kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args)) + if err != nil { + queueErr(err) + return nil + } + if kr != nil { + if kr.Topic == "" { + kr.Topic = record.GetDestinationTableName() + } + results = append(results, kr) + } + } + ls.SetTop(0) + numRecords.Add(1) + return results + }) + + case <-queueCtx.Done(): + break Loop + } + } + + if err := pool.Wait(queueCtx); err != nil { + return 0, err + } + if err := c.client.Flush(queueCtx); err != nil { + return 0, fmt.Errorf("[kafka] final flush error: %w", err) + } + + if err := c.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime); err != nil { + return 0, err + } + return int(numRecords.Load()), nil +} diff --git a/flow/connectors/pubsub/pubsub.go b/flow/connectors/pubsub/pubsub.go index b08150960c..54031f016d 100644 --- a/flow/connectors/pubsub/pubsub.go +++ b/flow/connectors/pubsub/pubsub.go @@ -121,6 +121,57 @@ func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (PubSubMessage, err }, nil } +func (c *PubSubConnector) createPool( + ctx context.Context, + script string, + flowJobName string, + topiccache *topicCache, + publish chan<- *pubsub.PublishResult, + queueErr func(error), +) (*utils.LPool[[]PubSubMessage], error) { + return utils.LuaPool(func() (*lua.LState, error) { + ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int { + top := ls.GetTop() + ss := make([]string, top) + for i := range top { + ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() + } + _ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t")) + return 0 + }) + if err != nil { + return nil, err + } + if script == "" { + ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) + } + return ls, nil + }, func(messages []PubSubMessage) { + for _, message := range messages { + topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) { + topicClient := c.client.Topic(message.Topic) + exists, err := topicClient.Exists(ctx) + if err != nil { + return nil, fmt.Errorf("error checking if topic exists: %w", err) + } + if !exists { + topicClient, err = c.client.CreateTopic(ctx, message.Topic) + if err != nil { + return nil, fmt.Errorf("error creating topic: %w", err) + } + } + return topicClient, nil + }) + if err != nil { + queueErr(err) + return + } + + publish <- topicClient.Publish(ctx, message.Message) + } + }) +} + type topicCache struct { cache map[string]*pubsub.Topic lock sync.RWMutex @@ -175,51 +226,10 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} publish := make(chan *pubsub.PublishResult, 32) - waitChan := make(chan struct{}) - wgCtx, wgErr := context.WithCancelCause(ctx) - pool, err := utils.LuaPool(func() (*lua.LState, error) { - ls, err := utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int { - top := ls.GetTop() - ss := make([]string, top) - for i := range top { - ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String() - } - _ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t")) - return 0 - }) - if err != nil { - return nil, err - } - if req.Script == "" { - ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) - } - return ls, nil - }, func(messages []PubSubMessage) { - for _, message := range messages { - topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) { - topicClient := c.client.Topic(message.Topic) - exists, err := topicClient.Exists(wgCtx) - if err != nil { - return nil, fmt.Errorf("error checking if topic exists: %w", err) - } - if !exists { - topicClient, err = c.client.CreateTopic(wgCtx, message.Topic) - if err != nil { - return nil, fmt.Errorf("error creating topic: %w", err) - } - } - return topicClient, nil - }) - if err != nil { - wgErr(err) - return - } - - publish <- topicClient.Publish(ctx, message.Message) - } - }) + queueCtx, queueErr := context.WithCancelCause(ctx) + pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, &topiccache, publish, queueErr) if err != nil { return nil, err } @@ -228,7 +238,7 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord go func() { for curpub := range publish { if _, err := curpub.Get(ctx); err != nil { - wgErr(err) + queueErr(err) break } } @@ -275,7 +285,7 @@ Loop: lfn := ls.Env.RawGetString("onRecord") fn, ok := lfn.(*lua.LFunction) if !ok { - wgErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) + queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) return nil } @@ -283,7 +293,7 @@ Loop: ls.Push(pua.LuaRecord.New(ls, record)) err := ls.PCall(1, -1, nil) if err != nil { - wgErr(fmt.Errorf("script failed: %w", err)) + queueErr(fmt.Errorf("script failed: %w", err)) return nil } @@ -292,7 +302,7 @@ Loop: for i := range args { msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args)) if err != nil { - wgErr(err) + queueErr(err) return nil } if msg.Message != nil { @@ -309,20 +319,20 @@ Loop: return results }) - case <-wgCtx.Done(): + case <-queueCtx.Done(): break Loop } } close(flushLoopDone) - close(publish) - if err := pool.Wait(wgCtx); err != nil { + if err := pool.Wait(queueCtx); err != nil { return nil, err } - topiccache.Stop(wgCtx) + close(publish) + topiccache.Stop(queueCtx) select { - case <-wgCtx.Done(): - return nil, wgCtx.Err() + case <-queueCtx.Done(): + return nil, queueCtx.Err() case <-waitChan: } diff --git a/flow/connectors/pubsub/qrep.go b/flow/connectors/pubsub/qrep.go new file mode 100644 index 0000000000..c1f21edc4a --- /dev/null +++ b/flow/connectors/pubsub/qrep.go @@ -0,0 +1,128 @@ +package connpubsub + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + "cloud.google.com/go/pubsub" + lua "github.com/yuin/gopher-lua" + + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/pua" +) + +func (*PubSubConnector) SetupQRepMetadataTables(_ context.Context, _ *protos.QRepConfig) error { + return nil +} + +func (c *PubSubConnector) SyncQRepRecords( + ctx context.Context, + config *protos.QRepConfig, + partition *protos.QRepPartition, + stream *model.QRecordStream, +) (int, error) { + startTime := time.Now() + numRecords := atomic.Int64{} + schema := stream.Schema() + topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} + publish := make(chan *pubsub.PublishResult, 32) + waitChan := make(chan struct{}) + + queueCtx, queueErr := context.WithCancelCause(ctx) + pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, &topiccache, publish, queueErr) + if err != nil { + return 0, err + } + defer pool.Close() + + go func() { + for curpub := range publish { + if _, err := curpub.Get(ctx); err != nil { + queueErr(err) + break + } + } + close(waitChan) + }() + +Loop: + for { + select { + case qrecord, ok := <-stream.Records: + if !ok { + c.logger.Info("flushing batches because no more records") + break Loop + } + + pool.Run(func(ls *lua.LState) []PubSubMessage { + items := model.NewRecordItems(len(qrecord)) + for i, val := range qrecord { + items.AddColumn(schema.Fields[i].Name, val) + } + record := &model.InsertRecord[model.RecordItems]{ + BaseRecord: model.BaseRecord{}, + Items: items, + SourceTableName: config.WatermarkTable, + DestinationTableName: config.DestinationTableIdentifier, + CommitID: 0, + } + + lfn := ls.Env.RawGetString("onRecord") + fn, ok := lfn.(*lua.LFunction) + if !ok { + queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) + return nil + } + + ls.Push(fn) + ls.Push(pua.LuaRecord.New(ls, record)) + err := ls.PCall(1, -1, nil) + if err != nil { + queueErr(fmt.Errorf("script failed: %w", err)) + return nil + } + + args := ls.GetTop() + results := make([]PubSubMessage, 0, args) + for i := range args { + msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args)) + if err != nil { + queueErr(err) + return nil + } + if msg.Message != nil { + if msg.Topic == "" { + msg.Topic = record.GetDestinationTableName() + } + results = append(results, msg) + } + } + ls.SetTop(0) + numRecords.Add(1) + return results + }) + + case <-queueCtx.Done(): + break Loop + } + } + + if err := pool.Wait(queueCtx); err != nil { + return 0, err + } + close(publish) + topiccache.Stop(queueCtx) + select { + case <-queueCtx.Done(): + return 0, queueCtx.Err() + case <-waitChan: + } + + if err := c.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime); err != nil { + return 0, err + } + return int(numRecords.Load()), nil +} diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index f42c9f4cd8..fadd09deab 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -125,7 +125,7 @@ func (s *SnowflakeAvroSyncHandler) SyncQRepRecords( err = s.connector.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime) if err != nil { - return -1, err + return 0, err } activity.RecordHeartbeat(ctx, "finished syncing records") diff --git a/flow/e2e/kafka/kafka_test.go b/flow/e2e/kafka/kafka_test.go index 83d4a42dd2..180383bd73 100644 --- a/flow/e2e/kafka/kafka_test.go +++ b/flow/e2e/kafka/kafka_test.go @@ -192,3 +192,59 @@ func (s KafkaSuite) TestDefault() { env.Cancel() e2e.RequireEnvCanceled(s.t, env) } + +func (s KafkaSuite) TestInitialLoad() { + srcTableName := e2e.AttachSchema(s, "kainitial") + + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL PRIMARY KEY, + val text + ); + `, srcTableName)) + require.NoError(s.t, err) + + flowName := e2e.AddSuffix(s, "kainitial") + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: flowName, + TableNameMapping: map[string]string{srcTableName: flowName}, + Destination: s.Peer(), + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.DoInitialSnapshot = true + + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s (id, val) VALUES (1, 'testval') + `, srcTableName)) + require.NoError(s.t, err) + + tc := e2e.NewTemporalClient(s.t) + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + e2e.EnvWaitFor(s.t, env, 3*time.Minute, "normalize insert", func() bool { + kafka, err := kgo.NewClient( + kgo.SeedBrokers("localhost:9092"), + kgo.ConsumeTopics(flowName), + ) + if err != nil { + return false + } + defer kafka.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + fetches := kafka.PollFetches(ctx) + fetches.EachTopic(func(ft kgo.FetchTopic) { + require.Equal(s.t, flowName, ft.Topic) + ft.EachRecord(func(r *kgo.Record) { + require.Contains(s.t, string(r.Value), "\"testval\"") + require.Equal(s.t, byte('{'), r.Value[0]) + require.Equal(s.t, byte('}'), r.Value[len(r.Value)-1]) + }) + }) + return true + }) + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +} diff --git a/flow/e2e/pubsub/pubsub_test.go b/flow/e2e/pubsub/pubsub_test.go index ec03fbb111..df6809b819 100644 --- a/flow/e2e/pubsub/pubsub_test.go +++ b/flow/e2e/pubsub/pubsub_test.go @@ -237,3 +237,73 @@ func (s PubSubSuite) TestSimple() { env.Cancel() e2e.RequireEnvCanceled(s.t, env) } + +func (s PubSubSuite) TestInitialLoad() { + srcTableName := e2e.AttachSchema(s, "psinitial") + + _, err := s.Conn().Exec(context.Background(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL PRIMARY KEY, + val text + ); + `, srcTableName)) + require.NoError(s.t, err) + + sa, err := ServiceAccount() + require.NoError(s.t, err) + + _, err = s.Conn().Exec(context.Background(), `insert into public.scripts (name, lang, source) values + ('e2e_psinitial', 'lua', 'function onRecord(r) return r.row and r.row.val end') on conflict do nothing`) + require.NoError(s.t, err) + + flowName := e2e.AddSuffix(s, "e2epsinitial") + connectionGen := e2e.FlowConnectionGenerationConfig{ + FlowJobName: flowName, + TableNameMapping: map[string]string{srcTableName: flowName}, + Destination: s.Peer(sa), + } + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs() + flowConnConfig.Script = "e2e_psinitial" + flowConnConfig.DoInitialSnapshot = true + + psclient, err := sa.CreatePubSubClient(context.Background()) + require.NoError(s.t, err) + defer psclient.Close() + topic, err := psclient.CreateTopic(context.Background(), flowName) + require.NoError(s.t, err) + sub, err := psclient.CreateSubscription(context.Background(), flowName, pubsub.SubscriptionConfig{ + Topic: topic, + RetentionDuration: 10 * time.Minute, + ExpirationPolicy: 24 * time.Hour, + }) + require.NoError(s.t, err) + _, err = s.Conn().Exec(context.Background(), fmt.Sprintf(` + INSERT INTO %s (id, val) VALUES (1, 'testval') + `, srcTableName)) + require.NoError(s.t, err) + + tc := e2e.NewTemporalClient(s.t) + env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil) + e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + msgs := make(chan *pubsub.Message) + go func() { + _ = sub.Receive(ctx, func(_ context.Context, m *pubsub.Message) { + msgs <- m + }) + }() + select { + case msg := <-msgs: + require.NotNil(s.t, msg) + require.Equal(s.t, "testval", string(msg.Data)) + case <-ctx.Done(): + s.t.Log("UNEXPECTED TIMEOUT PubSub subscription waiting on message") + s.t.Fail() + } + + env.Cancel() + e2e.RequireEnvCanceled(s.t, env) +} diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index fdec9d15c6..10757774d6 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -199,6 +199,7 @@ func (s *SnapshotFlowExecution) cloneTable( SoftDeleteColName: s.config.SoftDeleteColName, WriteMode: snapshotWriteMode, System: s.config.System, + Script: s.config.Script, } state := NewQRepFlowState() diff --git a/protos/flow.proto b/protos/flow.proto index 46516fea29..4d4343f71a 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -297,6 +297,7 @@ message QRepConfig { string soft_delete_col_name = 17; TypeSystem system = 18; + string script = 19; } message QRepPartition {