From 16dd84cd81ace0acccf4ebf58a84112aeb48f70c Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Thu, 28 Nov 2024 22:49:38 +0100 Subject: [PATCH] refactor(stopFlag): completly remove stopFlag for `context` `context` is Go package used mainly to signal cancelation. `stopFlag` package inside of gemini was used the same way. It was just adding overhead and maintainability issues in the long run, as it was only a wrapper for `context.Context` and all at once calling the `cancel` funcs at the end of the program. This removal makes code easier to follow for everybody working on it. `context` package is well known in Go community and can serve many purposes beyond just canceling the background workers (`jobs` in gemini terminology). Hard kill in gemini was used to signal the immadiet stoppage for gemini, without pritting the results, and soft on the contrary could be triggers in a few ways: - gemini validation failed - gemini mutation failed - SIG(INT,TERM) was sent There is no reason to have HARD kill in the application, if we need hard kill, SIGKILL can be sent and everything would be stopped, there will be no need to the the cleanup as kernel will ensure that happens. `context.Context` works as a soft kill, and if something happens bad in gemini (validation fails or mutation fail) `globalStatus.HasError()` would stop all other `goroutines` from continuing if `failFast` CLI flag is passed, so the set of `softKill` is not needed. Signed-off-by: Dusan Malusev --- cmd/gemini/root.go | 66 ++--- pkg/generators/generator.go | 3 +- pkg/generators/generator_test.go | 8 +- pkg/generators/generators.go | 2 +- pkg/generators/partition.go | 10 +- pkg/jobs/jobs.go | 49 ++-- pkg/jobs/pump.go | 15 +- pkg/stop/flag.go | 221 ---------------- pkg/stop/flag_test.go | 429 ------------------------------- 9 files changed, 72 insertions(+), 731 deletions(-) delete mode 100644 pkg/stop/flag.go delete mode 100644 pkg/stop/flag_test.go diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 278a102d..6d2a22d0 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -22,24 +22,13 @@ import ( "net/http" "net/http/pprof" "os" + "os/signal" "strconv" "strings" + "syscall" "text/tabwriter" "time" - "github.com/scylladb/gemini/pkg/auth" - "github.com/scylladb/gemini/pkg/builders" - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/jobs" - "github.com/scylladb/gemini/pkg/realrandom" - "github.com/scylladb/gemini/pkg/replication" - "github.com/scylladb/gemini/pkg/store" - "github.com/scylladb/gemini/pkg/typedef" - "github.com/scylladb/gemini/pkg/utils" - - "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" - "github.com/gocql/gocql" "github.com/hailocab/go-hostpool" "github.com/pkg/errors" @@ -50,6 +39,17 @@ import ( "golang.org/x/exp/rand" "golang.org/x/net/context" "gonum.org/v1/gonum/stat/distuv" + + "github.com/scylladb/gemini/pkg/auth" + "github.com/scylladb/gemini/pkg/builders" + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/jobs" + "github.com/scylladb/gemini/pkg/realrandom" + "github.com/scylladb/gemini/pkg/replication" + "github.com/scylladb/gemini/pkg/status" + "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/typedef" + "github.com/scylladb/gemini/pkg/utils" ) var ( @@ -137,8 +137,11 @@ func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Sc } func run(_ *cobra.Command, _ []string) error { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGABRT, syscall.SIGTERM, syscall.SIGINT) + defer cancel() + logger := createLogger(level) - globalStatus := status.NewGlobalStatus(1000) + globalStatus := status.NewGlobalStatus(int32(maxErrorsToStore)) defer utils.IgnoreError(logger.Sync) if err := validateSeed(seed); err != nil { @@ -242,7 +245,7 @@ func run(_ *cobra.Command, _ []string) error { if dropSchema && mode != jobs.ReadMode { for _, stmt := range generators.GetDropKeyspace(schema) { logger.Debug(stmt) - if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { + if err = st.Mutate(ctx, typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to drop schema") } } @@ -250,7 +253,7 @@ func run(_ *cobra.Command, _ []string) error { testKeyspace, oracleKeyspace := generators.GetCreateKeyspaces(schema) if err = st.Create( - context.Background(), + ctx, typedef.SimpleStmt(testKeyspace, typedef.CreateKeyspaceStatementType), typedef.SimpleStmt(oracleKeyspace, typedef.CreateKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to create keyspace") @@ -263,11 +266,7 @@ func run(_ *cobra.Command, _ []string) error { } } - ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2) - stopFlag := stop.NewFlag("main") - warmupStopFlag := stop.NewFlag("warmup") - stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag) - pump := jobs.NewPump(stopFlag, logger) + pump := jobs.NewPump(ctx, logger) distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, normalDistMean, normalDistSigma) if err != nil { @@ -281,10 +280,9 @@ func run(_ *cobra.Command, _ []string) error { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) go func() { - defer done() for { select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return case <-ticker.C: sp.Set(" Running Gemini... %v", globalStatus) @@ -293,20 +291,24 @@ func run(_ *cobra.Command, _ []string) error { }() } - if warmup > 0 && !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, warmupStopFlag, failFast, verbose); err != nil { + if warmup > 0 { + warmupCtx, warmupCancel := context.WithTimeout(ctx, warmup) + defer warmupCancel() + + jobsList := jobs.ListFromMode(jobs.WarmupMode, concurrency) + if err = jobsList.Run(warmupCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { logger.Error("warmup encountered an error", zap.Error(err)) - stopFlag.SetHard(true) } } - if !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(mode, duration, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil { - logger.Debug("error detected", zap.Error(err)) - } + jobsCtx, jobsCancel := context.WithTimeout(ctx, duration) + defer jobsCancel() + + jobsList := jobs.ListFromMode(mode, concurrency) + if err = jobsList.Run(jobsCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { + logger.Debug("error detected", zap.Error(err)) } + logger.Info("test finished") globalStatus.PrintResult(outFile, schema, version) if globalStatus.HasErrors() { diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index 87a63876..c0ee9214 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -16,6 +16,7 @@ package generators import ( "context" + "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/exp/rand" @@ -26,7 +27,7 @@ import ( // TokenIndex represents the position of a token in the token ring. // A token index is translated to a token by a generators. If the generators -// preserves the exact position, then the token index becomes the token; +// preserve the exact position, then the token index becomes the token; // otherwise token index represents an approximation of the token. // // We use a token index approach, because our generators actually generate diff --git a/pkg/generators/generator_test.go b/pkg/generators/generator_test.go index 3a46551c..6568d756 100644 --- a/pkg/generators/generator_test.go +++ b/pkg/generators/generator_test.go @@ -15,13 +15,13 @@ package generators_test import ( + "context" "sync/atomic" "testing" "go.uber.org/zap" "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" ) @@ -32,7 +32,7 @@ func TestGenerator(t *testing.T) { PartitionKeys: generators.CreatePkColumns(1, "pk"), } var current uint64 - cfg := &generators.Config{ + cfg := generators.Config{ PartitionsRangeConfig: typedef.PartitionRangeConfig{ MaxStringLength: 10, MinStringLength: 0, @@ -47,8 +47,8 @@ func TestGenerator(t *testing.T) { } logger, _ := zap.NewDevelopment() generator := generators.NewGenerator(table, cfg, logger) - generator.Start(stop.NewFlag("main_test")) - for i := uint64(0); i < cfg.PartitionsCount; i++ { + generator.Start(context.Background()) + for i := range cfg.PartitionsCount { atomic.StoreUint64(¤t, i) v := generator.Get() n := generator.Get() diff --git a/pkg/generators/generators.go b/pkg/generators/generators.go index 23c88de0..95a06d84 100644 --- a/pkg/generators/generators.go +++ b/pkg/generators/generators.go @@ -25,9 +25,9 @@ import ( ) type Generators struct { - Generators []Generator wg *sync.WaitGroup cancel context.CancelFunc + Generators []Generator } func New( diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index 839b50ca..eb418a72 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -15,9 +15,10 @@ package generators import ( - "go.uber.org/multierr" "sync/atomic" + "go.uber.org/multierr" + "github.com/scylladb/gemini/pkg/inflight" "github.com/scylladb/gemini/pkg/typedef" ) @@ -113,12 +114,11 @@ func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken { } func (s *Partition) Close() error { - for !s.closed.CompareAndSwap(false, true) { + if s.closed.CompareAndSwap(false, true) { + close(s.values) + close(s.oldValues) } - close(s.values) - close(s.oldValues) - return nil } diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 5274c28c..55ddf2cb 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -28,7 +28,6 @@ import ( "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/joberror" "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/store" "github.com/scylladb/gemini/pkg/typedef" ) @@ -53,10 +52,9 @@ var ( ) type List struct { - name string - jobs []job - duration time.Duration - workers uint64 + name string + jobs []job + workers uint64 } type job struct { @@ -72,16 +70,16 @@ type job struct { *generators.Generator, *status.GlobalStatus, *zap.Logger, - *stop.Flag, bool, bool, ) error name string } -func ListFromMode(mode string, duration time.Duration, workers uint64) List { +func ListFromMode(mode string, workers uint64) List { jobs := make([]job, 0, 2) name := "work cycle" + switch mode { case WriteMode: jobs = append(jobs, mutate) @@ -93,11 +91,11 @@ func ListFromMode(mode string, duration time.Duration, workers uint64) List { default: jobs = append(jobs, mutate, validate) } + return List{ - name: name, - jobs: jobs, - duration: duration, - workers: workers, + name: name, + jobs: jobs, + workers: workers, } } @@ -111,16 +109,10 @@ func (l List) Run( globalStatus *status.GlobalStatus, logger *zap.Logger, seed uint64, - stopFlag *stop.Flag, failFast, verbose bool, ) error { logger = logger.Named(l.name) - ctx = stopFlag.CancelContextOnSignal(ctx, stop.SignalHardStop) g, gCtx := errgroup.WithContext(ctx) - time.AfterFunc(l.duration, func() { - logger.Info("jobs time is up, begins jobs completion") - stopFlag.SetSoft(true) - }) partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() logger.Info("start jobs") @@ -131,7 +123,7 @@ func (l List) Run( jobF := l.jobs[idx].function r := rand.New(rand.NewSource(seed)) g.Go(func() error { - return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, stopFlag, failFast, verbose) + return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, failFast, verbose) }) } } @@ -154,7 +146,6 @@ func mutationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, verbose bool, ) error { schemaConfig := &schemaCfg @@ -164,11 +155,8 @@ func mutationJob( logger.Info("ending mutation loop") }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): logger.Debug("mutation job terminated") return nil case hb := <-pump: @@ -187,7 +175,6 @@ func mutationJob( } } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -207,7 +194,6 @@ func validationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -218,11 +204,8 @@ func validationJob( }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return nil case hb := <-pump: time.Sleep(hb) @@ -262,7 +245,6 @@ func validationJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -282,7 +264,6 @@ func warmupJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -292,10 +273,13 @@ func warmupJob( logger.Info("ending warmup loop") }() for { - if stopFlag.IsHardOrSoft() { + select { + case <-ctx.Done(): logger.Debug("warmup job terminated") return nil + default: } + // Do we care about errors during warmup? err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, false, logger) if err != nil { @@ -303,7 +287,6 @@ func warmupJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } diff --git a/pkg/jobs/pump.go b/pkg/jobs/pump.go index c929f8ce..4baf6a98 100644 --- a/pkg/jobs/pump.go +++ b/pkg/jobs/pump.go @@ -15,15 +15,14 @@ package jobs import ( + "context" "time" - "github.com/scylladb/gemini/pkg/stop" - "go.uber.org/zap" "golang.org/x/exp/rand" ) -func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { +func NewPump(ctx context.Context, logger *zap.Logger) <-chan time.Duration { pump := make(chan time.Duration, 10000) logger = logger.Named("Pump") go func() { @@ -32,8 +31,14 @@ func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { close(pump) logger.Debug("pump channel closed") }() - for !stopFlag.IsHardOrSoft() { - pump <- newHeartBeat() + + for { + select { + case <-ctx.Done(): + return + default: + pump <- newHeartBeat() + } } }() diff --git a/pkg/stop/flag.go b/pkg/stop/flag.go deleted file mode 100644 index 54c6e1f2..00000000 --- a/pkg/stop/flag.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop - -import ( - "context" - "fmt" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - - "go.uber.org/zap" -) - -const ( - SignalNoop uint32 = iota - SignalSoftStop - SignalHardStop -) - -type SignalChannel chan uint32 - -var closedChan = createClosedChan() - -func createClosedChan() SignalChannel { - ch := make(SignalChannel) - close(ch) - return ch -} - -type SyncList[T any] struct { - children []T - childrenLock sync.RWMutex -} - -func (f *SyncList[T]) Append(el T) { - f.childrenLock.Lock() - defer f.childrenLock.Unlock() - f.children = append(f.children, el) -} - -func (f *SyncList[T]) Get() []T { - f.childrenLock.RLock() - defer f.childrenLock.RUnlock() - return f.children -} - -type logger interface { - Debug(msg string, fields ...zap.Field) -} - -type Flag struct { - name string - log logger - ch atomic.Pointer[SignalChannel] - parent *Flag - children SyncList[*Flag] - stopHandlers SyncList[func(signal uint32)] - val atomic.Uint32 -} - -func (s *Flag) Name() string { - return s.name -} - -func (s *Flag) closeChannel() { - ch := s.ch.Swap(&closedChan) - if ch != &closedChan { - close(*ch) - } -} - -func (s *Flag) sendSignal(signal uint32, sendToParent bool) bool { - s.log.Debug(fmt.Sprintf("flag %s received signal %s", s.name, GetStateName(signal))) - s.closeChannel() - out := s.val.CompareAndSwap(SignalNoop, signal) - if !out { - return false - } - - for _, handler := range s.stopHandlers.Get() { - handler(signal) - } - - for _, child := range s.children.Get() { - child.sendSignal(signal, sendToParent) - } - if sendToParent && s.parent != nil { - s.parent.sendSignal(signal, sendToParent) - } - return out -} - -func (s *Flag) SetHard(sendToParent bool) bool { - return s.sendSignal(SignalHardStop, sendToParent) -} - -func (s *Flag) SetSoft(sendToParent bool) bool { - return s.sendSignal(SignalSoftStop, sendToParent) -} - -func (s *Flag) CreateChild(name string) *Flag { - child := newFlag(name, s) - s.children.Append(child) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - child.sendSignal(val, false) - } - return child -} - -func (s *Flag) SignalChannel() SignalChannel { - return *s.ch.Load() -} - -func (s *Flag) IsSoft() bool { - return s.val.Load() == SignalSoftStop -} - -func (s *Flag) IsHard() bool { - return s.val.Load() == SignalHardStop -} - -func (s *Flag) IsHardOrSoft() bool { - return s.val.Load() != SignalNoop -} - -func (s *Flag) AddHandler(handler func(signal uint32)) { - s.stopHandlers.Append(handler) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - handler(val) - } -} - -func (s *Flag) AddHandler2(handler func(), expectedSignal uint32) { - s.AddHandler(func(signal uint32) { - switch expectedSignal { - case SignalNoop: - handler() - default: - if signal == expectedSignal { - handler() - } - } - }) -} - -func (s *Flag) CancelContextOnSignal(ctx context.Context, expectedSignal uint32) context.Context { - ctx, cancel := context.WithCancel(ctx) - s.AddHandler2(cancel, expectedSignal) - return ctx -} - -func (s *Flag) SetLogger(log logger) { - s.log = log -} - -func NewFlag(name string) *Flag { - return newFlag(name, nil) -} - -func newFlag(name string, parent *Flag) *Flag { - out := Flag{ - name: name, - parent: parent, - log: zap.NewNop(), - } - ch := make(SignalChannel) - out.ch.Store(&ch) - return &out -} - -func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) { - graceful := make(chan os.Signal, 1) - signal.Notify(graceful, syscall.SIGTERM, syscall.SIGINT) - go func() { - sig := <-graceful - switch sig { - case syscall.SIGINT: - for i := range flags { - flags[i].SetSoft(true) - } - logger.Info("Get SIGINT signal, begin soft stop.") - default: - for i := range flags { - flags[i].SetHard(true) - } - logger.Info("Get SIGTERM signal, begin hard stop.") - } - }() -} - -func GetStateName(state uint32) string { - switch state { - case SignalSoftStop: - return "soft" - case SignalHardStop: - return "hard" - case SignalNoop: - return "no-signal" - default: - panic(fmt.Sprintf("unexpected signal %d", state)) - } -} diff --git a/pkg/stop/flag_test.go b/pkg/stop/flag_test.go deleted file mode 100644 index 81a4b344..00000000 --- a/pkg/stop/flag_test.go +++ /dev/null @@ -1,429 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop_test - -import ( - "context" - "errors" - "fmt" - "reflect" - "runtime" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/scylladb/gemini/pkg/stop" -) - -func TestHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHard, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func TestSoftStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } -} - -func TestSoftOrHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } - - workersDone.Store(uint32(0)) - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() != nil { - t.Error("Error:SetHard function apply hardStopHandler after SetSoft") - } - - testFlag, ctx, workersDone = initVars() - workersDone.Store(uint32(0)) - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func initVars() (testFlag *stop.Flag, ctx context.Context, workersDone *atomic.Uint32) { - testFlagOut := stop.NewFlag("main_test") - ctx = testFlagOut.CancelContextOnSignal(context.Background(), stop.SignalHardStop) - workersDone = &atomic.Uint32{} - return testFlagOut, ctx, workersDone -} - -func testSignals( - t *testing.T, - workersDone *atomic.Uint32, - workers int, - checkFunc func() bool, - setFunc func(propagation bool) bool, -) { - t.Helper() - for i := 0; i != workers; i++ { - go func() { - for { - if checkFunc() { - workersDone.Add(1) - return - } - time.Sleep(10 * time.Millisecond) - } - }() - } - time.Sleep(200 * time.Millisecond) - setFunc(false) - - for i := 0; i != 10; i++ { - time.Sleep(100 * time.Millisecond) - if workersDone.Load() == uint32(workers) { - break - } - } - - setFuncName := runtime.FuncForPC(reflect.ValueOf(setFunc).Pointer()).Name() - setFuncName, _ = strings.CutSuffix(setFuncName, "-fm") - _, setFuncName, _ = strings.Cut(setFuncName, ".(") - setFuncName = strings.ReplaceAll(setFuncName, ").", ".") - checkFuncName := runtime.FuncForPC(reflect.ValueOf(checkFunc).Pointer()).Name() - checkFuncName, _ = strings.CutSuffix(checkFuncName, "-fm") - _, checkFuncName, _ = strings.Cut(checkFuncName, ".(") - checkFuncName = strings.ReplaceAll(checkFuncName, ").", ".") - - if workersDone.Load() != uint32(workers) { - t.Errorf("Error:%s or %s functions works not correctly %[2]s=%v", setFuncName, checkFuncName, checkFunc()) - } -} - -func TestSendToParent(t *testing.T) { - t.Parallel() - tcases := []tCase{ - { - testName: "parent-hard-true", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-hard-false", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "parent-soft-false", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalNoop, - }, - { - testName: "child11-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child11-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalNoop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalNoop, - child2Signal: stop.SignalNoop, - }, - } - for id := range tcases { - tcase := tcases[id] - t.Run(tcase.testName, func(t *testing.T) { - t.Parallel() - if err := tcase.runTest(); err != nil { - t.Error(err) - } - }) - } -} - -// nolint: govet -type parentChildInfo struct { - parent *stop.Flag - parentSignal uint32 - child1 *stop.Flag - child1Signal uint32 - child11 *stop.Flag - child11Signal uint32 - child12 *stop.Flag - child12Signal uint32 - child2 *stop.Flag - child2Signal uint32 -} - -func (t *parentChildInfo) getFlag(flagName string) *stop.Flag { - switch flagName { - case "parent": - return t.parent - case "child1": - return t.child1 - case "child2": - return t.child2 - case "child11": - return t.child11 - case "child12": - return t.child12 - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) getFlagHandlerState(flagName string) uint32 { - switch flagName { - case "parent": - return t.parentSignal - case "child1": - return t.child1Signal - case "child2": - return t.child2Signal - case "child11": - return t.child11Signal - case "child12": - return t.child12Signal - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) checkFlagState(flag *stop.Flag, expectedState uint32) error { - var err error - flagName := flag.Name() - state := t.getFlagHandlerState(flagName) - if state != expectedState { - err = errors.Join(err, fmt.Errorf("flag %s handler has state %s while it is expected to be %s", flagName, stop.GetStateName(state), stop.GetStateName(expectedState))) - } - flagState := getFlagState(flag) - if stop.GetStateName(expectedState) != flagState { - err = errors.Join(err, fmt.Errorf("flag %s has state %s while it is expected to be %s", flagName, flagState, stop.GetStateName(expectedState))) - } - return err -} - -type tCase struct { - testName string - parentSignal uint32 - child1Signal uint32 - child11Signal uint32 - child12Signal uint32 - child2Signal uint32 -} - -func (t *tCase) runTest() error { - chunk := strings.Split(t.testName, "-") - if len(chunk) != 3 { - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - flagName := chunk[0] - signalTypeName := chunk[1] - sendToParentName := chunk[2] - - var sendToParent bool - switch sendToParentName { - case "true": - sendToParent = true - case "false": - sendToParent = false - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - runt := newParentChildInfo() - flag := runt.getFlag(flagName) - switch signalTypeName { - case "soft": - flag.SetSoft(sendToParent) - case "hard": - flag.SetHard(sendToParent) - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - var err error - err = errors.Join(err, runt.checkFlagState(runt.parent, t.parentSignal)) - err = errors.Join(err, runt.checkFlagState(runt.child1, t.child1Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child2, t.child2Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child11, t.child11Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child12, t.child12Signal)) - return err -} - -func newParentChildInfo() *parentChildInfo { - parent := stop.NewFlag("parent") - child1 := parent.CreateChild("child1") - out := parentChildInfo{ - parent: parent, - child1: child1, - child11: child1.CreateChild("child11"), - child12: child1.CreateChild("child12"), - child2: parent.CreateChild("child2"), - } - - out.parent.AddHandler(func(signal uint32) { - out.parentSignal = signal - }) - out.child1.AddHandler(func(signal uint32) { - out.child1Signal = signal - }) - out.child11.AddHandler(func(signal uint32) { - out.child11Signal = signal - }) - out.child12.AddHandler(func(signal uint32) { - out.child12Signal = signal - }) - out.child2.AddHandler(func(signal uint32) { - out.child2Signal = signal - }) - return &out -} - -func getFlagState(flag *stop.Flag) string { - switch { - case flag.IsSoft(): - return "soft" - case flag.IsHard(): - return "hard" - default: - return "no-signal" - } -} - -func TestSignalChannel(t *testing.T) { - t.Parallel() - t.Run("single-no-signal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - select { - case <-flag.SignalChannel(): - t.Error("should not get the signal") - case <-time.Tick(200 * time.Millisecond): - } - }) - - t.Run("single-beforehand", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - flag.SetSoft(true) - <-flag.SignalChannel() - }) - - t.Run("single-normal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - go func() { - time.Sleep(200 * time.Millisecond) - flag.SetSoft(true) - }() - <-flag.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - parent.SetSoft(true) - <-child.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - parent.SetSoft(true) - child := parent.CreateChild("child") - <-child.SignalChannel() - }) - - t.Run("parent-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - parent.SetSoft(true) - }() - <-child.SignalChannel() - }) - - t.Run("child-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - child.SetSoft(true) - <-parent.SignalChannel() - }) - - t.Run("child-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - child.SetSoft(true) - }() - <-parent.SignalChannel() - }) -}