From 2aa4aa2fdcd3b2c52ef471b6735e41399772910a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Philip=20Dub=C3=A9?= <philip@peerdb.io>
Date: Fri, 22 Dec 2023 15:29:55 +0000
Subject: [PATCH] BQ: avoid generating SQL while holding metadata lock

---
 flow/connectors/bigquery/bigquery.go          | 33 ++++++++++---------
 .../bigquery/merge_statement_generator.go     |  6 ++--
 2 files changed, 20 insertions(+), 19 deletions(-)

diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go
index 2749566ced..f97c726294 100644
--- a/flow/connectors/bigquery/bigquery.go
+++ b/flow/connectors/bigquery/bigquery.go
@@ -782,23 +782,11 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
 		return nil, fmt.Errorf("couldn't get tablename to unchanged cols mapping: %w", err)
 	}
 
-	stmts := []string{}
 	// append all the statements to one list
 	c.logger.Info(fmt.Sprintf("merge raw records to corresponding tables: %s %s %v",
 		c.datasetID, rawTableName, distinctTableNames))
 
-	release, err := c.grabJobsUpdateLock()
-	if err != nil {
-		return nil, fmt.Errorf("failed to grab lock: %v", err)
-	}
-
-	defer func() {
-		err := release()
-		if err != nil {
-			c.logger.Error("failed to release lock", slog.Any("error", err))
-		}
-	}()
-
+	stmts := make([]string, 0, len(distinctTableNames)*3+3)
 	stmts = append(stmts, "BEGIN TRANSACTION;")
 
 	for _, tableName := range distinctTableNames {
@@ -817,8 +805,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
 			},
 		}
 		// normalize anything between last normalized batch id to last sync batchid
-		mergeStmts := mergeGen.generateMergeStmts()
-		stmts = append(stmts, mergeStmts...)
+		createTemp, mergeStmt, dropTemp := mergeGen.generateMergeStmts()
+		stmts = append(stmts, createTemp, mergeStmt, dropTemp)
 	}
 	// update metadata to make the last normalized batch id to the recent last sync batch id.
 	updateMetadataStmt := fmt.Sprintf(
@@ -826,12 +814,25 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
 		c.datasetID, MirrorJobsTable, syncBatchID, req.FlowJobName)
 	stmts = append(stmts, updateMetadataStmt)
 	stmts = append(stmts, "COMMIT TRANSACTION;")
+	mergeQuery := strings.Join(stmts, "\n")
+
+	release, err := c.grabJobsUpdateLock()
+	if err != nil {
+		return nil, fmt.Errorf("failed to grab lock: %v", err)
+	}
+
+	defer func() {
+		err := release()
+		if err != nil {
+			c.logger.Error("failed to release lock", slog.Any("error", err))
+		}
+	}()
 
 	// put this within a transaction
 	// TODO - not truncating rows in staging table as of now.
 	// err = c.truncateTable(staging...)
 
-	_, err = c.client.Query(strings.Join(stmts, "\n")).Read(c.ctx)
+	_, err = c.client.Query(mergeQuery).Read(c.ctx)
 	if err != nil {
 		return nil, fmt.Errorf("failed to execute statements %s in a transaction: %v", strings.Join(stmts, "\n"), err)
 	}
diff --git a/flow/connectors/bigquery/merge_statement_generator.go b/flow/connectors/bigquery/merge_statement_generator.go
index 22161c434b..8a560ee627 100644
--- a/flow/connectors/bigquery/merge_statement_generator.go
+++ b/flow/connectors/bigquery/merge_statement_generator.go
@@ -30,8 +30,8 @@ type mergeStmtGenerator struct {
 	peerdbCols *protos.PeerDBColumns
 }
 
-// GenerateMergeStmt generates a merge statements.
-func (m *mergeStmtGenerator) generateMergeStmts() []string {
+// GenerateMergeStmt returns 3 strings to create temp, merge, drop temp
+func (m *mergeStmtGenerator) generateMergeStmts() (string, string, string) {
 	// return an empty array for now
 	flattenedCTE := m.generateFlattenedCTE()
 	deDupedCTE := m.generateDeDupedCTE()
@@ -45,7 +45,7 @@ func (m *mergeStmtGenerator) generateMergeStmts() []string {
 
 	dropTempTableStmt := fmt.Sprintf("DROP TABLE %s;", tempTable)
 
-	return []string{createTempTableStmt, mergeStmt, dropTempTableStmt}
+	return createTempTableStmt, mergeStmt, dropTempTableStmt
 }
 
 // generateFlattenedCTE generates a flattened CTE.