Skip to content

Commit

Permalink
pubsub
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed May 7, 2024
1 parent ab60ed3 commit 1f084e7
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 43 deletions.
96 changes: 53 additions & 43 deletions flow/connectors/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,57 @@ func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (PubSubMessage, err
}, nil
}

func (c *PubSubConnector) createPool(
ctx context.Context,
script string,
flowJobName string,
topiccache *topicCache,
publish chan<- *pubsub.PublishResult,
wgErr func(error),
) (*utils.LPool[[]PubSubMessage], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, flowJobName, strings.Join(ss, "\t"))
return 0
})
if err != nil {
return nil, err
}
if script == "" {
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}
return ls, nil
}, func(messages []PubSubMessage) {
for _, message := range messages {
topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) {
topicClient := c.client.Topic(message.Topic)
exists, err := topicClient.Exists(ctx)
if err != nil {
return nil, fmt.Errorf("error checking if topic exists: %w", err)
}
if !exists {
topicClient, err = c.client.CreateTopic(ctx, message.Topic)
if err != nil {
return nil, fmt.Errorf("error creating topic: %w", err)
}
}
return topicClient, nil
})
if err != nil {
wgErr(err)
return
}

publish <- topicClient.Publish(ctx, message.Message)
}
})
}

type topicCache struct {
cache map[string]*pubsub.Topic
lock sync.RWMutex
Expand Down Expand Up @@ -175,51 +226,10 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
topiccache := topicCache{cache: make(map[string]*pubsub.Topic)}
publish := make(chan *pubsub.PublishResult, 32)

waitChan := make(chan struct{})
wgCtx, wgErr := context.WithCancelCause(ctx)

pool, err := utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, req.Script, func(ls *lua.LState) int {
top := ls.GetTop()
ss := make([]string, top)
for i := range top {
ss[i] = ls.ToStringMeta(ls.Get(i + 1)).String()
}
_ = c.LogFlowInfo(ctx, req.FlowJobName, strings.Join(ss, "\t"))
return 0
})
if err != nil {
return nil, err
}
if req.Script == "" {
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}
return ls, nil
}, func(messages []PubSubMessage) {
for _, message := range messages {
topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) {
topicClient := c.client.Topic(message.Topic)
exists, err := topicClient.Exists(wgCtx)
if err != nil {
return nil, fmt.Errorf("error checking if topic exists: %w", err)
}
if !exists {
topicClient, err = c.client.CreateTopic(wgCtx, message.Topic)
if err != nil {
return nil, fmt.Errorf("error creating topic: %w", err)
}
}
return topicClient, nil
})
if err != nil {
wgErr(err)
return
}

publish <- topicClient.Publish(ctx, message.Message)
}
})
wgCtx, wgErr := context.WithCancelCause(ctx)
pool, err := c.createPool(wgCtx, req.Script, req.FlowJobName, &topiccache, publish, wgErr)
if err != nil {
return nil, err
}
Expand Down
128 changes: 128 additions & 0 deletions flow/connectors/pubsub/qrep.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package connpubsub

import (
"context"
"fmt"
"sync/atomic"
"time"

"cloud.google.com/go/pubsub"
lua "github.com/yuin/gopher-lua"

"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/pua"
)

func (*PubSubConnector) SetupQRepMetadataTables(_ context.Context, _ *protos.QRepConfig) error {
return nil
}

