From 60044de384a0f7244410233b54f659d356290f2f Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Fri, 22 Dec 2023 18:28:34 +0530 Subject: [PATCH] making all Snowflake DB access use a context (#876) Closes #766 Co-authored-by: Kaushik Iska --- flow/connectors/snowflake/qrep.go | 4 ++-- flow/connectors/snowflake/qrep_avro_sync.go | 16 +++++++--------- flow/connectors/snowflake/snowflake.go | 5 +++-- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index 98d20b63ff..def870c183 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -84,7 +84,7 @@ func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType LIMIT 0 `, tableName) - rows, err := c.database.Query(queryString) + rows, err := c.database.QueryContext(c.ctx, queryString) if err != nil { return nil, fmt.Errorf("failed to execute query: %w", err) } @@ -294,7 +294,7 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) (*model.ColumnIn WHERE UPPER(table_name) = '%s' AND UPPER(table_schema) = '%s' `, components.tableIdentifier, components.schemaIdentifier) - rows, err := c.database.Query(queryString) + rows, err := c.database.QueryContext(c.ctx, queryString) if err != nil { return nil, fmt.Errorf("failed to execute query: %w", err) } diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 7184898ae3..30834c2554 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -112,7 +112,6 @@ func (s *SnowflakeAvroSyncMethod) SyncQRepRecords( s.connector.logger.Info("sync function called and schema acquired", partitionLog) err = s.addMissingColumns( - config.FlowJobName, schema, dstTableSchema, dstTableName, @@ -152,7 +151,6 @@ func (s *SnowflakeAvroSyncMethod) SyncQRepRecords( } func (s *SnowflakeAvroSyncMethod) addMissingColumns( - flowJobName string, schema *model.QRecordSchema, dstTableSchema []*sql.ColumnType, dstTableName string, @@ -197,7 +195,7 @@ func (s *SnowflakeAvroSyncMethod) addMissingColumns( s.connector.logger.Info(fmt.Sprintf("altering destination table %s with command `%s`", dstTableName, alterTableCmd), partitionLog) - if _, err := tx.Exec(alterTableCmd); err != nil { + if _, err := tx.ExecContext(s.connector.ctx, alterTableCmd); err != nil { return fmt.Errorf("failed to alter destination table: %w", err) } } @@ -290,7 +288,7 @@ func (s *SnowflakeAvroSyncMethod) putFileToStage(avroFile *avro.AvroFile, stage shutdown <- struct{}{} }() - if _, err := s.connector.database.Exec(putCmd); err != nil { + if _, err := s.connector.database.ExecContext(s.connector.ctx, putCmd); err != nil { return fmt.Errorf("failed to put file to stage: %w", err) } @@ -395,7 +393,7 @@ func (s *SnowflakeAvroSyncMethod) insertMetadata( return fmt.Errorf("failed to create metadata insert statement: %v", err) } - if _, err := s.connector.database.Exec(insertMetadataStmt); err != nil { + if _, err := s.connector.database.ExecContext(s.connector.ctx, insertMetadataStmt); err != nil { s.connector.logger.Error("failed to execute metadata insert statement "+insertMetadataStmt, slog.Any("error", err), partitionLog) return fmt.Errorf("failed to execute metadata insert statement: %v", err) @@ -434,7 +432,7 @@ func (s *SnowflakeAvroWriteHandler) HandleAppendMode( copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s", s.dstTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ",")) s.connector.logger.Info("running copy command: " + copyCmd) - _, err := s.connector.database.Exec(copyCmd) + _, err := s.connector.database.ExecContext(s.connector.ctx, copyCmd) if err != nil { return fmt.Errorf("failed to run COPY INTO command: %w", err) } @@ -518,7 +516,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( //nolint:gosec createTempTableCmd := fmt.Sprintf("CREATE TEMPORARY TABLE %s AS SELECT * FROM %s LIMIT 0", tempTableName, s.dstTableName) - if _, err := s.connector.database.Exec(createTempTableCmd); err != nil { + if _, err := s.connector.database.ExecContext(s.connector.ctx, createTempTableCmd); err != nil { return fmt.Errorf("failed to create temp table: %w", err) } s.connector.logger.Info("created temp table " + tempTableName) @@ -526,7 +524,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( //nolint:gosec copyCmd := fmt.Sprintf("COPY INTO %s(%s) FROM (SELECT %s FROM @%s) %s", tempTableName, copyInfo.columnsSQL, copyInfo.transformationSQL, s.stage, strings.Join(s.copyOpts, ",")) - _, err = s.connector.database.Exec(copyCmd) + _, err = s.connector.database.ExecContext(s.connector.ctx, copyCmd) if err != nil { return fmt.Errorf("failed to run COPY INTO command: %w", err) } @@ -538,7 +536,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( } startTime := time.Now() - rows, err := s.connector.database.Exec(mergeCmd) + rows, err := s.connector.database.ExecContext(s.connector.ctx, mergeCmd) if err != nil { return fmt.Errorf("failed to merge data into destination table '%s': %w", mergeCmd, err) } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 8cd8240f11..e61830db88 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -480,8 +480,9 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string, return fmt.Errorf("failed to convert column type %s to snowflake type: %w", addedColumn.ColumnType, err) } - _, err = tableSchemaModifyTx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", - schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype)) + _, err = tableSchemaModifyTx.ExecContext(c.ctx, + fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", + schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype)) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, schemaDelta.DstTableName, err)