Skip to content

Commit

Permalink
Merge branch 'main' into normalize-split
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex authored Jan 25, 2024
2 parents bcea38b + 94cb719 commit 7daabd0
Show file tree
Hide file tree
Showing 33 changed files with 340 additions and 281 deletions.
12 changes: 6 additions & 6 deletions flow/cmd/peer_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ import (

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"google.golang.org/protobuf/proto"

connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres"
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
)

Expand All @@ -33,17 +31,19 @@ func (h *FlowRequestHandler) getPGPeerConfig(ctx context.Context, peerName strin
return &pgPeerConfig, nil
}

func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*pgxpool.Pool, error) {
func (h *FlowRequestHandler) getPoolForPGPeer(ctx context.Context, peerName string) (*connpostgres.SSHWrappedPostgresPool, error) {
pgPeerConfig, err := h.getPGPeerConfig(ctx, peerName)
if err != nil {
return nil, err
}
connStr := utils.GetPGConnectionString(pgPeerConfig)
peerPool, err := pgxpool.New(ctx, connStr)

pool, err := connpostgres.NewSSHWrappedPostgresPoolFromConfig(ctx, pgPeerConfig)
if err != nil {
slog.Error("Failed to create postgres pool", slog.Any("error", err))
return nil, err
}
return peerPool, nil

return pool, nil
}

func (h *FlowRequestHandler) GetSchemas(
Expand Down
15 changes: 7 additions & 8 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@ import (
func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
pgPeer, err := connpostgres.NewPostgresConnector(ctx, req.ConnectionConfigs.Source.GetPostgresConfig())
sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, fmt.Errorf("source peer config is nil")
}

pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig)
if err != nil {
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, fmt.Errorf("failed to create postgres connector: %v", err)
}

defer pgPeer.Close()

sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, fmt.Errorf("source peer config is nil")
}