func (c *PubSubConnector) SyncQRepRecords(
ctx context.Context,
config *protos.QRepConfig,
partition *protos.QRepPartition,
stream *model.QRecordStream,
) (int, error) {
startTime := time.Now()
numRecords := atomic.Int64{}
schema := stream.Schema()
topiccache := topicCache{cache: make(map[string]*pubsub.Topic)}
publish := make(chan *pubsub.PublishResult, 32)
waitChan := make(chan struct{})

wgCtx, wgErr := context.WithCancelCause(ctx)
pool, err := c.createPool(wgCtx, config.Script, config.FlowJobName, &topiccache, publish, wgErr)
if err != nil {
return 0, err
}
defer pool.Close()

go func() {
for curpub := range publish {
if _, err := curpub.Get(ctx); err != nil {
wgErr(err)
break
}
}
close(waitChan)
}()

Loop:
for {
select {
case qrecord, ok := <-stream.Records:
if !ok {
c.logger.Info("flushing batches because no more records")
break Loop
}

pool.Run(func(ls *lua.LState) []PubSubMessage {
items := model.NewRecordItems(len(qrecord))
for i, val := range qrecord {
items.AddColumn(schema.Fields[i].Name, val)
}
record := &model.InsertRecord[model.RecordItems]{
BaseRecord: model.BaseRecord{},
Items: items,
SourceTableName: config.WatermarkTable,
DestinationTableName: config.DestinationTableIdentifier,
CommitID: 0,
}

lfn := ls.Env.RawGetString("onRecord")
fn, ok := lfn.(*lua.LFunction)
if !ok {
wgErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
wgErr(fmt.Errorf("script failed: %w", err))
return nil
}

args := ls.GetTop()
results := make([]PubSubMessage, 0, args)
for i := range args {
msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args))
if err != nil {
wgErr(err)
return nil
}
if msg.Message != nil {
if msg.Topic == "" {
msg.Topic = record.GetDestinationTableName()
}
results = append(results, msg)
}
}
ls.SetTop(0)
numRecords.Add(1)
return results
})

case <-wgCtx.Done():
break Loop
}
}

close(publish)
if err := pool.Wait(wgCtx); err != nil {
return 0, err
}
topiccache.Stop(wgCtx)
select {
case <-wgCtx.Done():
return 0, wgCtx.Err()
case <-waitChan:
}

if err := c.FinishQRepPartition(ctx, partition, config.FlowJobName, startTime); err != nil {
return 0, err
}
return int(numRecords.Load()), nil
}
71 changes: 71 additions & 0 deletions flow/e2e/pubsub/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,74 @@ func (s PubSubSuite) TestSimple() {
env.Cancel()
e2e.RequireEnvCanceled(s.t, env)
}

func (s PubSubSuite) TestInitialLoad() {
srcTableName := e2e.AttachSchema(s, "psinitial")

_, err := s.Conn().Exec(context.Background(), fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id SERIAL PRIMARY KEY,
val text
);
`, srcTableName))
require.NoError(s.t, err)

sa, err := ServiceAccount()
require.NoError(s.t, err)

_, err = s.Conn().Exec(context.Background(), `insert into public.scripts (name, lang, source) values
('e2e_psinitial', 'lua', 'function onRecord(r) return r.row and r.row.val end') on conflict do nothing`)
require.NoError(s.t, err)

flowName := e2e.AddSuffix(s, "e2epsinitial")
connectionGen := e2e.FlowConnectionGenerationConfig{
FlowJobName: flowName,
TableNameMapping: map[string]string{srcTableName: flowName},
Destination: s.Peer(sa),
}
flowConnConfig := connectionGen.GenerateFlowConnectionConfigs()
flowConnConfig.Script = "e2e_psinitial"
flowConnConfig.DoInitialSnapshot = true

psclient, err := sa.CreatePubSubClient(context.Background())
require.NoError(s.t, err)
defer psclient.Close()
topic, err := psclient.CreateTopic(context.Background(), flowName)
require.NoError(s.t, err)
sub, err := psclient.CreateSubscription(context.Background(), flowName, pubsub.SubscriptionConfig{
Topic: topic,
RetentionDuration: 10 * time.Minute,
ExpirationPolicy: 24 * time.Hour,
})
require.NoError(s.t, err)
_, err = s.Conn().Exec(context.Background(), fmt.Sprintf(`
INSERT INTO %s (id, val) VALUES (1, 'testval')
`, srcTableName))
require.NoError(s.t, err)

tc := e2e.NewTemporalClient(s.t)
env := e2e.ExecutePeerflow(tc, peerflow.CDCFlowWorkflow, flowConnConfig, nil)
e2e.SetupCDCFlowStatusQuery(s.t, env, connectionGen)

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()

msgs := make(chan *pubsub.Message)
go func() {
_ = sub.Receive(ctx, func(_ context.Context, m *pubsub.Message) {
msgs <- m
})
}()
select {
case msg := <-msgs:
require.NotNil(s.t, msg)
require.Equal(s.t, "testval", string(msg.Data))
case <-ctx.Done():
s.t.Log("UNEXPECTED TIMEOUT PubSub subscription waiting on message")
s.t.Fail()
}

env.Cancel()
e2e.RequireEnvCanceled(s.t, env)
}

0 comments on commit 1f084e7

Please sign in to comment.