Skip to content

Commit

Permalink
Don't use Run in main_test
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Feb 29, 2024
1 parent bffc0a3 commit 9964ffe
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 71 deletions.
19 changes: 7 additions & 12 deletions flow/cmd/snapshot_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type SnapshotWorkerOptions struct {
TemporalKey string
}

func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) error {
func SnapshotWorkerMain(opts *SnapshotWorkerOptions) (worker.Worker, error) {
clientOptions := client.Options{
HostPort: opts.TemporalHostPort,
Namespace: opts.TemporalNamespace,
Expand All @@ -35,7 +35,7 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err
if opts.TemporalCert != "" && opts.TemporalKey != "" {
certs, err := Base64DecodeCertAndKey(opts.TemporalCert, opts.TemporalKey)
if err != nil {
return fmt.Errorf("unable to process certificate and key: %w", err)
return nil, fmt.Errorf("unable to process certificate and key: %w", err)
}

connOptions := client.ConnectionOptions{
Expand All @@ -49,13 +49,13 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err

c, err := client.Dial(clientOptions)
if err != nil {
return fmt.Errorf("unable to create Temporal client: %w", err)
return nil, fmt.Errorf("unable to create Temporal client: %w", err)
}
defer c.Close()

taskQueue, queueErr := shared.GetPeerFlowTaskQueueName(shared.SnapshotFlowTaskQueueID)
if queueErr != nil {
return queueErr
return nil, queueErr
}

w := worker.New(c, taskQueue, worker.Options{
Expand All @@ -64,12 +64,12 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err

conn, err := utils.GetCatalogConnectionPoolFromEnv(context.Background())
if err != nil {
return fmt.Errorf("unable to create catalog connection pool: %w", err)
return nil, fmt.Errorf("unable to create catalog connection pool: %w", err)
}

alerter, err := alerting.NewAlerter(conn)
if err != nil {
return fmt.Errorf("unable to create alerter: %w", err)
return nil, fmt.Errorf("unable to create alerter: %w", err)
}

w.RegisterWorkflow(peerflow.SnapshotFlowWorkflow)
Expand All @@ -78,10 +78,5 @@ func SnapshotWorkerMain(end <-chan interface{}, opts *SnapshotWorkerOptions) err
Alerter: alerter,
})

err = w.Run(worker.InterruptCh())
if err != nil {
return fmt.Errorf("worker run error: %w", err)
}

return nil
return w, nil
}
32 changes: 7 additions & 25 deletions flow/cmd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"log"
"log/slog"
"os"
"os/signal"
"runtime"
"syscall"

"github.com/grafana/pyroscope-go"
"go.temporal.io/sdk/client"
Expand Down Expand Up @@ -74,22 +72,11 @@ func setupPyroscope(opts *WorkerOptions) {
}
}

func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error {
func WorkerMain(opts *WorkerOptions) (worker.Worker, error) {
if opts.EnableProfiling {
setupPyroscope(opts)
}

go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGQUIT)
buf := make([]byte, 1<<20)
for {
<-sigs
stacklen := runtime.Stack(buf, true)
log.Printf("=== received SIGQUIT ===\n*** goroutine dump...\n%s\n*** end\n", buf[:stacklen])
}
}()

clientOptions := client.Options{
HostPort: opts.TemporalHostPort,
Namespace: opts.TemporalNamespace,
Expand All @@ -100,7 +87,7 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error {
slog.Info("Using temporal certificate/key for authentication")
certs, err := Base64DecodeCertAndKey(opts.TemporalCert, opts.TemporalKey)
if err != nil {
return fmt.Errorf("unable to process certificate and key: %w", err)
return nil, fmt.Errorf("unable to process certificate and key: %w", err)
}
connOptions := client.ConnectionOptions{
TLS: &tls.Config{
Expand All @@ -113,19 +100,19 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error {

conn, err := utils.GetCatalogConnectionPoolFromEnv(context.Background())
if err != nil {
return fmt.Errorf("unable to create catalog connection pool: %w", err)
return nil, fmt.Errorf("unable to create catalog connection pool: %w", err)
}

c, err := client.Dial(clientOptions)
if err != nil {
return fmt.Errorf("unable to create Temporal client: %w", err)
return nil, fmt.Errorf("unable to create Temporal client: %w", err)
}
slog.Info("Created temporal client")
defer c.Close()

taskQueue, queueErr := shared.GetPeerFlowTaskQueueName(shared.PeerFlowTaskQueueID)
if queueErr != nil {
return queueErr
return nil, queueErr
}

w := worker.New(c, taskQueue, worker.Options{
Expand All @@ -135,7 +122,7 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error {

alerter, err := alerting.NewAlerter(conn)
if err != nil {
return fmt.Errorf("unable to create alerter: %w", err)
return nil, fmt.Errorf("unable to create alerter: %w", err)
}

w.RegisterActivity(&activities.FlowableActivity{
Expand All @@ -144,10 +131,5 @@ func WorkerMain(end <-chan interface{}, opts *WorkerOptions) error {
CdcCache: make(map[string]connectors.CDCPullConnector),
})

err = w.Run(end)
if err != nil {
return fmt.Errorf("worker run error: %w", err)
}

return nil
return w, nil
}
57 changes: 25 additions & 32 deletions flow/e2e/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,31 @@ import (
)

func TestMain(m *testing.M) {
end := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
group, _ := errgroup.WithContext(ctx)
group.Go(func() error {
return cmd.WorkerMain(end, &cmd.WorkerOptions{
TemporalHostPort: "localhost:7233",
EnableProfiling: false,
PyroscopeServer: "",
TemporalNamespace: "default",
TemporalCert: "",
TemporalKey: "",
})
peerWorker, peerErr := cmd.WorkerMain(end, &cmd.WorkerOptions{
TemporalHostPort: "localhost:7233",
EnableProfiling: false,
PyroscopeServer: "",
TemporalNamespace: "default",
TemporalCert: "",
TemporalKey: "",
})
group.Go(func() error {
return cmd.SnapshotWorkerMain(end, &cmd.SnapshotWorkerOptions{
TemporalHostPort: "localhost:7233",
TemporalNamespace: "default",
TemporalCert: "",
TemporalKey: "",
})
if peerErr != nil {
panic(peerErr)
} else if err := peerWorker.Start(); err != nil {
panic(err)
}

snapWorker, snapErr := cmd.SnapshotWorkerMain(end, &cmd.SnapshotWorkerOptions{
TemporalHostPort: "localhost:7233",
TemporalNamespace: "default",
TemporalCert: "",
TemporalKey: "",
})
exitcode := m.Run()
end <- os.Interrupt
close(end)
go func() {
err := group.Wait()
if err != nil {
//nolint:forbidigo
fmt.Printf("%+v\n", err)
}
}()
time.Sleep(time.Second)
cancel()
os.Exit(exitcode)
if snapErr != nil {
panic(snapErr)
} else if err := snapWorker.Start(); err != nil {
panic(err)
}

os.Exit(m.Run())
}
24 changes: 22 additions & 2 deletions flow/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log/slog"
"os"
"os/signal"
"runtime"
"syscall"

"github.com/urfave/cli/v3"
Expand Down Expand Up @@ -68,14 +69,18 @@ func main() {
Name: "worker",
Action: func(ctx context.Context, clicmd *cli.Command) error {
temporalHostPort := clicmd.String("temporal-host-port")
return cmd.WorkerMain(worker.InterruptCh(), &cmd.WorkerOptions{
w, err := cmd.WorkerMain(&cmd.WorkerOptions{
TemporalHostPort: temporalHostPort,
EnableProfiling: clicmd.Bool("enable-profiling"),
PyroscopeServer: clicmd.String("pyroscope-server-address"),
TemporalNamespace: clicmd.String("temporal-namespace"),
TemporalCert: clicmd.String("temporal-cert"),
TemporalKey: clicmd.String("temporal-key"),
})
if err != nil {
return err
}
return w.Run(worker.InterruptCh())
},
Flags: []cli.Flag{
temporalHostPortFlag,
Expand All @@ -90,12 +95,16 @@ func main() {
Name: "snapshot-worker",
Action: func(ctx context.Context, clicmd *cli.Command) error {
temporalHostPort := clicmd.String("temporal-host-port")
return cmd.SnapshotWorkerMain(worker.InterruptCh(), &cmd.SnapshotWorkerOptions{
w, err := cmd.SnapshotWorkerMain(&cmd.SnapshotWorkerOptions{
TemporalHostPort: temporalHostPort,
TemporalNamespace: clicmd.String("temporal-namespace"),
TemporalCert: clicmd.String("temporal-cert"),
TemporalKey: clicmd.String("temporal-key"),
})
if err != nil {
return err
}
return w.Run(worker.InterruptCh())
},
Flags: []cli.Flag{
temporalHostPortFlag,
Expand Down Expand Up @@ -138,6 +147,17 @@ func main() {
},
}

go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGQUIT)
buf := make([]byte, 1<<20)
for {
<-sigs
stacklen := runtime.Stack(buf, true)
log.Printf("=== received SIGQUIT ===\n*** goroutine dump...\n%s\n*** end\n", buf[:stacklen])
}
}()

if err := app.Run(appCtx, os.Args); err != nil {
log.Printf("error running app: %+v", err)
}
Expand Down

0 comments on commit 9964ffe

Please sign in to comment.