Skip to content

Commit

Permalink
ported single statement MERGE and BQ txn remove
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Dec 27, 2023
1 parent 6f3eec2 commit 46e5d23
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 100 deletions.
73 changes: 8 additions & 65 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -774,24 +774,10 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
return nil, fmt.Errorf("couldn't get tablename to unchanged cols mapping: %w", err)
}

stmts := []string{}
stmts := make([]string, 0, len(distinctTableNames)+1)
// append all the statements to one list
log.Printf("merge raw records to corresponding tables: %s %s %v", c.datasetID, rawTableName, distinctTableNames)

release, err := c.grabJobsUpdateLock()
if err != nil {
return nil, fmt.Errorf("failed to grab lock: %v", err)
}

defer func() {
err := release()
if err != nil {
log.Errorf("failed to release lock: %v", err)
}
}()

stmts = append(stmts, "BEGIN TRANSACTION;")

for _, tableName := range distinctTableNames {
mergeGen := &mergeStmtGenerator{
Dataset: c.datasetID,
Expand All @@ -803,23 +789,19 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
UnchangedToastColumns: tableNametoUnchangedToastCols[tableName],
}
// normalize anything between last normalized batch id to last sync batchid
mergeStmts := mergeGen.generateMergeStmts()
stmts = append(stmts, mergeStmts...)
mergeStmts := mergeGen.generateMergeStmt()
stmts = append(stmts, mergeStmts)
}
// update metadata to make the last normalized batch id to the recent last sync batch id.
updateMetadataStmt := fmt.Sprintf(
"UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name = '%s';",
"UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';",
c.datasetID, MirrorJobsTable, syncBatchID, req.FlowJobName)
stmts = append(stmts, updateMetadataStmt)
stmts = append(stmts, "COMMIT TRANSACTION;")

// put this within a transaction
// TODO - not truncating rows in staging table as of now.
// err = c.truncateTable(staging...)

_, err = c.client.Query(strings.Join(stmts, "\n")).Read(c.ctx)
query := strings.Join(stmts, "\n")
_, err = c.client.Query(query).Read(c.ctx)
if err != nil {
return nil, fmt.Errorf("failed to execute statements %s in a transaction: %v", strings.Join(stmts, "\n"), err)
return nil, fmt.Errorf("failed to execute statements %s in a transaction: %v", query, err)
}

