Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor jobMetadataExists in postgres/snowflake connectors #826

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,19 @@ func (c *PostgresConnector) getLastNormalizeBatchID(jobName string) (int64, erro
}

func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) {
rows, err := c.pool.Query(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName)
var result pgtype.Bool
err := c.pool.QueryRow(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("failed to check if job exists: %w", err)
return false, fmt.Errorf("error reading result row: %w", err)
}
defer rows.Close()
return result.Bool, nil
}

func (c *PostgresConnector) jobMetadataExistsTx(tx pgx.Tx, jobName string) (bool, error) {
var result pgtype.Bool
rows.Next()
err = rows.Scan(&result)
err := tx.QueryRow(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("error reading result row: %w", err)
}
Expand All @@ -440,7 +443,7 @@ func (c *PostgresConnector) majorVersionCheck(majorVersion int) (bool, error) {

func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64, syncBatchID int64,
syncRecordsTx pgx.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
jobMetadataExists, err := c.jobMetadataExistsTx(syncRecordsTx, flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
}
Expand All @@ -466,7 +469,7 @@ func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64,

func (c *PostgresConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64,
normalizeRecordsTx pgx.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
jobMetadataExists, err := c.jobMetadataExistsTx(normalizeRecordsTx, flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
}
Expand Down
26 changes: 11 additions & 15 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,15 +731,8 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error {
}

func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) (bool, error) {
rows, err := c.database.QueryContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier)
if err != nil {
return false, err
}

// this query is guaranteed to return exactly one row
var result pgtype.Bool
rows.Next()
err = rows.Scan(&result)
err := c.database.QueryRowContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier).Scan(&result)
if err != nil {
return false, fmt.Errorf("error while reading result row: %w", err)
}
Expand Down Expand Up @@ -926,15 +919,18 @@ func parseTableName(tableName string) (*tableNameComponents, error) {
}

func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) {
rows, err := c.database.QueryContext(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName)
var result pgtype.Bool
err := c.database.QueryRowContext(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("failed to check if job exists: %w", err)
return false, fmt.Errorf("error reading result row: %w", err)
}

return result.Bool, nil
}
func (c *SnowflakeConnector) jobMetadataExistsTx(tx *sql.Tx, jobName string) (bool, error) {
var result pgtype.Bool
rows.Next()
err = rows.Scan(&result)
err := tx.QueryRowContext(c.ctx,
fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result)
if err != nil {
return false, fmt.Errorf("error reading result row: %w", err)
}
Expand All @@ -943,7 +939,7 @@ func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) {

func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64,
syncBatchID int64, syncRecordsTx *sql.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
jobMetadataExists, err := c.jobMetadataExistsTx(syncRecordsTx, flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
}
Expand Down
Loading