diff --git a/flow/cmd/snapshot_worker.go b/flow/cmd/snapshot_worker.go index aca936a628..772a489594 100644 --- a/flow/cmd/snapshot_worker.go +++ b/flow/cmd/snapshot_worker.go @@ -25,7 +25,7 @@ type SnapshotWorkerOptions struct { TemporalKey string } -func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) error { +func SnapshotWorkerMain(opts *SnapshotWorkerOptions) (worker.Worker, error) { clientOptions := client.Options{ HostPort: opts.TemporalHostPort, Namespace: opts.TemporalNamespace, @@ -35,7 +35,7 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err if opts.TemporalCert != "" && opts.TemporalKey != "" { certs, err := Base64DecodeCertAndKey(opts.TemporalCert, opts.TemporalKey) if err != nil { - return fmt.Errorf("unable to process certificate and key: %w", err) + return nil, fmt.Errorf("unable to process certificate and key: %w", err) } connOptions := client.ConnectionOptions{ @@ -49,13 +49,13 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err c, err := client.Dial(clientOptions) if err != nil { - return fmt.Errorf("unable to create Temporal client: %w", err) + return nil, fmt.Errorf("unable to create Temporal client: %w", err) } defer c.Close() taskQueue, queueErr := shared.GetPeerFlowTaskQueueName(shared.SnapshotFlowTaskQueueID) if queueErr != nil { - return queueErr + return nil, queueErr } w := worker.New(c, taskQueue, worker.Options{ @@ -64,12 +64,12 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err conn, err := utils.GetCatalogConnectionPoolFromEnv(context.Background()) if err != nil { - return fmt.Errorf("unable to create catalog connection pool: %w", err) + return nil, fmt.Errorf("unable to create catalog connection pool: %w", err) } alerter, err := alerting.NewAlerter(conn) if err != nil { - return fmt.Errorf("unable to create alerter: %w", err) + return nil, fmt.Errorf("unable to create alerter: %w", err) } w.RegisterWorkflow(peerflow.SnapshotFlowWorkflow) @@ -78,10 +78,5 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err Alerter: alerter, }) - err = w.Run(worker.InterruptCh()) - if err != nil { - return fmt.Errorf("worker run error: %w", err) - } - - return nil + return w, nil } diff --git a/flow/cmd/worker.go b/flow/cmd/worker.go index 17e818d491..fd1eb063b0 100644 --- a/flow/cmd/worker.go +++ b/flow/cmd/worker.go @@ -7,9 +7,7 @@ import ( "log" "log/slog" "os" - "os/signal" "runtime" - "syscall" "github.com/grafana/pyroscope-go" "go.temporal.io/sdk/client" @@ -74,22 +72,11 @@ func setupPyroscope(opts *WorkerOptions) { } } -func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error { +func WorkerMain(opts *WorkerOptions) (worker.Worker, error) { if opts.EnableProfiling { setupPyroscope(opts) } - go func() { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGQUIT) - buf := make([]byte, 1<<20) - for { - <-sigs - stacklen := runtime.Stack(buf, true) - log.Printf("=== received SIGQUIT ===\n*** goroutine dump...\n%s\n*** end\n", buf[:stacklen]) - } - }() - clientOptions := client.Options{ HostPort: opts.TemporalHostPort, Namespace: opts.TemporalNamespace, @@ -100,7 +87,7 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error { slog.Info("Using temporal certificate/key for authentication") certs, err := Base64DecodeCertAndKey(opts.TemporalCert, opts.TemporalKey) if err != nil { - return fmt.Errorf("unable to process certificate and key: %w", err) + return nil, fmt.Errorf("unable to process certificate and key: %w", err) } connOptions := client.ConnectionOptions{ TLS: &tls.Config{ @@ -113,19 +100,19 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error { conn, err := utils.GetCatalogConnectionPoolFromEnv(context.Background()) if err != nil { - return fmt.Errorf("unable to create catalog connection pool: %w", err) + return nil, fmt.Errorf("unable to create catalog connection pool: %w", err) } c, err := client.Dial(clientOptions) if err != nil { - return fmt.Errorf("unable to create Temporal client: %w", err) + return nil, fmt.Errorf("unable to create Temporal client: %w", err) } slog.Info("Created temporal client") defer c.Close() taskQueue, queueErr := shared.GetPeerFlowTaskQueueName(shared.PeerFlowTaskQueueID) if queueErr != nil { - return queueErr + return nil, queueErr } w := worker.New(c, taskQueue, worker.Options{ @@ -135,7 +122,7 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error { alerter, err := alerting.NewAlerter(conn) if err != nil { - return fmt.Errorf("unable to create alerter: %w", err) + return nil, fmt.Errorf("unable to create alerter: %w", err) } w.RegisterActivity(&activities.FlowableActivity{ @@ -144,10 +131,5 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error { CdcCache: make(map[string]connectors.CDCPullConnector), }) - err = w.Run(end) - if err != nil { - return fmt.Errorf("worker run error: %w", err) - } - - return nil + return w, nil } diff --git a/flow/e2e/main_test.go b/flow/e2e/main_test.go index d19745f295..e972c36079 100644 --- a/flow/e2e/main_test.go +++ b/flow/e2e/main_test.go @@ -13,38 +13,31 @@ import ( ) func TestMain(m *testing.M) { - end := make(chan interface{}) - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - group, _ := errgroup.WithContext(ctx) - group.Go(func() error { - return cmd.WorkerMain(end, &cmd.WorkerOptions{ - TemporalHostPort: "localhost:7233", - EnableProfiling: false, - PyroscopeServer: "", - TemporalNamespace: "default", - TemporalCert: "", - TemporalKey: "", - }) + peerWorker, peerErr := cmd.WorkerMain(end, &cmd.WorkerOptions{ + TemporalHostPort: "localhost:7233", + EnableProfiling: false, + PyroscopeServer: "", + TemporalNamespace: "default", + TemporalCert: "", + TemporalKey: "", }) - group.Go(func() error { - return cmd.SnapshotWorkerMain(end, &cmd.SnapshotWorkerOptions{ - TemporalHostPort: "localhost:7233", - TemporalNamespace: "default", - TemporalCert: "", - TemporalKey: "", - }) + if peerErr != nil { + panic(peerErr) + } else if err := peerWorker.Start(); err != nil { + panic(err) + } + + snapWorker, snapErr := cmd.SnapshotWorkerMain(end, &cmd.SnapshotWorkerOptions{ + TemporalHostPort: "localhost:7233", + TemporalNamespace: "default", + TemporalCert: "", + TemporalKey: "", }) - exitcode := m.Run() - end <- os.Interrupt - close(end) - go func() { - err := group.Wait() - if err != nil { - //nolint:forbidigo - fmt.Printf("%+v\n", err) - } - }() - time.Sleep(time.Second) - cancel() - os.Exit(exitcode) + if snapErr != nil { + panic(snapErr) + } else if err := snapWorker.Start(); err != nil { + panic(err) + } + + os.Exit(m.Run()) } diff --git a/flow/main.go b/flow/main.go index bbb741ea55..41a1d93a98 100644 --- a/flow/main.go +++ b/flow/main.go @@ -6,6 +6,7 @@ import ( "log/slog" "os" "os/signal" + "runtime" "syscall" "github.com/urfave/cli/v3" @@ -68,7 +69,7 @@ func main() { Name: "worker", Action: func(ctx context.Context, clicmd *cli.Command) error { temporalHostPort := clicmd.String("temporal-host-port") - return cmd.WorkerMain(worker.InterruptCh(), &cmd.WorkerOptions{ + w, err := cmd.WorkerMain(&cmd.WorkerOptions{ TemporalHostPort: temporalHostPort, EnableProfiling: clicmd.Bool("enable-profiling"), PyroscopeServer: clicmd.String("pyroscope-server-address"), @@ -76,6 +77,10 @@ func main() { TemporalCert: clicmd.String("temporal-cert"), TemporalKey: clicmd.String("temporal-key"), }) + if err != nil { + return err + } + return w.Run(worker.InterruptCh()) }, Flags: []cli.Flag{ temporalHostPortFlag, @@ -90,12 +95,16 @@ func main() { Name: "snapshot-worker", Action: func(ctx context.Context, clicmd *cli.Command) error { temporalHostPort := clicmd.String("temporal-host-port") - return cmd.SnapshotWorkerMain(worker.InterruptCh(), &cmd.SnapshotWorkerOptions{ + w, err := cmd.SnapshotWorkerMain(&cmd.SnapshotWorkerOptions{ TemporalHostPort: temporalHostPort, TemporalNamespace: clicmd.String("temporal-namespace"), TemporalCert: clicmd.String("temporal-cert"), TemporalKey: clicmd.String("temporal-key"), }) + if err != nil { + return err + } + return w.Run(worker.InterruptCh()) }, Flags: []cli.Flag{ temporalHostPortFlag, @@ -138,6 +147,17 @@ func main() { }, } + go func() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGQUIT) + buf := make([]byte, 1<<20) + for { + <-sigs + stacklen := runtime.Stack(buf, true) + log.Printf("=== received SIGQUIT ===\n*** goroutine dump...\n%s\n*** end\n", buf[:stacklen]) + } + }() + if err := app.Run(appCtx, os.Args); err != nil { log.Printf("error running app: %+v", err) }