// Check permissions of postgres peer
err = pgPeer.CheckReplicationPermissions(sourcePeerConfig.User)
if err != nil {
Expand Down
65 changes: 38 additions & 27 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,11 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string,
for _, addedColumn := range schemaDelta.AddedColumns {
dstDatasetTable, _ := c.convertToDatasetTable(schemaDelta.DstTableName)
query := c.client.Query(fmt.Sprintf(
"ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS `%s` %s", dstDatasetTable.dataset,
"ALTER TABLE %s ADD COLUMN IF NOT EXISTS `%s` %s",
dstDatasetTable.table, addedColumn.ColumnName,
qValueKindToBigQueryType(addedColumn.ColumnType)))
query.DefaultProjectID = c.projectID
query.DefaultDatasetID = c.datasetID
query.DefaultDatasetID = dstDatasetTable.dataset
_, err := query.Read(c.ctx)
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
Expand Down Expand Up @@ -312,7 +312,7 @@ func (c *BigQueryConnector) SetupMetadataTables() error {
}

func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) {
query := fmt.Sprintf("SELECT offset FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("SELECT offset FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
q.DefaultDatasetID = c.datasetID
Expand All @@ -339,8 +339,7 @@ func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) {

func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) error {
query := fmt.Sprintf(
"UPDATE %s.%s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'",
c.datasetID,
"UPDATE %s SET offset = GREATEST(offset, %d) WHERE mirror_job_name = '%s'",
MirrorJobsTable,
lastOffset,
jobName,
Expand All @@ -357,8 +356,8 @@ func (c *BigQueryConnector) SetLastOffset(jobName string, lastOffset int64) erro
}

func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) {
query := fmt.Sprintf("SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name = '%s'",
c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("SELECT sync_batch_id FROM %s WHERE mirror_job_name = '%s'",
MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DisableQueryCache = true
q.DefaultProjectID = c.projectID
Expand All @@ -385,8 +384,8 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) {
}

func (c *BigQueryConnector) GetLastNormalizeBatchID(jobName string) (int64, error) {
query := fmt.Sprintf("SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'",
c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("SELECT normalize_batch_id FROM %s WHERE mirror_job_name = '%s'",
MirrorJobsTable, jobName)
q := c.client.Query(query)
q.DisableQueryCache = true
q.DefaultProjectID = c.projectID
Expand Down Expand Up @@ -416,9 +415,9 @@ func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syn
rawTableName := c.getRawTableName(flowJobName)

// Prepare the query to retrieve distinct tables in that batch
query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s.%s
query := fmt.Sprintf(`SELECT DISTINCT _peerdb_destination_table_name FROM %s
WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d`,
c.datasetID, rawTableName, normalizeBatchID, syncBatchID)
rawTableName, normalizeBatchID, syncBatchID)
// Run the query
q := c.client.Query(query)
q.DefaultProjectID = c.projectID
Expand Down Expand Up @@ -459,10 +458,10 @@ func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, sync
// where a placeholder value for unchanged cols can be set in DeleteRecord if there is no backfill
// we don't want these particular DeleteRecords to be used in the update statement
query := fmt.Sprintf(`SELECT _peerdb_destination_table_name,
array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s.%s
array_agg(DISTINCT _peerdb_unchanged_toast_columns) as unchanged_toast_columns FROM %s
WHERE _peerdb_batch_id > %d AND _peerdb_batch_id <= %d AND _peerdb_record_type != 2
GROUP BY _peerdb_destination_table_name`,
c.datasetID, rawTableName, normalizeBatchID, syncBatchID)
rawTableName, normalizeBatchID, syncBatchID)
// Run the query
q := c.client.Query(query)
q.DefaultDatasetID = c.datasetID
Expand Down Expand Up @@ -587,6 +586,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
dstDatasetTable, _ := c.convertToDatasetTable(tableName)
mergeGen := &mergeStmtGenerator{
rawDatasetTable: &datasetTable{
project: c.projectID,
dataset: c.datasetID,
table: rawTableName,
},
Expand All @@ -609,7 +609,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
i+1, len(mergeStmts), tableName))
q := c.client.Query(mergeStmt)
q.DefaultProjectID = c.projectID
q.DefaultDatasetID = c.datasetID
q.DefaultDatasetID = dstDatasetTable.dataset
_, err = q.Read(c.ctx)
if err != nil {
return nil, fmt.Errorf("failed to execute merge statement %s: %v", mergeStmt, err)
Expand All @@ -618,8 +618,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest)
}
// update metadata to make the last normalized batch id to the recent last sync batch id.
updateMetadataStmt := fmt.Sprintf(
"UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';",
c.datasetID, MirrorJobsTable, req.SyncBatchID, req.FlowJobName)
"UPDATE %s SET normalize_batch_id=%d WHERE mirror_job_name='%s';",
MirrorJobsTable, req.SyncBatchID, req.FlowJobName)

query := c.client.Query(updateMetadataStmt)
query.DefaultProjectID = c.projectID
Expand Down Expand Up @@ -719,12 +719,12 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec

// create the job in the metadata table
jobStatement := fmt.Sprintf(
"INSERT INTO %s.%s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);",
c.datasetID, MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID)
"INSERT INTO %s (mirror_job_name,offset,sync_batch_id) VALUES ('%s',%d,%d);",
MirrorJobsTable, jobName, lastSyncedCheckpointID, batchID)
if hasJob {
jobStatement = fmt.Sprintf(
"UPDATE %s.%s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';",
c.datasetID, MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName)
"UPDATE %s SET offset=GREATEST(offset,%d),sync_batch_id=%d WHERE mirror_job_name = '%s';",
MirrorJobsTable, lastSyncedCheckpointID, batchID, jobName)
}

return jobStatement, nil
Expand All @@ -733,8 +733,8 @@ func (c *BigQueryConnector) getUpdateMetadataStmt(jobName string, lastSyncedChec
// metadataHasJob checks if the metadata table has the given job.
func (c *BigQueryConnector) metadataHasJob(jobName string) (bool, error) {
checkStmt := fmt.Sprintf(
"SELECT COUNT(*) FROM %s.%s WHERE mirror_job_name = '%s'",
c.datasetID, MirrorJobsTable, jobName)
"SELECT COUNT(*) FROM %s WHERE mirror_job_name = '%s'",
MirrorJobsTable, jobName)

q := c.client.Query(checkStmt)
q.DefaultProjectID = c.projectID
Expand Down Expand Up @@ -804,13 +804,14 @@ func (c *BigQueryConnector) SetupNormalizedTables(

// convert the column names and types to bigquery types
columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.ColumnNames)+2)
utils.IterColumns(tableSchema, func(colName, genericColType string) {
for i, colName := range tableSchema.ColumnNames {
genericColType := tableSchema.ColumnTypes[i]
columns = append(columns, &bigquery.FieldSchema{
Name: colName,
Type: qValueKindToBigQueryType(genericColType),
Repeated: qvalue.QValueKind(genericColType).IsArray(),
})
})
}