return &model.NormalizeResponse{
Expand Down Expand Up @@ -998,21 +980,9 @@ func (c *BigQueryConnector) SetupNormalizedTables(
}

func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error {
release, err := c.grabJobsUpdateLock()
if err != nil {
return fmt.Errorf("failed to grab lock: %w", err)
}

defer func() {
err := release()
if err != nil {
log.Printf("failed to release lock: %v", err)
}
}()

dataset := c.client.Dataset(c.datasetID)
// deleting PeerDB specific tables
err = dataset.Table(c.getRawTableName(jobName)).Delete(c.ctx)
err := dataset.Table(c.getRawTableName(jobName)).Delete(c.ctx)
if err != nil {
return fmt.Errorf("failed to delete raw table: %w", err)
}
Expand Down Expand Up @@ -1044,33 +1014,6 @@ func (c *BigQueryConnector) getStagingTableName(flowJobName string) string {
return fmt.Sprintf("_peerdb_staging_%s", flowJobName)
}

// Bigquery doesn't allow concurrent updates to the same table.
// we grab a lock on catalog to ensure that only one job is updating
// bigquery tables at a time.
// returns a function to release the lock.
func (c *BigQueryConnector) grabJobsUpdateLock() (func() error, error) {
tx, err := c.catalogPool.Begin(c.ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}

// grab an advisory lock based on the mirror jobs table hash
mjTbl := fmt.Sprintf("%s.%s", c.datasetID, MirrorJobsTable)
_, err = tx.Exec(c.ctx, "SELECT pg_advisory_xact_lock(hashtext($1))", mjTbl)
if err != nil {
err = tx.Rollback(c.ctx)
return nil, fmt.Errorf("failed to grab lock on %s: %w", mjTbl, err)
}

return func() error {
err = tx.Commit(c.ctx)
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}, nil
}

func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) {
for _, renameRequest := range req.RenameTableOptions {
src := renameRequest.CurrentName
Expand Down
28 changes: 5 additions & 23 deletions flow/connectors/bigquery/merge_statement_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model/qvalue"
util "github.com/PeerDB-io/peer-flow/utils"
)

type mergeStmtGenerator struct {
Expand All @@ -28,24 +27,6 @@ type mergeStmtGenerator struct {
UnchangedToastColumns []string
}

// GenerateMergeStmt generates a merge statements.
func (m *mergeStmtGenerator) generateMergeStmts() []string {
// return an empty array for now
flattenedCTE := m.generateFlattenedCTE()
deDupedCTE := m.generateDeDupedCTE()
tempTable := fmt.Sprintf("_peerdb_de_duplicated_data_%s", util.RandomString(5))
// create temp table stmt
createTempTableStmt := fmt.Sprintf(
"CREATE TEMP TABLE %s AS (%s, %s);",
tempTable, flattenedCTE, deDupedCTE)

mergeStmt := m.generateMergeStmt(tempTable)

dropTempTableStmt := fmt.Sprintf("DROP TABLE %s;", tempTable)

return []string{createTempTableStmt, mergeStmt, dropTempTableStmt}
}

// generateFlattenedCTE generates a flattened CTE.
func (m *mergeStmtGenerator) generateFlattenedCTE() string {
// for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR
Expand All @@ -61,7 +42,7 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string {

switch qvalue.QValueKind(colType) {
case qvalue.QValueKindJSON:
//if the type is JSON, then just extract JSON
// if the type is JSON, then just extract JSON
castStmt = fmt.Sprintf("CAST(JSON_VALUE(_peerdb_data, '$.%s') AS %s) AS `%s`",
colName, bqType, colName)
// expecting data in BASE64 format
Expand Down Expand Up @@ -124,7 +105,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string {
}

// generateMergeStmt generates a merge statement.
func (m *mergeStmtGenerator) generateMergeStmt(tempTable string) string {
func (m *mergeStmtGenerator) generateMergeStmt() string {
// comma separated list of column names
backtickColNames := make([]string, 0)
pureColNames := make([]string, 0)
Expand All @@ -146,14 +127,15 @@ func (m *mergeStmtGenerator) generateMergeStmt(tempTable string) string {
pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ")

return fmt.Sprintf(`
MERGE %s.%s _peerdb_target USING %s _peerdb_deduped
MERGE %s.%s _peerdb_target USING (%s,%s) _peerdb_deduped
ON %s
WHEN NOT MATCHED and (_peerdb_deduped._peerdb_record_type != 2) THEN
INSERT (%s) VALUES (%s)
%s
WHEN MATCHED AND (_peerdb_deduped._peerdb_record_type = 2) THEN
DELETE;
`, m.Dataset, m.NormalizedTable, tempTable, pkeySelectSQL, csep, csep, updateStringToastCols)
`, m.Dataset, m.NormalizedTable, m.generateFlattenedCTE(), m.generateDeDupedCTE(),
pkeySelectSQL, csep, csep, updateStringToastCols)
}

/*
Expand Down
29 changes: 17 additions & 12 deletions flow/connectors/bigquery/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ type QRepAvroSyncMethod struct {
}

func NewQRepAvroSyncMethod(connector *BigQueryConnector, gcsBucket string,
flowJobName string) *QRepAvroSyncMethod {
flowJobName string,
) *QRepAvroSyncMethod {
return &QRepAvroSyncMethod{
connector: connector,
gcsBucket: gcsBucket,
Expand Down Expand Up @@ -73,11 +74,12 @@ func (s *QRepAvroSyncMethod) SyncRecords(
)

// execute the statements in a transaction
stmts := []string{}
stmts = append(stmts, "BEGIN TRANSACTION;")
stmts = append(stmts, insertStmt)
stmts = append(stmts, updateMetadataStmt)
stmts = append(stmts, "COMMIT TRANSACTION;")
stmts := []string{
"BEGIN TRANSACTION;",
insertStmt,
updateMetadataStmt,
"COMMIT TRANSACTION;",
}
_, err = bqClient.Query(strings.Join(stmts, "\n")).Read(s.connector.ctx)
if err != nil {
return -1, fmt.Errorf("failed to execute statements in a transaction: %v", err)
Expand Down Expand Up @@ -133,14 +135,11 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
bqClient := s.connector.client
datasetID := s.connector.datasetID
// Start a transaction
stmts := []string{"BEGIN TRANSACTION;"}

// Insert the records from the staging table into the destination table
insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;",
datasetID, dstTableName, datasetID, stagingTable)

stmts = append(stmts, insertStmt)

insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime)
if err != nil {
return -1, fmt.Errorf("failed to create metadata insert statement: %v", err)
Expand All @@ -149,8 +148,13 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
"flowName": flowJobName,
}).Infof("Performing transaction inside QRep sync function for partition ID %s",
partition.PartitionId)
stmts = append(stmts, insertMetadataStmt)
stmts = append(stmts, "COMMIT TRANSACTION;")

stmts := []string{
"BEGIN TRANSACTION;",
insertStmt,
insertMetadataStmt,
"COMMIT TRANSACTION;",
}
// Execute the statements in a transaction
_, err = bqClient.Query(strings.Join(stmts, "\n")).Read(s.connector.ctx)
if err != nil {
Expand Down Expand Up @@ -187,7 +191,8 @@ type AvroSchema struct {
}

func DefineAvroSchema(dstTableName string,
dstTableMetadata *bigquery.TableMetadata) (*model.QRecordAvroSchemaDefinition, error) {
dstTableMetadata *bigquery.TableMetadata,
) (*model.QRecordAvroSchemaDefinition, error) {
avroFields := []AvroField{}
nullableFields := make(map[string]struct{})

Expand Down

0 comments on commit 46e5d23

Please sign in to comment.