Skip to content

Commit

Permalink
making all Snowflake DB access use a context (#876)
Browse files Browse the repository at this point in the history
Closes #766

Co-authored-by: Kaushik Iska <[email protected]>
  • Loading branch information
heavycrystal and iskakaushik authored Dec 22, 2023
1 parent 45a4205 commit 60044de
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
4 changes: 2 additions & 2 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType
LIMIT 0
`, tableName)

rows, err := c.database.Query(queryString)
rows, err := c.database.QueryContext(c.ctx, queryString)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
Expand Down Expand Up @@ -294,7 +294,7 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnIn
WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s'
`, components.tableIdentifier, components.schemaIdentifier)

rows, err := c.database.Query(queryString)
rows, err := c.database.QueryContext(c.ctx, queryString)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
Expand Down
16 changes: 7 additions & 9 deletions flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ func (s *SnowflakeAvroSyncMethod) SyncQRepRecords(
s.connector.logger.Info("sync function called and schema acquired", partitionLog)

err = s.addMissingColumns(
config.FlowJobName,
schema,
dstTableSchema,
dstTableName,
Expand Down Expand Up @@ -152,7 +151,6 @@ func (s *SnowflakeAvroSyncMethod) SyncQRepRecords(
}

func (s *SnowflakeAvroSyncMethod) addMissingColumns(
flowJobName string,
schema *model.QRecordSchema,
dstTableSchema []*sql.ColumnType,
dstTableName string,
Expand Down Expand Up @@ -197,7 +195,7 @@ func (s *SnowflakeAvroSyncMethod) addMissingColumns(
s.connector.logger.Info(fmt.Sprintf("altering destination table %s with command `%s`",
dstTableName, alterTableCmd), partitionLog)

if _, err := tx.Exec(alterTableCmd); err != nil {
if _, err := tx.ExecContext(s.connector.ctx, alterTableCmd); err != nil {
return fmt.Errorf("failed to alter destination table: %w", err)
}
}
Expand Down Expand Up @@ -290,7 +288,7 @@ func (s *SnowflakeAvroSyncMethod) putFileToStage(avroFile *avro.AvroFile, stage
shutdown <- struct{}{}
}()

if _, err := s.connector.database.Exec(putCmd); err != nil {
if _, err := s.connector.database.ExecContext(s.connector.ctx, putCmd); err != nil {
return fmt.Errorf("failed to put file to stage: %w", err)
}

Expand Down Expand Up @@ -395,7 +393,7 @@ func (s *SnowflakeAvroSyncMethod) insertMetadata(
return fmt.Errorf("failed to create metadata insert statement: %v", err)
}

if _, err := s.connector.database.Exec(insertMetadataStmt); err != nil {
if _, err := s.connector.database.ExecContext(s.connector.ctx, insertMetadataStmt); err != nil {
s.connector.logger.Error("failed to execute metadata insert statement "+insertMetadataStmt,
slog.Any("error", err), partitionLog)
return fmt.Errorf("failed to execute metadata insert statement: %v", err)
Expand Down Expand Up @@ -434,7 +432,7 @@ func (s *SnowflakeAvroWriteHandler) HandleAppendMode(
copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s",
s.dstTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ","))
s.connector.logger.Info("running copy command: " + copyCmd)
_, err := s.connector.database.Exec(copyCmd)
_, err := s.connector.database.ExecContext(s.connector.ctx, copyCmd)
if err != nil {
return fmt.Errorf("failed to run COPY INTO command: %w", err)
}
Expand Down Expand Up @@ -518,15 +516,15 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode(
//nolint:gosec
createTempTableCmd := fmt.Sprintf("CREATE TEMPORARY TABLE %s AS SELECT * FROM %s LIMIT 0",
tempTableName, s.dstTableName)
if _, err := s.connector.database.Exec(createTempTableCmd); err != nil {
if _, err := s.connector.database.ExecContext(s.connector.ctx, createTempTableCmd); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
}
s.connector.logger.Info("created temp table " + tempTableName)

//nolint:gosec
copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s",
tempTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ","))
_, err = s.connector.database.Exec(copyCmd)
_, err = s.connector.database.ExecContext(s.connector.ctx, copyCmd)
if err != nil {
return fmt.Errorf("failed to run COPY INTO command: %w", err)
}
Expand All @@ -538,7 +536,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode(
}

startTime := time.Now()
rows, err := s.connector.database.Exec(mergeCmd)
rows, err := s.connector.database.ExecContext(s.connector.ctx, mergeCmd)
if err != nil {
return fmt.Errorf("failed to merge data into destination table '%s': %w", mergeCmd, err)
}
Expand Down
5 changes: 3 additions & 2 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,9 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string,
return fmt.Errorf("failed to convert column type %s to snowflake type: %w",
addedColumn.ColumnType, err)
}
_, err = tableSchemaModifyTx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s",
schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype))
_, err = tableSchemaModifyTx.ExecContext(c.ctx,
fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s",
schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype))
if err != nil {
return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName,
schemaDelta.DstTableName, err)
Expand Down

0 comments on commit 60044de

Please sign in to comment.