Skip to content

Commit

Permalink
WaitFor (#980)
Browse files Browse the repository at this point in the history
`NormalizeFlowCountQuery` is stunting decoupled sync/normalize workflows
So replace it with `WaitFor`

Besides, I just don't like this `ExitAfterRecords` way of doing things
e2e tests are integration tests: implementation should be treated as a black box as much as possible
Temporal has a bunch of capabilities to mock activities so that we can create unit tests for the more intrusive tests that'd be necessary to raise branch coverage etc

`WaitFor` presents the ideal mechanism for testing convergent processes:
update source, wait for destination to reflect change

In order to make this change work, however,
I needed to use `env.CancelWorkflow` after completing tests
since I now want the workflow running indefinitely
It turns out our code doesn't adequately handle cancellation,
so implemented that
  • Loading branch information
serprex authored Jan 12, 2024
1 parent bf2544e commit 267fd25
Show file tree
Hide file tree
Showing 19 changed files with 657 additions and 631 deletions.
12 changes: 6 additions & 6 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,

go a.recordSlotSizePeriodically(errCtx, srcConn, slotNameForMetrics, input.FlowConnectionConfigs.Source.Name)

shutdown := utils.HeartbeatRoutine(ctx, 10*time.Second, func() string {
jobName := input.FlowConnectionConfigs.FlowJobName
return fmt.Sprintf("transferring records for job - %s", jobName)
})
defer shutdown()

// start a goroutine to pull records from the source
recordBatch := model.NewCDCRecordStream()
startTime := time.Now()
Expand Down Expand Up @@ -283,12 +289,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
return syncResponse, nil
}

shutdown := utils.HeartbeatRoutine(ctx, 10*time.Second, func() string {
jobName := input.FlowConnectionConfigs.FlowJobName
return fmt.Sprintf("pushing records for job - %s", jobName)
})
defer shutdown()

syncStartTime := time.Now()
res, err := dstConn.SyncRecords(&model.SyncRecordsRequest{
Records: recordBatch,
Expand Down
32 changes: 20 additions & 12 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ func (p *PostgresCDCSource) consumeStream(

if cdcRecordsStorage.Len() == 1 {
records.SignalAsNotEmpty()
p.logger.Info(fmt.Sprintf("pushing the standby deadline to %s", time.Now().Add(standbyMessageTimeout)))
nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout)
p.logger.Info(fmt.Sprintf("pushing the standby deadline to %s", nextStandbyMessageDeadline))
}
return nil
}
Expand Down Expand Up @@ -297,17 +297,19 @@ func (p *PostgresCDCSource) consumeStream(
}
}