if req.SoftDeleteColName != "" {
columns = append(columns, &bigquery.FieldSchema{
Expand Down Expand Up @@ -872,7 +873,7 @@ func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error {
}

// deleting job from metadata table
query := fmt.Sprintf("DELETE FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName)
query := fmt.Sprintf("DELETE FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName)
queryHandler := c.client.Query(query)
queryHandler.DefaultProjectID = c.projectID
queryHandler.DefaultDatasetID = c.datasetID
Expand Down Expand Up @@ -902,7 +903,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos
dstDatasetTable.string()))

if req.SoftDeleteColName != nil {
allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",")
allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",")
pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",")

c.logger.InfoContext(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string()))
Expand Down Expand Up @@ -1011,12 +1012,16 @@ func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFro
}

type datasetTable struct {
project string
dataset string
table string
}

func (d *datasetTable) string() string {
return fmt.Sprintf("%s.%s", d.dataset, d.table)
if d.project == "" {
return fmt.Sprintf("%s.%s", d.dataset, d.table)
}
return fmt.Sprintf("%s.%s.%s", d.project, d.dataset, d.table)
}

func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTable, error) {
Expand All @@ -1031,6 +1036,12 @@ func (c *BigQueryConnector) convertToDatasetTable(tableName string) (*datasetTab
dataset: parts[0],
table: parts[1],
}, nil
} else if len(parts) == 3 {
return &datasetTable{
project: parts[0],
dataset: parts[1],
table: parts[2],
}, nil
} else {
return nil, fmt.Errorf("invalid BigQuery table name: %s", tableName)
}
Expand Down
24 changes: 10 additions & 14 deletions flow/connectors/bigquery/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type mergeStmtGenerator struct {
func (m *mergeStmtGenerator) generateFlattenedCTE() string {
// for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR
// statement.
flattenedProjs := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)+3)
flattenedProjs := make([]string, 0, len(m.normalizedTableSchema.ColumnNames)+3)

for i, colName := range m.normalizedTableSchema.ColumnNames {
colType := m.normalizedTableSchema.ColumnTypes[i]
Expand Down Expand Up @@ -93,9 +93,9 @@ func (m *mergeStmtGenerator) generateFlattenedCTE() string {
)

// normalize anything between last normalized batch id to last sync batchid
return fmt.Sprintf(`WITH _f AS
(SELECT %s FROM %s WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND
_peerdb_destination_table_name='%s')`,
return fmt.Sprintf("WITH _f AS "+
"(SELECT %s FROM `%s` WHERE _peerdb_batch_id>%d AND _peerdb_batch_id<=%d AND "+
"_peerdb_destination_table_name='%s')",
strings.Join(flattenedProjs, ","), m.rawDatasetTable.string(), m.normalizeBatchID,
m.syncBatchID, m.dstTableName)
}
Expand Down Expand Up @@ -124,7 +124,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string {
// generateMergeStmt generates a merge statement.
func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) string {
// comma separated list of column names
columnCount := utils.TableSchemaColumns(m.normalizedTableSchema)
columnCount := len(m.normalizedTableSchema.ColumnNames)
backtickColNames := make([]string, 0, columnCount)
shortBacktickColNames := make([]string, 0, columnCount)
pureColNames := make([]string, 0, columnCount)
Expand Down Expand Up @@ -169,15 +169,11 @@ func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) s
}
}

