diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 871fd7403b..fd4a162e15 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -412,16 +412,19 @@ func (c *PostgresConnector) getLastNormalizeBatchID(jobName string) (int64, erro } func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) { - rows, err := c.pool.Query(c.ctx, - fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) + var result pgtype.Bool + err := c.pool.QueryRow(c.ctx, + fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) if err != nil { - return false, fmt.Errorf("failed to check if job exists: %w", err) + return false, fmt.Errorf("error reading result row: %w", err) } - defer rows.Close() + return result.Bool, nil +} +func (c *PostgresConnector) jobMetadataExistsTx(tx pgx.Tx, jobName string) (bool, error) { var result pgtype.Bool - rows.Next() - err = rows.Scan(&result) + err := tx.QueryRow(c.ctx, + fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) if err != nil { return false, fmt.Errorf("error reading result row: %w", err) } @@ -440,7 +443,7 @@ func (c *PostgresConnector) majorVersionCheck(majorVersion int) (bool, error) { func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64, syncBatchID int64, syncRecordsTx pgx.Tx) error { - jobMetadataExists, err := c.jobMetadataExists(flowJobName) + jobMetadataExists, err := c.jobMetadataExistsTx(syncRecordsTx, flowJobName) if err != nil { return fmt.Errorf("failed to get sync status for flow job: %w", err) } @@ -466,7 +469,7 @@ func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64, func (c *PostgresConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64, normalizeRecordsTx pgx.Tx) error { - jobMetadataExists, err := c.jobMetadataExists(flowJobName) + jobMetadataExists, err := c.jobMetadataExistsTx(normalizeRecordsTx, flowJobName) if err != nil { return fmt.Errorf("failed to get sync status for flow job: %w", err) } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index b1dcbb3c62..9920cdad36 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -731,15 +731,8 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error { } func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) (bool, error) { - rows, err := c.database.QueryContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier) - if err != nil { - return false, err - } - - // this query is guaranteed to return exactly one row var result pgtype.Bool - rows.Next() - err = rows.Scan(&result) + err := c.database.QueryRowContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier).Scan(&result) if err != nil { return false, fmt.Errorf("error while reading result row: %w", err) } @@ -926,15 +919,18 @@ func parseTableName(tableName string) (*tableNameComponents, error) { } func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) { - rows, err := c.database.QueryContext(c.ctx, - fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) + var result pgtype.Bool + err := c.database.QueryRowContext(c.ctx, + fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) if err != nil { - return false, fmt.Errorf("failed to check if job exists: %w", err) + return false, fmt.Errorf("error reading result row: %w", err) } - + return result.Bool, nil +} +func (c *SnowflakeConnector) jobMetadataExistsTx(tx *sql.Tx, jobName string) (bool, error) { var result pgtype.Bool - rows.Next() - err = rows.Scan(&result) + err := tx.QueryRowContext(c.ctx, + fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) if err != nil { return false, fmt.Errorf("error reading result row: %w", err) } @@ -943,7 +939,7 @@ func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) { func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64, syncBatchID int64, syncRecordsTx *sql.Tx) error { - jobMetadataExists, err := c.jobMetadataExists(flowJobName) + jobMetadataExists, err := c.jobMetadataExistsTx(syncRecordsTx, flowJobName) if err != nil { return fmt.Errorf("failed to get sync status for flow job: %w", err) }