if (cdcRecordsStorage.Len() >= int(req.MaxBatchSize)) && !p.commitLock {
return nil
}
if !p.commitLock {
if cdcRecordsStorage.Len() >= int(req.MaxBatchSize) {
return nil
}

if waitingForCommit && !p.commitLock {
p.logger.Info(fmt.Sprintf(
"[%s] commit received, returning currently accumulated records - %d",
p.flowJobName,
cdcRecordsStorage.Len()),
)
return nil
if waitingForCommit {
p.logger.Info(fmt.Sprintf(
"[%s] commit received, returning currently accumulated records - %d",
p.flowJobName,
cdcRecordsStorage.Len()),
)
return nil
}
}

// if we are past the next standby deadline (?)
Expand Down Expand Up @@ -340,9 +342,15 @@ func (p *PostgresCDCSource) consumeStream(
} else {
ctx, cancel = context.WithDeadline(p.ctx, nextStandbyMessageDeadline)
}

rawMsg, err := conn.ReceiveMessage(ctx)
cancel()

utils.RecordHeartbeatWithRecover(p.ctx, "consumeStream ReceiveMessage")
ctxErr := p.ctx.Err()
if ctxErr != nil {
return fmt.Errorf("consumeStream preempted: %w", ctxErr)
}

if err != nil && !p.commitLock {
if pgconn.Timeout(err) {
p.logger.Info(fmt.Sprintf("Stand-by deadline reached, returning currently accumulated records - %d",
Expand Down
6 changes: 2 additions & 4 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T
return nil, fmt.Errorf("error querying Snowflake peer for schema of table %s: %w", tableName, err)
}
defer func() {
// not sure if the errors these two return are same or different?
err = errors.Join(rows.Close(), rows.Err())
err = rows.Close()
if err != nil {
c.logger.Error("error while closing rows for reading schema of table",
slog.String("tableName", tableName),
Expand Down Expand Up @@ -289,8 +288,7 @@ func (c *SnowflakeConnector) GetLastOffset(jobName string) (int64, error) {
return 0, fmt.Errorf("error querying Snowflake peer for last syncedID: %w", err)
}
defer func() {
// not sure if the errors these two return are same or different?
err = errors.Join(rows.Close(), rows.Err())
err = rows.Close()
if err != nil {
c.logger.Error("error while closing rows for reading last offset", slog.Any("error", err))
}
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/utils/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func HeartbeatRoutine(
}
}
}()
return func() { shutdown <- struct{}{} }
return func() { close(shutdown) }
}

// if the functions are being called outside the context of a Temporal workflow,
Expand Down
50 changes: 27 additions & 23 deletions flow/e2e/bigquery/bigquery_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ type BigQueryTestHelper struct {
Peer *protos.Peer
// client to talk to BigQuery
client *bigquery.Client
// dataset to use for testing.
datasetName string
}

// NewBigQueryTestHelper creates a new BigQueryTestHelper.
Expand All @@ -51,7 +49,7 @@ func NewBigQueryTestHelper() (*BigQueryTestHelper, error) {
return nil, fmt.Errorf("failed to read file: %w", err)
}

var config protos.BigqueryConfig
var config *protos.BigqueryConfig
err = json.Unmarshal(content, &config)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal json: %w", err)
Expand All @@ -60,7 +58,7 @@ func NewBigQueryTestHelper() (*BigQueryTestHelper, error) {
// suffix the dataset with the runID to namespace stateful schemas.
config.DatasetId = fmt.Sprintf("%s_%d", config.DatasetId, runID)

bqsa, err := peer_bq.NewBigQueryServiceAccount(&config)
bqsa, err := peer_bq.NewBigQueryServiceAccount(config)
if err != nil {
return nil, fmt.Errorf("failed to create BigQueryServiceAccount: %v", err)
}
Expand All @@ -70,14 +68,13 @@ func NewBigQueryTestHelper() (*BigQueryTestHelper, error) {
return nil, fmt.Errorf("failed to create helper BigQuery client: %v", err)
}

peer := generateBQPeer(&config)
peer := generateBQPeer(config)

return &BigQueryTestHelper{
runID: runID,
Config: &config,
client: client,
datasetName: config.DatasetId,
Peer: peer,
runID: runID,
Config: config,
client: client,
Peer: peer,
}, nil
}

Expand Down Expand Up @@ -115,12 +112,12 @@ func (b *BigQueryTestHelper) datasetExists(datasetName string) (bool, error) {

// RecreateDataset recreates the dataset, i.e, deletes it if exists and creates it again.
func (b *BigQueryTestHelper) RecreateDataset() error {
exists, err := b.datasetExists(b.datasetName)
exists, err := b.datasetExists(b.Config.DatasetId)
if err != nil {
return fmt.Errorf("failed to check if dataset %s exists: %w", b.Config.DatasetId, err)
}

dataset := b.client.Dataset(b.datasetName)
dataset := b.client.Dataset(b.Config.DatasetId)
if exists {
err := dataset.DeleteWithContents(context.Background())
if err != nil {
Expand Down Expand Up @@ -158,7 +155,9 @@ func (b *BigQueryTestHelper) DropDataset(datasetName string) error {

// RunCommand runs the given command.
func (b *BigQueryTestHelper) RunCommand(command string) error {
_, err := b.client.Query(command).Read(context.Background())
q := b.client.Query(command)
q.DisableQueryCache = true
_, err := q.Read(context.Background())
if err != nil {
return fmt.Errorf("failed to run command: %w", err)
}
Expand All @@ -168,7 +167,7 @@ func (b *BigQueryTestHelper) RunCommand(command string) error {

// countRows(tableName) returns the number of rows in the given table.
func (b *BigQueryTestHelper) countRows(tableName string) (int, error) {
return b.countRowsWithDataset(b.datasetName, tableName, "")
return b.countRowsWithDataset(b.Config.DatasetId, tableName, "")
}

func (b *BigQueryTestHelper) countRowsWithDataset(dataset, tableName string, nonNullCol string) (int, error) {
Expand All @@ -177,7 +176,9 @@ func (b *BigQueryTestHelper) countRowsWithDataset(dataset, tableName string, non
command = fmt.Sprintf("SELECT COUNT(CASE WHEN " + nonNullCol +
" IS NOT NULL THEN 1 END) AS non_null_count FROM `" + dataset + "." + tableName + "`;")
}
it, err := b.client.Query(command).Read(context.Background())
q := b.client.Query(command)
q.DisableQueryCache = true
it, err := q.Read(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to run command: %w", err)
}
Expand Down Expand Up @@ -305,7 +306,9 @@ func bqSchemaToQRecordSchema(schema bigquery.Schema) (*model.QRecordSchema, erro
}

func (b *BigQueryTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecordBatch, error) {
it, err := b.client.Query(query).Read(context.Background())
q := b.client.Query(query)
q.DisableQueryCache = true
it, err := q.Read(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to run command: %w", err)
}
Expand Down Expand Up @@ -358,18 +361,17 @@ func (b *BigQueryTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecor
}, nil
}

/*
if the function errors or there are nulls, the function returns false
else true
*/
// returns whether the function errors or there are nulls
func (b *BigQueryTestHelper) CheckNull(tableName string, ColName []string) (bool, error) {
if len(ColName) == 0 {
return true, nil
}
joinedString := strings.Join(ColName, " is null or ") + " is null"
command := fmt.Sprintf("SELECT COUNT(*) FROM `%s.%s` WHERE %s",
b.Config.DatasetId, tableName, joinedString)
it, err := b.client.Query(command).Read(context.Background())
q := b.client.Query(command)
q.DisableQueryCache = true
it, err := q.Read(context.Background())
if err != nil {
return false, fmt.Errorf("failed to run command: %w", err)
}
Expand Down Expand Up @@ -401,7 +403,9 @@ func (b *BigQueryTestHelper) CheckDoubleValues(tableName string, ColName []strin
csep := strings.Join(ColName, ",")
command := fmt.Sprintf("SELECT %s FROM `%s.%s`",
csep, b.Config.DatasetId, tableName)
it, err := b.client.Query(command).Read(context.Background())
q := b.client.Query(command)
q.DisableQueryCache = true
it, err := q.Read(context.Background())
if err != nil {
return false, fmt.Errorf("failed to run command: %w", err)
}
Expand Down Expand Up @@ -474,7 +478,7 @@ func (b *BigQueryTestHelper) CreateTable(tableName string, schema *model.QRecord
fields = append(fields, fmt.Sprintf("`%s` %s", field.Name, bqType))
}

command := fmt.Sprintf("CREATE TABLE %s.%s (%s)", b.datasetName, tableName, strings.Join(fields, ", "))
command := fmt.Sprintf("CREATE TABLE %s.%s (%s)", b.Config.DatasetId, tableName, strings.Join(fields, ", "))

err := b.RunCommand(command)
if err != nil {
Expand Down
Loading

0 comments on commit 267fd25

Please sign in to comment.