Skip to content

Commit

Permalink
Refactor jobMetadataExists in postgres/snowflake connectors
Browse files Browse the repository at this point in the history
Use QueryRow instead of Query,
& introduce a Tx variant which can reuse existing connection
  • Loading branch information
serprex committed Dec 14, 2023
1 parent 46208f1 commit c6431b3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
19 changes: 11 additions & 8 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,19 @@ func (c *PostgresConnector) getLastNormalizeBatchID(jobName string) (int64, erro
}

func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) {
rows, err := c.pool.Query(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName)
var result pgtype.Bool
err := c.pool.QueryRow(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("failed to check if job exists: %w", err)
return false, fmt.Errorf("error reading result row: %w", err)
}
defer rows.Close()
return result.Bool, nil
}

func (c *PostgresConnector) jobMetadataExistsTx(tx pgx.Tx, jobName string) (bool, error) {
var result pgtype.Bool
rows.Next()
err = rows.Scan(&result)
err := tx.QueryRow(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("error reading result row: %w", err)
}
Expand All @@ -440,7 +443,7 @@ func (c *PostgresConnector) majorVersionCheck(majorVersion int) (bool, error) {

func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64, syncBatchID int64,
syncRecordsTx pgx.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
jobMetadataExists, err := c.jobMetadataExistsTx(syncRecordsTx, flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
}
Expand All @@ -466,7 +469,7 @@ func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64,

func (c *PostgresConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64,
normalizeRecordsTx pgx.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
jobMetadataExists, err := c.jobMetadataExistsTx(normalizeRecordsTx, flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
}
Expand Down
26 changes: 11 additions & 15 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,15 +734,8 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error {
}

func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) (bool, error) {
rows, err := c.database.QueryContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier)
if err != nil {
return false, err
}

// this query is guaranteed to return exactly one row
var result pgtype.Bool
rows.Next()
err = rows.Scan(&result)
err := c.database.QueryRowContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier).Scan(&result)
if err != nil {
return false, fmt.Errorf("error while reading result row: %w", err)
}
Expand Down Expand Up @@ -929,15 +922,18 @@ func parseTableName(tableName string) (*tableNameComponents, error) {
}

func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) {
rows, err := c.database.QueryContext(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName)
var result pgtype.Bool
err := c.database.QueryRowContext(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("failed to check if job exists: %w", err)
return false, fmt.Errorf("error reading result row: %w", err)
}

return result.Bool, nil
}
func (c *SnowflakeConnector) jobMetadataExistsTx(tx *sql.Tx, jobName string) (bool, error) {
var result pgtype.Bool
rows.Next()
err = rows.Scan(&result)
err := tx.QueryRowContext(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("error reading result row: %w", err)
}
Expand All @@ -946,7 +942,7 @@ func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) {

func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64,
syncBatchID int64, syncRecordsTx *sql.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
jobMetadataExists, err := c.jobMetadataExistsTx(syncRecordsTx, flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
}
Expand Down

0 comments on commit c6431b3

Please sign in to comment.