return fmt.Sprintf(`
MERGE %s _t USING(%s,%s) _d
ON %s
WHEN NOT MATCHED AND _d._rt!=2 THEN
INSERT (%s) VALUES(%s)
%s
WHEN MATCHED AND _d._rt=2 THEN
%s;
`, m.dstDatasetTable.string(), m.generateFlattenedCTE(), m.generateDeDupedCTE(),
return fmt.Sprintf("MERGE `%s` _t USING(%s,%s) _d"+
" ON %s WHEN NOT MATCHED AND _d._rt!=2 THEN "+
"INSERT (%s) VALUES(%s) "+
"%s WHEN MATCHED AND _d._rt=2 THEN %s;",
m.dstDatasetTable.table, m.generateFlattenedCTE(), m.generateDeDupedCTE(),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)
}

Expand Down
8 changes: 4 additions & 4 deletions flow/connectors/bigquery/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ func (c *BigQueryConnector) createMetadataInsertStatement(
partitionJSON := string(pbytes)

insertMetadataStmt := fmt.Sprintf(
"INSERT INTO %s._peerdb_query_replication_metadata"+
"INSERT INTO _peerdb_query_replication_metadata"+
"(flowJobName, partitionID, syncPartition, syncStartTime, syncFinishTime) "+
"VALUES ('%s', '%s', JSON '%s', TIMESTAMP('%s'), CURRENT_TIMESTAMP());",
c.datasetID, jobName, partition.PartitionId,
jobName, partition.PartitionId,
partitionJSON, startTime.Format(time.RFC3339))

return insertMetadataStmt, nil
Expand Down Expand Up @@ -170,8 +170,8 @@ func (c *BigQueryConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e

func (c *BigQueryConnector) isPartitionSynced(partitionID string) (bool, error) {
queryString := fmt.Sprintf(
"SELECT COUNT(*) FROM %s._peerdb_query_replication_metadata WHERE partitionID = '%s';",
c.datasetID, partitionID,
"SELECT COUNT(*) FROM _peerdb_query_replication_metadata WHERE partitionID = '%s';",
partitionID,
)

query := c.client.Query(queryString)
Expand Down
10 changes: 6 additions & 4 deletions flow/connectors/bigquery/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (s *QRepAvroSyncMethod) SyncRecords(
stagingTable := fmt.Sprintf("%s_%s_staging", rawTableName, strconv.FormatInt(syncBatchID, 10))
numRecords, err := s.writeToStage(strconv.FormatInt(syncBatchID, 10), rawTableName, avroSchema,
&datasetTable{
project: s.connector.projectID,
dataset: s.connector.datasetID,
table: stagingTable,
}, stream, req.FlowJobName)
Expand All @@ -67,8 +68,8 @@ func (s *QRepAvroSyncMethod) SyncRecords(

bqClient := s.connector.client
datasetID := s.connector.datasetID
insertStmt := fmt.Sprintf("INSERT INTO `%s.%s` SELECT * FROM `%s.%s`;",
datasetID, rawTableName, datasetID, stagingTable)
insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT * FROM `%s`;",
rawTableName, stagingTable)

lastCP, err := req.Records.GetLastCheckpoint()
if err != nil {
Expand Down Expand Up @@ -171,6 +172,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
// create a staging table name with partitionID replace hyphens with underscores
dstDatasetTable, _ := s.connector.convertToDatasetTable(dstTableName)
stagingDatasetTable := &datasetTable{
project: s.connector.projectID,
dataset: dstDatasetTable.dataset,
table: fmt.Sprintf("%s_%s_staging", dstDatasetTable.table,
strings.ReplaceAll(partition.PartitionId, "-", "_")),
Expand Down Expand Up @@ -198,7 +200,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
}
// Insert the records from the staging table into the destination table
insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT %s FROM `%s`;",
dstDatasetTable.string(), selector, stagingDatasetTable.string())
dstTableName, selector, stagingDatasetTable.string())

insertMetadataStmt, err := s.connector.createMetadataInsertStatement(partition, flowJobName, startTime)
if err != nil {
Expand Down Expand Up @@ -229,7 +231,7 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords(
flowLog)
}

slog.Info(fmt.Sprintf("loaded stage into %s", dstDatasetTable.string()), flowLog)
slog.Info(fmt.Sprintf("loaded stage into %s", dstTableName), flowLog)
return numRecords, nil
}

Expand Down
Loading

0 comments on commit 7daabd0

Please sign in to comment.