diff --git a/flow/cmd/api.go b/flow/cmd/api.go index 5b010916db..15765ee189 100644 --- a/flow/cmd/api.go +++ b/flow/cmd/api.go @@ -13,6 +13,8 @@ import ( "github.com/google/uuid" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/client" "google.golang.org/grpc" @@ -37,6 +39,93 @@ type APIServerParams struct { GatewayPort uint16 } +type RecryptItem struct { + options []byte + id int32 +} + +func recryptDatabase(ctx context.Context, catalogPool *pgxpool.Pool) { + newKeyID := peerdbenv.PeerDBCurrentEncKeyID() + keys := peerdbenv.PeerDBEncKeys() + if newKeyID == "" { + if len(keys) == 0 { + slog.Warn("Encryption disabled. This is not recommended.") + } else { + slog.Warn("Encryption disabled, decrypting any currently encrypted configs. This is not recommended.") + } + } + + key, err := keys.Get(newKeyID) + if err != nil { + slog.Warn("recrypt failed to find key, skipping", slog.Any("error", err)) + return + } + + tx, err := catalogPool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + slog.Warn("recrypt failed to start transaction, skipping", slog.Any("error", err)) + return + } + defer shared.RollbackTx(tx, slog.Default()) + + rows, err := tx.Query(ctx, "SELECT id, options, enc_key_id FROM peers WHERE enc_key_id <> $1 FOR UPDATE", newKeyID) + if err != nil { + slog.Warn("recrypt failed to query, skipping", slog.Any("error", err)) + return + } + var todo []RecryptItem + var id int32 + var options []byte + var oldKeyID string + for rows.Next() { + if err := rows.Scan(&id, &options, &oldKeyID); err != nil { + slog.Warn("recrypt failed to scan, skipping", slog.Any("error", err)) + continue + } + + oldKey, err := keys.Get(oldKeyID) + if err != nil { + slog.Warn("recrypt failed to find key, skipping", slog.Any("error", err), slog.String("enc_key_id", oldKeyID)) + continue + } + + if oldKey != nil { + options, err = oldKey.Decrypt(options) + if err != nil { + slog.Warn("recrypt failed to decrypt, skipping", slog.Any("error", err), slog.Int64("id", int64(id))) + continue + } + } + + if key != nil { + options, err = key.Encrypt(options) + if err != nil { + slog.Warn("recrypt failed to encrypt, skipping", slog.Any("error", err)) + continue + } + } + + slog.Info("recrypting peer", slog.Int64("id", int64(id)), slog.String("oldKey", oldKeyID), slog.String("newKey", newKeyID)) + todo = append(todo, RecryptItem{id: id, options: options}) + } + if err := rows.Err(); err != nil { + slog.Warn("recrypt iteration failed, skipping", slog.Any("error", err)) + return + } + + for _, item := range todo { + if _, err := tx.Exec(ctx, "UPDATE peers SET options = $2, enc_key_id = $3 WHERE id = $1", item.id, item.options, newKeyID); err != nil { + slog.Warn("recrypt failed to update, ignoring", slog.Any("error", err), slog.Int64("id", int64(item.id))) + return + } + } + + if err := tx.Commit(ctx); err != nil { + slog.Warn("recrypt failed to commit transaction, skipping", slog.Any("error", err)) + } + slog.Info("recrypt finished") +} + // setupGRPCGatewayServer sets up the grpc-gateway mux func setupGRPCGatewayServer(args *APIServerParams) (*http.Server, error) { conn, err := grpc.NewClient( @@ -117,13 +206,13 @@ func APIMain(ctx context.Context, args *APIServerParams) error { grpcServer := grpc.NewServer() - catalogConn, err := peerdbenv.GetCatalogConnectionPoolFromEnv(ctx) + catalogPool, err := peerdbenv.GetCatalogConnectionPoolFromEnv(ctx) if err != nil { return fmt.Errorf("unable to get catalog connection pool: %w", err) } taskQueue := peerdbenv.PeerFlowTaskQueueName(shared.PeerFlowTaskQueue) - flowHandler := NewFlowRequestHandler(tc, catalogConn, taskQueue) + flowHandler := NewFlowRequestHandler(tc, catalogPool, taskQueue) err = killExistingScheduleFlows(ctx, tc, args.TemporalNamespace, taskQueue) if err != nil { @@ -168,13 +257,15 @@ func APIMain(ctx context.Context, args *APIServerParams) error { slog.Info(fmt.Sprintf("Starting API gateway on port %d", args.GatewayPort)) go func() { - if err := gateway.ListenAndServe(); err != nil { + if err := gateway.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("failed to serve http: %v", err) } }() - <-ctx.Done() + // somewhat unrelated here, but needed a process which isn't replicated + go recryptDatabase(ctx, catalogPool) + <-ctx.Done() grpcServer.GracefulStop() slog.Info("Server has been shut down gracefully. Exiting...") diff --git a/flow/shared/postgres.go b/flow/shared/postgres.go index b01d057806..03915af55d 100644 --- a/flow/shared/postgres.go +++ b/flow/shared/postgres.go @@ -101,8 +101,7 @@ func GetMajorVersion(ctx context.Context, conn *pgx.Conn) (PGVersion, error) { func RollbackTx(tx pgx.Tx, logger log.Logger) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - err := tx.Rollback(ctx) - if err != nil && err != pgx.ErrTxClosed { + if err := tx.Rollback(ctx); err != nil && err != pgx.ErrTxClosed { logger.Error("error while rolling back transaction", slog.Any("error", err)) } }