Skip to content

Commit

Permalink
Merge branch 'main' into lua-q
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex authored May 10, 2024
2 parents 82c9920 + b80d434 commit 4cd5efc
Show file tree
Hide file tree
Showing 21 changed files with 112 additions and 165 deletions.
22 changes: 6 additions & 16 deletions flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,27 +142,19 @@ func (h *FlowRequestHandler) CreateCDCFlow(

if req.ConnectionConfigs.SoftDeleteColName == "" {
req.ConnectionConfigs.SoftDeleteColName = "_PEERDB_IS_DELETED"
} else {
// make them all uppercase
req.ConnectionConfigs.SoftDeleteColName = strings.ToUpper(req.ConnectionConfigs.SoftDeleteColName)
}

if req.ConnectionConfigs.SyncedAtColName == "" {
req.ConnectionConfigs.SyncedAtColName = "_PEERDB_SYNCED_AT"
} else {
// make them all uppercase
req.ConnectionConfigs.SyncedAtColName = strings.ToUpper(req.ConnectionConfigs.SyncedAtColName)
}

if req.CreateCatalogEntry {
err := h.createCdcJobEntry(ctx, req, workflowID)
if err != nil {
slog.Error("unable to create flow job entry", slog.Any("error", err))
return nil, fmt.Errorf("unable to create flow job entry: %w", err)
}
err := h.createCdcJobEntry(ctx, req, workflowID)
if err != nil {
slog.Error("unable to create flow job entry", slog.Any("error", err))
return nil, fmt.Errorf("unable to create flow job entry: %w", err)
}

err := h.updateFlowConfigInCatalog(ctx, cfg)
err = h.updateFlowConfigInCatalog(ctx, cfg)
if err != nil {
slog.Error("unable to update flow config in catalog", slog.Any("error", err))
return nil, fmt.Errorf("unable to update flow config in catalog: %w", err)
Expand Down Expand Up @@ -258,10 +250,8 @@ func (h *FlowRequestHandler) CreateQRepFlow(

if req.QrepConfig.SyncedAtColName == "" {
cfg.SyncedAtColName = "_PEERDB_SYNCED_AT"
} else {
// make them all uppercase
cfg.SyncedAtColName = strings.ToUpper(req.QrepConfig.SyncedAtColName)
}

_, err := h.temporalClient.ExecuteWorkflow(ctx, workflowOptions, workflowFn, cfg, state)
if err != nil {
slog.Error("unable to start QRepFlow workflow",
Expand Down
28 changes: 15 additions & 13 deletions flow/cmd/validate_mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
func (h *FlowRequestHandler) ValidateCDCMirror(
ctx context.Context, req *protos.CreateCDCFlowRequest,
) (*protos.ValidateCDCMirrorResponse, error) {
if req.CreateCatalogEntry && !req.ConnectionConfigs.Resync {
if !req.ConnectionConfigs.Resync {
mirrorExists, existCheckErr := h.CheckIfMirrorNameExists(ctx, req.ConnectionConfigs.FlowJobName)
if existCheckErr != nil {
slog.Error("/validatecdc failed to check if mirror name exists", slog.Any("error", existCheckErr))
Expand Down Expand Up @@ -46,7 +46,9 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig()
if sourcePeerConfig == nil {
slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source))
return nil, errors.New("source peer config is nil")
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, errors.New("source peer config is nil")
}

pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig)
Expand Down Expand Up @@ -103,17 +105,17 @@ func (h *FlowRequestHandler) ValidateCDCMirror(
}

pubName := req.ConnectionConfigs.PublicationName
if pubName != "" {
err = pgPeer.CheckSourceTables(ctx, sourceTables, pubName)
if err != nil {
displayErr := fmt.Errorf("provided source tables invalidated: %v", err)
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}

err = pgPeer.CheckSourceTables(ctx, sourceTables, pubName)
if err != nil {
displayErr := fmt.Errorf("provided source tables invalidated: %v", err)
slog.Error(displayErr.Error())
h.alerter.LogNonFlowWarning(ctx, telemetry.CreateMirror, req.ConnectionConfigs.FlowJobName,
fmt.Sprint(displayErr),
)
return &protos.ValidateCDCMirrorResponse{
Ok: false,
}, displayErr
}

return &protos.ValidateCDCMirrorResponse{
Expand Down
18 changes: 14 additions & 4 deletions flow/connectors/clickhouse/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ func (s *ClickhouseAvroSyncMethod) CopyStageToDestination(ctx context.Context, a
if err != nil {
return err
}

sessionTokenPart := ""
if creds.AWS.SessionToken != "" {
sessionTokenPart = fmt.Sprintf(", '%s'", creds.AWS.SessionToken)
}
//nolint:gosec
query := fmt.Sprintf("INSERT INTO %s SELECT * FROM s3('%s','%s','%s', '%s', 'Avro')",
query := fmt.Sprintf("INSERT INTO %s SELECT * FROM s3('%s','%s','%s'%s, 'Avro')",
s.config.DestinationTableIdentifier, avroFileUrl,
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, creds.AWS.SessionToken)
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, sessionTokenPart)

_, err = s.connector.database.ExecContext(ctx, query)

Expand Down Expand Up @@ -137,10 +142,15 @@ func (s *ClickhouseAvroSyncMethod) SyncQRepRecords(
selector = append(selector, "`"+colName+"`")
}
selectorStr := strings.Join(selector, ",")

sessionTokenPart := ""
if creds.AWS.SessionToken != "" {
sessionTokenPart = fmt.Sprintf(", '%s'", creds.AWS.SessionToken)
}
//nolint:gosec
query := fmt.Sprintf("INSERT INTO %s(%s) SELECT %s FROM s3('%s','%s','%s', '%s', 'Avro')",
query := fmt.Sprintf("INSERT INTO %s(%s) SELECT %s FROM s3('%s','%s','%s'%s, 'Avro')",
config.DestinationTableIdentifier, selectorStr, selectorStr, avroFileUrl,
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, creds.AWS.SessionToken)
creds.AWS.AccessKeyID, creds.AWS.SecretAccessKey, sessionTokenPart)

_, err = s.connector.database.ExecContext(ctx, query)
if err != nil {
Expand Down
14 changes: 4 additions & 10 deletions flow/connectors/connelasticsearch/elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,23 @@ func (esc *ElasticsearchConnector) SyncRecords(ctx context.Context,
req *model.SyncRecordsRequest[model.RecordItems],
) (*model.SyncResponse, error) {
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
// atomics for counts will be unnecessary in other destinations, using a mutex instead
var recordCountsUpdateMutex sync.Mutex
// we're taking a mutex anyway, avoid atomic
var lastSeenLSN atomic.Int64
var numRecords atomic.Int64
var numRecords int64

// no I don't like this either
esBulkIndexerCache := make(map[string]esutil.BulkIndexer)
bulkIndexersHaveShutdown := false
// true if we saw errors while closing
cacheCloser := func() bool {
closeHasErrors := false
if bulkIndexersHaveShutdown {
if !bulkIndexersHaveShutdown {
for _, esBulkIndexer := range maps.Values(esBulkIndexerCache) {
err := esBulkIndexer.Close(context.Background())
if err != nil {
esc.logger.Error("[es] failed to close bulk indexer", slog.Any("error", err))
closeHasErrors = true
}
numRecords += int64(esBulkIndexer.Stats().NumFlushed)
}
bulkIndexersHaveShutdown = true
}
Expand Down Expand Up @@ -237,9 +235,6 @@ func (esc *ElasticsearchConnector) SyncRecords(ctx context.Context,

OnSuccess: func(_ context.Context, _ esutil.BulkIndexerItem, _ esutil.BulkIndexerResponseItem) {
shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID())
numRecords.Add(1)
recordCountsUpdateMutex.Lock()
defer recordCountsUpdateMutex.Unlock()
record.PopulateCountMap(tableNameRowsMapping)
},
// OnFailure is called for each failed operation, log and let parent handle
Expand Down Expand Up @@ -284,7 +279,6 @@ func (esc *ElasticsearchConnector) SyncRecords(ctx context.Context,
esc.logger.Error("[es] failed to close bulk indexer(s)")
return nil, errors.New("[es] failed to close bulk indexer(s)")
}
bulkIndexersHaveShutdown = true
if len(bulkIndexErrors) > 0 {
for _, err := range bulkIndexErrors {
esc.logger.Error("[es] failed to index record", slog.Any("err", err))
Expand All @@ -299,7 +293,7 @@ func (esc *ElasticsearchConnector) SyncRecords(ctx context.Context,
return &model.SyncResponse{
CurrentSyncBatchID: req.SyncBatchID,
LastSyncedCheckpointID: lastCheckpoint,
NumRecordsSynced: numRecords.Load(),
NumRecordsSynced: numRecords,
TableNameRowsMapping: tableNameRowsMapping,
TableSchemaDeltas: req.Records.SchemaDeltas,
}, nil
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig)
// create a separate connection pool for non-replication queries as replication connections cannot
// be used for extended query protocol, i.e. prepared statements
connConfig, err := pgx.ParseConfig(connectionString)
replConfig := connConfig.Copy()
if err != nil {
return nil, fmt.Errorf("failed to parse connection string: %w", err)
}

replConfig := connConfig.Copy()
runtimeParams := connConfig.Config.RuntimeParams
runtimeParams["idle_in_transaction_session_timeout"] = "0"
runtimeParams["statement_timeout"] = "0"
Expand Down
33 changes: 18 additions & 15 deletions flow/connectors/postgres/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,31 @@ func (c *PostgresConnector) CheckSourceTables(ctx context.Context,
}

tableStr := strings.Join(tableArr, ",")
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)

if pubName != "" {
// Check if publication exists
err := c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname=$1", pubName).Scan(nil)
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("publication does not exist: %s", pubName)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}
return fmt.Errorf("error while checking for publication existence: %w", err)
}

// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
// Check if tables belong to publication
var pubTableCount int
err = c.conn.QueryRow(ctx, fmt.Sprintf(`
with source_table_components (sname, tname) as (values %s)
select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables
INNER JOIN source_table_components stc
ON schemaname=stc.sname and tablename=stc.tname where pubname=$1;`, tableStr), pubName).Scan(&pubTableCount)
if err != nil {
return err
}
if err != nil {
return err
}

if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
if pubTableCount != len(tableNames) {
return errors.New("not all tables belong to publication")
}
}

return nil
Expand Down
1 change: 1 addition & 0 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (
SELECT _PEERDB_UID,_PEERDB_TIMESTAMP,TO_VARIANT(PARSE_JSON(_PEERDB_DATA)) %s,_PEERDB_RECORD_TYPE,
_PEERDB_MATCH_DATA,_PEERDB_BATCH_ID,_PEERDB_UNCHANGED_TOAST_COLUMNS
FROM _PEERDB_INTERNAL.%s WHERE _PEERDB_BATCH_ID = %d AND
_PEERDB_DATA != '' AND
_PEERDB_DESTINATION_TABLE_NAME = ? ), FLATTENED AS
(SELECT _PEERDB_UID,_PEERDB_TIMESTAMP,_PEERDB_RECORD_TYPE,_PEERDB_MATCH_DATA,_PEERDB_BATCH_ID,
_PEERDB_UNCHANGED_TOAST_COLUMNS,%s
Expand Down
6 changes: 3 additions & 3 deletions flow/e2e/bigquery/peer_flow_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1110,8 +1110,8 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_IUD_Same_Batch() {
Source: e2e.GeneratePostgresPeer(),
CdcStagingPath: connectionGen.CdcStagingPath,
SoftDelete: true,
SoftDeleteColName: "_PEERDB_IS_DELETED",
SyncedAtColName: "_PEERDB_SYNCED_AT",
SoftDeleteColName: "_custom_deleted",
SyncedAtColName: "_custom_synced",
MaxBatchSize: 100,
}

Expand Down Expand Up @@ -1141,7 +1141,7 @@ func (s PeerFlowE2ETestSuiteBQ) Test_Soft_Delete_IUD_Same_Batch() {
e2e.EnvWaitForEqualTables(env, s, "normalizing tx", "test_softdel_iud", "id,c1,c2,t")
e2e.EnvWaitFor(s.t, env, 3*time.Minute, "checking soft delete", func() bool {
newerSyncedAtQuery := fmt.Sprintf(
"SELECT COUNT(*) FROM `%s.%s` WHERE _PEERDB_IS_DELETED",
"SELECT COUNT(*) FROM `%s.%s` WHERE _custom_deleted",
s.bqHelper.Config.DatasetId, dstTableName)
numNewRows, err := s.bqHelper.RunInt64Query(newerSyncedAtQuery)
e2e.EnvNoError(s.t, env, err)
Expand Down
27 changes: 24 additions & 3 deletions nexus/analyzer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use pt::{
peerdb_peers::{
peer::Config, BigqueryConfig, ClickhouseConfig, DbType, EventHubConfig, GcpServiceAccount,
KafkaConfig, MongoConfig, Peer, PostgresConfig, PubSubConfig, S3Config, SnowflakeConfig,
SqlServerConfig,
SqlServerConfig, SshConfig,
},
};
use qrep::process_options;
Expand Down Expand Up @@ -300,6 +300,11 @@ impl StatementAnalyzer for PeerDDLAnalyzer {
_ => None,
};

let sync_interval: Option<u64> = match raw_options.remove("sync_interval") {
Some(Expr::Value(ast::Value::Number(n, _))) => Some(n.parse::<u64>()?),
_ => None,
};

let soft_delete_col_name: Option<String> = match raw_options
.remove("soft_delete_col_name")
{
Expand Down Expand Up @@ -347,6 +352,7 @@ impl StatementAnalyzer for PeerDDLAnalyzer {
push_batch_size,
push_parallelism,
max_batch_size,
sync_interval,
resync,
soft_delete_col_name,
synced_at_col_name,
Expand Down Expand Up @@ -646,6 +652,19 @@ fn parse_db_options(db_type: DbType, with_options: &[SqlOption]) -> anyhow::Resu
Config::MongoConfig(mongo_config)
}
DbType::Postgres => {
let ssh_fields: Option<SshConfig> = match opts.get("ssh_config") {
Some(ssh_config) => {
let ssh_config_str = ssh_config.to_string();
if ssh_config_str.is_empty() {
None
} else {
serde_json::from_str(&ssh_config_str)
.context("failed to deserialize ssh_config")?
}
}
None => None,
};

let postgres_config = PostgresConfig {
host: opts.get("host").context("no host specified")?.to_string(),
port: opts
Expand All @@ -667,8 +686,9 @@ fn parse_db_options(db_type: DbType, with_options: &[SqlOption]) -> anyhow::Resu
.to_string(),
metadata_schema: opts.get("metadata_schema").map(|s| s.to_string()),
transaction_snapshot: "".to_string(),
ssh_config: None,
ssh_config: ssh_fields,
};

Config::PostgresConfig(postgres_config)
}
DbType::S3 => {
Expand Down Expand Up @@ -744,7 +764,8 @@ fn parse_db_options(db_type: DbType, with_options: &[SqlOption]) -> anyhow::Resu
.unwrap_or_default(),
disable_tls: opts
.get("disable_tls")
.map(|s| s.parse::<bool>().unwrap_or_default()).unwrap_or_default(),
.map(|s| s.parse::<bool>().unwrap_or_default())
.unwrap_or_default(),
endpoint: opts.get("endpoint").map(|s| s.to_string()),
};
Config::ClickhouseConfig(clickhouse_config)
Expand Down
5 changes: 4 additions & 1 deletion nexus/analyzer/src/qrep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ pub fn process_options(

// If mode is upsert, we need unique key columns
if opts.get("mode") == Some(&Value::String(String::from("upsert")))
&& opts.get("unique_key_columns").map(|ukc| ukc == &Value::Array(Vec::new())).unwrap_or(true)
&& opts
.get("unique_key_columns")
.map(|ukc| ukc == &Value::Array(Vec::new()))
.unwrap_or(true)
{
anyhow::bail!("For upsert mode, unique_key_columns must be specified");
}
Expand Down
Loading

0 comments on commit 4cd5efc

Please sign in to comment.