From 94c08605bde98367b49cfc8e00abcfa29b45d062 Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Wed, 27 Nov 2024 17:24:08 +0100 Subject: [PATCH 1/5] refactor(generator): generators use context.Context instead of stopflag In process of removing the `stopFlag` from gemini's codebase, first step is to migrate the `Value Generators` for patitions. Using context with generators make a lot more sense then the custom built, `stopFlag`. `context` is built-in package in Go, and this is it's usecase - cancelation propagation to background task. Signed-off-by: Dusan Malusev --- cmd/gemini/generators.go | 58 ------------------- cmd/gemini/root.go | 8 ++- pkg/generators/generator.go | 107 +++++++++++++++++++++-------------- pkg/generators/generators.go | 84 +++++++++++++++++++++++++++ pkg/generators/partition.go | 45 +++++++-------- pkg/generators/utils.go | 18 ++---- pkg/jobs/gen_check_stmt.go | 20 +++---- pkg/jobs/gen_mutate_stmt.go | 2 +- pkg/jobs/jobs.go | 10 ++-- 9 files changed, 192 insertions(+), 160 deletions(-) delete mode 100644 cmd/gemini/generators.go create mode 100644 pkg/generators/generators.go diff --git a/cmd/gemini/generators.go b/cmd/gemini/generators.go deleted file mode 100644 index 96a1ffc..0000000 --- a/cmd/gemini/generators.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package main - -import ( - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/typedef" - - "go.uber.org/zap" -) - -func createGenerators( - schema *typedef.Schema, - schemaConfig typedef.SchemaConfig, - seed, distributionSize uint64, - logger *zap.Logger, -) (generators.Generators, error) { - partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() - - gs := make([]*generators.Generator, 0, len(schema.Tables)) - - for id := range schema.Tables { - table := schema.Tables[id] - pkVariations := table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig) - - distFunc, err := createDistributionFunc(partitionKeyDistribution, distributionSize, seed, stdDistMean, oneStdDev) - if err != nil { - return nil, err - } - - tablePartConfig := &generators.Config{ - PartitionsRangeConfig: partitionRangeConfig, - PartitionsCount: distributionSize, - PartitionsDistributionFunc: distFunc, - Seed: seed, - PkUsedBufferSize: pkBufferReuseSize, - } - g := generators.NewGenerator(table, tablePartConfig, logger.Named("generators")) - if pkVariations < 2^32 { - // Low partition key variation can lead to having staled partitions - // Let's detect and mark them before running test - g.FindAndMarkStalePartitions() - } - gs = append(gs, g) - } - return gs, nil -} diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index a90d7d5..4a3510d 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -269,12 +269,14 @@ func run(_ *cobra.Command, _ []string) error { stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag) pump := jobs.NewPump(stopFlag, logger) - gens, err := createGenerators(schema, schemaConfig, intSeed, partitionCount, logger) + distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, stdDistMean, oneStdDev) if err != nil { - return err + return errors.Wrapf(err, "Faile to create distribution function: %s", partitionKeyDistribution) } - gens.StartAll(stopFlag) + gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) + defer utils.IgnoreError(gens.Close) + if !nonInteractive { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index 500fc75..87a6387 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -15,12 +15,12 @@ package generators import ( + "context" "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/exp/rand" "github.com/scylladb/gemini/pkg/routingkey" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" ) @@ -37,7 +37,7 @@ type TokenIndex uint64 type DistributionFunc func() TokenIndex -type GeneratorInterface interface { +type Interface interface { Get() *typedef.ValueWithToken GetOld() *typedef.ValueWithToken GiveOld(_ *typedef.ValueWithToken) @@ -64,14 +64,6 @@ func (g *Generator) PartitionCount() uint64 { return g.partitionCount } -type Generators []*Generator - -func (g Generators) StartAll(stopFlag *stop.Flag) { - for _, gen := range g { - gen.Start(stopFlag) - } -} - type Config struct { PartitionsDistributionFunc DistributionFunc PartitionsRangeConfig typedef.PartitionRangeConfig @@ -80,9 +72,9 @@ type Config struct { PkUsedBufferSize uint64 } -func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Generator { +func NewGenerator(table *typedef.Table, config Config, logger *zap.Logger) Generator { wakeUpSignal := make(chan struct{}) - return &Generator{ + return Generator{ partitions: NewPartitions(int(config.PartitionsCount), int(config.PkUsedBufferSize), wakeUpSignal), partitionCount: config.PartitionsCount, table: table, @@ -135,39 +127,33 @@ func (g *Generator) ReleaseToken(token uint64) { g.GetPartitionForToken(TokenIndex(token)).releaseToken(token) } -func (g *Generator) Start(stopFlag *stop.Flag) { - go func() { - g.logger.Info("starting partition key generation loop") - defer g.partitions.CloseAll() - for { - g.fillAllPartitions(stopFlag) - select { - case <-stopFlag.SignalChannel(): - g.logger.Debug("stopping partition key generation loop", - zap.Uint64("keys_created", g.cntCreated), - zap.Uint64("keys_emitted", g.cntEmitted)) - return - case <-g.wakeUpSignal: - } +func (g *Generator) Start(ctx context.Context) { + defer g.partitions.Close() + g.logger.Info("starting partition key generation loop") + for { + g.fillAllPartitions(ctx) + select { + case <-ctx.Done(): + g.logger.Debug("stopping partition key generation loop", + zap.Uint64("keys_created", g.cntCreated), + zap.Uint64("keys_emitted", g.cntEmitted)) + return + case <-g.wakeUpSignal: } - }() + } } func (g *Generator) FindAndMarkStalePartitions() { r := rand.New(rand.NewSource(10)) - nonStale := make([]bool, g.partitionCount) - for n := uint64(0); n < g.partitionCount*100; n++ { - values := CreatePartitionKeyValues(g.table, r, &g.partitionsConfig) - token, err := g.routingKeyCreator.GetHash(g.table, values) + + for range g.partitionCount * 100 { + token, _, err := g.createPartitionKeyValues(r) if err != nil { - g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error()) + g.logger.Panic("failed to get primary key hash", zap.Error(err)) } - nonStale[g.shardOf(token)] = true - } - for idx, val := range nonStale { - if !val { - g.partitions[idx].MarkStale() + if err = g.partition(token).MarkStale(); err != nil { + g.logger.Panic("failed to mark partition as stale", zap.Error(err)) } } } @@ -175,7 +161,7 @@ func (g *Generator) FindAndMarkStalePartitions() { // fillAllPartitions guarantees that each partition was tested to be full // at least once since the function started and before it ended. // In other words no partition will be starved. -func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) { +func (g *Generator) fillAllPartitions(ctx context.Context) { pFilled := make([]bool, len(g.partitions)) allFilled := func() bool { for idx, filled := range pFilled { @@ -188,22 +174,30 @@ func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) { } return true } - for !stopFlag.IsHardOrSoft() { - values := CreatePartitionKeyValues(g.table, g.r, &g.partitionsConfig) - token, err := g.routingKeyCreator.GetHash(g.table, values) + + for { + select { + case <-ctx.Done(): + return + default: + } + + token, values, err := g.createPartitionKeyValues() if err != nil { - g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error()) + g.logger.Panic("failed to get primary key hash", zap.Error(err)) } g.cntCreated++ - idx := token % g.partitionCount - partition := g.partitions[idx] + + partition := g.partition(token) if partition.Stale() || partition.inFlight.Has(token) { continue } + select { case partition.values <- &typedef.ValueWithToken{Token: token, Value: values}: g.cntEmitted++ default: + idx := g.shardOf(token) if !pFilled[idx] { pFilled[idx] = true if allFilled() { @@ -217,3 +211,28 @@ func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) { func (g *Generator) shardOf(token uint64) int { return int(token % g.partitionCount) } + +func (g *Generator) partition(token uint64) *Partition { + return g.partitions[g.shardOf(token)] +} + +func (g *Generator) createPartitionKeyValues(r ...*rand.Rand) (uint64, []any, error) { + rnd := g.r + + if len(r) > 0 && r[0] != nil { + rnd = r[0] + } + + values := make([]any, 0, g.table.PartitionKeysLenValues()) + + for _, pk := range g.table.PartitionKeys { + values = append(values, pk.Type.GenValue(rnd, &g.partitionsConfig)...) + } + + token, err := g.routingKeyCreator.GetHash(g.table, values) + if err != nil { + return 0, nil, errors.Wrap(err, "failed to get primary key hash") + } + + return token, values, nil +} diff --git a/pkg/generators/generators.go b/pkg/generators/generators.go new file mode 100644 index 0000000..23c88de --- /dev/null +++ b/pkg/generators/generators.go @@ -0,0 +1,84 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generators + +import ( + "context" + "math" + "sync" + + "go.uber.org/zap" + + "github.com/scylladb/gemini/pkg/typedef" +) + +type Generators struct { + Generators []Generator + wg *sync.WaitGroup + cancel context.CancelFunc +} + +func New( + ctx context.Context, + schema *typedef.Schema, + distFunc DistributionFunc, + partitionRangeConfig typedef.PartitionRangeConfig, + seed, distributionSize, pkBufferReuseSize uint64, + logger *zap.Logger, +) *Generators { + gs := make([]Generator, 0, len(schema.Tables)) + + cfg := Config{ + PartitionsRangeConfig: partitionRangeConfig, + PartitionsCount: distributionSize, + PartitionsDistributionFunc: distFunc, + Seed: seed, + PkUsedBufferSize: pkBufferReuseSize, + } + + wg := new(sync.WaitGroup) + wg.Add(len(schema.Tables)) + + ctx, cancel := context.WithCancel(ctx) + + for _, table := range schema.Tables { + g := NewGenerator(table, cfg, logger.Named("generators-"+table.Name)) + go func() { + defer wg.Done() + g.Start(ctx) + }() + + if table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig) < math.MaxUint32 { + // Low partition key variation can lead to having staled partitions + // Let's detect and mark them before running test + g.FindAndMarkStalePartitions() + } + + gs = append(gs, g) + } + + return &Generators{ + Generators: gs, + wg: wg, + cancel: cancel, + } +} + +func (g *Generators) Close() error { + g.cancel() + g.wg.Wait() + + return nil +} diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index e70d46c..839b50c 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -15,7 +15,8 @@ package generators import ( - "sync" + "go.uber.org/multierr" + "sync/atomic" "github.com/scylladb/gemini/pkg/inflight" "github.com/scylladb/gemini/pkg/typedef" @@ -26,18 +27,17 @@ type Partition struct { oldValues chan *typedef.ValueWithToken inFlight inflight.InFlight wakeUpSignal chan<- struct{} // wakes up generator - closed bool - lock sync.RWMutex - isStale bool + closed atomic.Bool + isStale atomic.Bool } -func (s *Partition) MarkStale() { - s.isStale = true - s.Close() +func (s *Partition) MarkStale() error { + s.isStale.Store(true) + return s.Close() } func (s *Partition) Stale() bool { - return s.isStale + return s.isStale.Load() } // get returns a new value and ensures that it's corresponding token @@ -103,39 +103,34 @@ func (s *Partition) pick() *typedef.ValueWithToken { } func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken { - s.lock.RLock() - if s.closed { + if s.closed.Load() { // Since only giveOld could have been potentially called after partition is closed // we need to protect it against writing to closed channel return nil } - defer s.lock.RUnlock() + return s.oldValues } -func (s *Partition) Close() { - s.lock.RLock() - if s.closed { - s.lock.RUnlock() - return - } - s.lock.RUnlock() - s.lock.Lock() - if s.closed { - return +func (s *Partition) Close() error { + for !s.closed.CompareAndSwap(false, true) { } - s.closed = true + close(s.values) close(s.oldValues) - s.lock.Unlock() + + return nil } type Partitions []*Partition -func (p Partitions) CloseAll() { +func (p Partitions) Close() error { + var err error for _, part := range p { - part.Close() + err = multierr.Append(err, part.Close()) } + + return err } func NewPartitions(count, pkBufferSize int, wakeUpSignal chan struct{}) Partitions { diff --git a/pkg/generators/utils.go b/pkg/generators/utils.go index e5758dc..cd08ee5 100644 --- a/pkg/generators/utils.go +++ b/pkg/generators/utils.go @@ -14,27 +14,17 @@ package generators -import ( - "golang.org/x/exp/rand" - - "github.com/scylladb/gemini/pkg/typedef" -) - -func CreatePartitionKeyValues(table *typedef.Table, r *rand.Rand, g *typedef.PartitionRangeConfig) []any { - values := make([]any, 0, table.PartitionKeysLenValues()) - for _, pk := range table.PartitionKeys { - values = append(values, pk.Type.GenValue(r, g)...) - } - return values -} +import "github.com/scylladb/gemini/pkg/typedef" func CreatePkColumns(cnt int, prefix string) typedef.Columns { - var cols typedef.Columns + cols := make(typedef.Columns, 0, cnt) + for i := 0; i < cnt; i++ { cols = append(cols, &typedef.ColumnDef{ Name: GenColumnName(prefix, i), Type: typedef.TYPE_INT, }) } + return cols } diff --git a/pkg/jobs/gen_check_stmt.go b/pkg/jobs/gen_check_stmt.go index 88ad77f..e7b5533 100644 --- a/pkg/jobs/gen_check_stmt.go +++ b/pkg/jobs/gen_check_stmt.go @@ -28,7 +28,7 @@ import ( func GenCheckStmt( s *typedef.Schema, table *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, rnd *rand.Rand, p *typedef.PartitionRangeConfig, ) *typedef.Stmt { @@ -112,7 +112,7 @@ func GenCheckStmt( func genSinglePartitionQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, ) *typedef.Stmt { t.RLock() defer t.RUnlock() @@ -142,7 +142,7 @@ func genSinglePartitionQuery( func genSinglePartitionQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum int, @@ -181,7 +181,7 @@ func genSinglePartitionQueryMv( func genMultiplePartitionQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, numQueryPKs int, ) *typedef.Stmt { t.RLock() @@ -221,7 +221,7 @@ func genMultiplePartitionQuery( func genMultiplePartitionQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, numQueryPKs int, @@ -272,7 +272,7 @@ func genMultiplePartitionQueryMv( func genClusteringRangeQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, maxClusteringRels int, @@ -319,7 +319,7 @@ func genClusteringRangeQuery( func genClusteringRangeQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, maxClusteringRels int, @@ -372,7 +372,7 @@ func genClusteringRangeQueryMv( func genMultiplePartitionClusteringRangeQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, numQueryPKs, maxClusteringRels int, @@ -433,7 +433,7 @@ func genMultiplePartitionClusteringRangeQuery( func genMultiplePartitionClusteringRangeQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, numQueryPKs, maxClusteringRels int, @@ -514,7 +514,7 @@ func genMultiplePartitionClusteringRangeQueryMv( func genSingleIndexQuery( s *typedef.Schema, t *typedef.Table, - _ generators.GeneratorInterface, + _ generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, idxCount int, diff --git a/pkg/jobs/gen_mutate_stmt.go b/pkg/jobs/gen_mutate_stmt.go index 6f65caa..e3baad6 100644 --- a/pkg/jobs/gen_mutate_stmt.go +++ b/pkg/jobs/gen_mutate_stmt.go @@ -26,7 +26,7 @@ import ( "github.com/scylladb/gemini/pkg/typedef" ) -func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.GeneratorInterface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) { +func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) { t.RLock() defer t.RUnlock() diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 3730b7d..5274c28 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -107,7 +107,7 @@ func (l List) Run( schemaConfig typedef.SchemaConfig, s store.Store, pump <-chan time.Duration, - generators []*generators.Generator, + generators *generators.Generators, globalStatus *status.GlobalStatus, logger *zap.Logger, seed uint64, @@ -124,19 +124,19 @@ func (l List) Run( partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() logger.Info("start jobs") - for j := range schema.Tables { - gen := generators[j] - table := schema.Tables[j] + for j, table := range schema.Tables { + generator := &generators.Generators[j] for i := 0; i < int(l.workers); i++ { for idx := range l.jobs { jobF := l.jobs[idx].function r := rand.New(rand.NewSource(seed)) g.Go(func() error { - return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, gen, globalStatus, logger, stopFlag, failFast, verbose) + return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, stopFlag, failFast, verbose) }) } } } + return g.Wait() } From 0d05a8e3bd628fb3adeea192393f96e1c612571c Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Thu, 28 Nov 2024 20:31:54 +0100 Subject: [PATCH 2/5] fix(dist-func): use normalDistMean and normalDistSigma from CLI args Signed-off-by: Dusan Malusev --- cmd/gemini/root.go | 4 ++-- pkg/generators/generator.go | 1 + pkg/generators/generator_test.go | 6 +++--- pkg/generators/partition.go | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 4a3510d..278a102 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -269,14 +269,14 @@ func run(_ *cobra.Command, _ []string) error { stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag) pump := jobs.NewPump(stopFlag, logger) - distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, stdDistMean, oneStdDev) + distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, normalDistMean, normalDistSigma) if err != nil { return errors.Wrapf(err, "Faile to create distribution function: %s", partitionKeyDistribution) } gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) defer utils.IgnoreError(gens.Close) - + if !nonInteractive { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index 87a6387..46c4b73 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -16,6 +16,7 @@ package generators import ( "context" + "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/exp/rand" diff --git a/pkg/generators/generator_test.go b/pkg/generators/generator_test.go index 3a46551..c66c6b0 100644 --- a/pkg/generators/generator_test.go +++ b/pkg/generators/generator_test.go @@ -15,13 +15,13 @@ package generators_test import ( + "context" "sync/atomic" "testing" "go.uber.org/zap" "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" ) @@ -32,7 +32,7 @@ func TestGenerator(t *testing.T) { PartitionKeys: generators.CreatePkColumns(1, "pk"), } var current uint64 - cfg := &generators.Config{ + cfg := generators.Config{ PartitionsRangeConfig: typedef.PartitionRangeConfig{ MaxStringLength: 10, MinStringLength: 0, @@ -47,7 +47,7 @@ func TestGenerator(t *testing.T) { } logger, _ := zap.NewDevelopment() generator := generators.NewGenerator(table, cfg, logger) - generator.Start(stop.NewFlag("main_test")) + generator.Start(context.Background()) for i := uint64(0); i < cfg.PartitionsCount; i++ { atomic.StoreUint64(¤t, i) v := generator.Get() diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index 839b50c..138e228 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -15,9 +15,10 @@ package generators import ( - "go.uber.org/multierr" "sync/atomic" + "go.uber.org/multierr" + "github.com/scylladb/gemini/pkg/inflight" "github.com/scylladb/gemini/pkg/typedef" ) From f6b78ac487004ff1ac4248df4f1e77cb46181c2e Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Wed, 27 Nov 2024 17:24:08 +0100 Subject: [PATCH 3/5] refactor(generator): generators use context.Context instead of stopflag In process of removing the `stopFlag` from gemini's codebase, first step is to migrate the `Value Generators` for patitions. Using context with generators make a lot more sense then the custom built, `stopFlag`. `context` is built-in package in Go, and this is it's usecase - cancelation propagation to background task. Signed-off-by: Dusan Malusev --- cmd/gemini/root.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 278a102..fe3cd2f 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -277,6 +277,9 @@ func run(_ *cobra.Command, _ []string) error { gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) defer utils.IgnoreError(gens.Close) + gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) + defer utils.IgnoreError(gens.Close) + if !nonInteractive { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) From 69be9e9e232a5cd63edda900cd54cc563760d066 Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Thu, 28 Nov 2024 20:31:54 +0100 Subject: [PATCH 4/5] fix(dist-func): use normalDistMean and normalDistSigma from CLI args Signed-off-by: Dusan Malusev --- cmd/gemini/root.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index fe3cd2f..278a102 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -277,9 +277,6 @@ func run(_ *cobra.Command, _ []string) error { gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) defer utils.IgnoreError(gens.Close) - gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) - defer utils.IgnoreError(gens.Close) - if !nonInteractive { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) From 5f374f271f471ec15b32e05a09a51162b777d513 Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Thu, 28 Nov 2024 22:49:38 +0100 Subject: [PATCH 5/5] refactor(stopFlag): completly remove stopFlag for `context` `context` is Go package used mainly to signal cancelation. `stopFlag` package inside of gemini was used the same way. It was just adding overhead and maintainability issues in the long run, as it was only a wrapper for `context.Context` and all at once calling the `cancel` funcs at the end of the program. This removal makes code easier to follow for everybody working on it. `context` package is well known in Go community and can serve many purposes beyond just canceling the background workers (`jobs` in gemini terminology). Hard kill in gemini was used to signal the immadiet stoppage for gemini, without pritting the results, and soft on the contrary could be triggers in a few ways: - gemini validation failed - gemini mutation failed - SIG(INT,TERM) was sent There is no reason to have HARD kill in the application, if we need hard kill, SIGKILL can be sent and everything would be stopped, there will be no need to the the cleanup as kernel will ensure that happens. `context.Context` works as a soft kill, and if something happens bad in gemini (validation fails or mutation fail) `globalStatus.HasError()` would stop all other `goroutines` from continuing if `failFast` CLI flag is passed, so the set of `softKill` is not needed. Signed-off-by: Dusan Malusev --- cmd/gemini/root.go | 66 +++--- pkg/generators/generator.go | 2 +- pkg/generators/generators.go | 2 +- pkg/generators/partition.go | 7 +- pkg/jobs/jobs.go | 49 ++-- pkg/jobs/pump.go | 15 +- pkg/stop/flag.go | 221 ------------------ pkg/stop/flag_test.go | 429 ----------------------------------- 8 files changed, 65 insertions(+), 726 deletions(-) delete mode 100644 pkg/stop/flag.go delete mode 100644 pkg/stop/flag_test.go diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 278a102..6d2a22d 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -22,24 +22,13 @@ import ( "net/http" "net/http/pprof" "os" + "os/signal" "strconv" "strings" + "syscall" "text/tabwriter" "time" - "github.com/scylladb/gemini/pkg/auth" - "github.com/scylladb/gemini/pkg/builders" - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/jobs" - "github.com/scylladb/gemini/pkg/realrandom" - "github.com/scylladb/gemini/pkg/replication" - "github.com/scylladb/gemini/pkg/store" - "github.com/scylladb/gemini/pkg/typedef" - "github.com/scylladb/gemini/pkg/utils" - - "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" - "github.com/gocql/gocql" "github.com/hailocab/go-hostpool" "github.com/pkg/errors" @@ -50,6 +39,17 @@ import ( "golang.org/x/exp/rand" "golang.org/x/net/context" "gonum.org/v1/gonum/stat/distuv" + + "github.com/scylladb/gemini/pkg/auth" + "github.com/scylladb/gemini/pkg/builders" + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/jobs" + "github.com/scylladb/gemini/pkg/realrandom" + "github.com/scylladb/gemini/pkg/replication" + "github.com/scylladb/gemini/pkg/status" + "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/typedef" + "github.com/scylladb/gemini/pkg/utils" ) var ( @@ -137,8 +137,11 @@ func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Sc } func run(_ *cobra.Command, _ []string) error { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGABRT, syscall.SIGTERM, syscall.SIGINT) + defer cancel() + logger := createLogger(level) - globalStatus := status.NewGlobalStatus(1000) + globalStatus := status.NewGlobalStatus(int32(maxErrorsToStore)) defer utils.IgnoreError(logger.Sync) if err := validateSeed(seed); err != nil { @@ -242,7 +245,7 @@ func run(_ *cobra.Command, _ []string) error { if dropSchema && mode != jobs.ReadMode { for _, stmt := range generators.GetDropKeyspace(schema) { logger.Debug(stmt) - if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { + if err = st.Mutate(ctx, typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to drop schema") } } @@ -250,7 +253,7 @@ func run(_ *cobra.Command, _ []string) error { testKeyspace, oracleKeyspace := generators.GetCreateKeyspaces(schema) if err = st.Create( - context.Background(), + ctx, typedef.SimpleStmt(testKeyspace, typedef.CreateKeyspaceStatementType), typedef.SimpleStmt(oracleKeyspace, typedef.CreateKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to create keyspace") @@ -263,11 +266,7 @@ func run(_ *cobra.Command, _ []string) error { } } - ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2) - stopFlag := stop.NewFlag("main") - warmupStopFlag := stop.NewFlag("warmup") - stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag) - pump := jobs.NewPump(stopFlag, logger) + pump := jobs.NewPump(ctx, logger) distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, normalDistMean, normalDistSigma) if err != nil { @@ -281,10 +280,9 @@ func run(_ *cobra.Command, _ []string) error { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) go func() { - defer done() for { select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return case <-ticker.C: sp.Set(" Running Gemini... %v", globalStatus) @@ -293,20 +291,24 @@ func run(_ *cobra.Command, _ []string) error { }() } - if warmup > 0 && !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, warmupStopFlag, failFast, verbose); err != nil { + if warmup > 0 { + warmupCtx, warmupCancel := context.WithTimeout(ctx, warmup) + defer warmupCancel() + + jobsList := jobs.ListFromMode(jobs.WarmupMode, concurrency) + if err = jobsList.Run(warmupCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { logger.Error("warmup encountered an error", zap.Error(err)) - stopFlag.SetHard(true) } } - if !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(mode, duration, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil { - logger.Debug("error detected", zap.Error(err)) - } + jobsCtx, jobsCancel := context.WithTimeout(ctx, duration) + defer jobsCancel() + + jobsList := jobs.ListFromMode(mode, concurrency) + if err = jobsList.Run(jobsCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { + logger.Debug("error detected", zap.Error(err)) } + logger.Info("test finished") globalStatus.PrintResult(outFile, schema, version) if globalStatus.HasErrors() { diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index 46c4b73..c0ee921 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -27,7 +27,7 @@ import ( // TokenIndex represents the position of a token in the token ring. // A token index is translated to a token by a generators. If the generators -// preserves the exact position, then the token index becomes the token; +// preserve the exact position, then the token index becomes the token; // otherwise token index represents an approximation of the token. // // We use a token index approach, because our generators actually generate diff --git a/pkg/generators/generators.go b/pkg/generators/generators.go index 23c88de..95a06d8 100644 --- a/pkg/generators/generators.go +++ b/pkg/generators/generators.go @@ -25,9 +25,9 @@ import ( ) type Generators struct { - Generators []Generator wg *sync.WaitGroup cancel context.CancelFunc + Generators []Generator } func New( diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index 138e228..eb418a7 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -114,12 +114,11 @@ func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken { } func (s *Partition) Close() error { - for !s.closed.CompareAndSwap(false, true) { + if s.closed.CompareAndSwap(false, true) { + close(s.values) + close(s.oldValues) } - close(s.values) - close(s.oldValues) - return nil } diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 5274c28..55ddf2c 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -28,7 +28,6 @@ import ( "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/joberror" "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/store" "github.com/scylladb/gemini/pkg/typedef" ) @@ -53,10 +52,9 @@ var ( ) type List struct { - name string - jobs []job - duration time.Duration - workers uint64 + name string + jobs []job + workers uint64 } type job struct { @@ -72,16 +70,16 @@ type job struct { *generators.Generator, *status.GlobalStatus, *zap.Logger, - *stop.Flag, bool, bool, ) error name string } -func ListFromMode(mode string, duration time.Duration, workers uint64) List { +func ListFromMode(mode string, workers uint64) List { jobs := make([]job, 0, 2) name := "work cycle" + switch mode { case WriteMode: jobs = append(jobs, mutate) @@ -93,11 +91,11 @@ func ListFromMode(mode string, duration time.Duration, workers uint64) List { default: jobs = append(jobs, mutate, validate) } + return List{ - name: name, - jobs: jobs, - duration: duration, - workers: workers, + name: name, + jobs: jobs, + workers: workers, } } @@ -111,16 +109,10 @@ func (l List) Run( globalStatus *status.GlobalStatus, logger *zap.Logger, seed uint64, - stopFlag *stop.Flag, failFast, verbose bool, ) error { logger = logger.Named(l.name) - ctx = stopFlag.CancelContextOnSignal(ctx, stop.SignalHardStop) g, gCtx := errgroup.WithContext(ctx) - time.AfterFunc(l.duration, func() { - logger.Info("jobs time is up, begins jobs completion") - stopFlag.SetSoft(true) - }) partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() logger.Info("start jobs") @@ -131,7 +123,7 @@ func (l List) Run( jobF := l.jobs[idx].function r := rand.New(rand.NewSource(seed)) g.Go(func() error { - return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, stopFlag, failFast, verbose) + return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, failFast, verbose) }) } } @@ -154,7 +146,6 @@ func mutationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, verbose bool, ) error { schemaConfig := &schemaCfg @@ -164,11 +155,8 @@ func mutationJob( logger.Info("ending mutation loop") }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): logger.Debug("mutation job terminated") return nil case hb := <-pump: @@ -187,7 +175,6 @@ func mutationJob( } } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -207,7 +194,6 @@ func validationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -218,11 +204,8 @@ func validationJob( }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return nil case hb := <-pump: time.Sleep(hb) @@ -262,7 +245,6 @@ func validationJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -282,7 +264,6 @@ func warmupJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -292,10 +273,13 @@ func warmupJob( logger.Info("ending warmup loop") }() for { - if stopFlag.IsHardOrSoft() { + select { + case <-ctx.Done(): logger.Debug("warmup job terminated") return nil + default: } + // Do we care about errors during warmup? err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, false, logger) if err != nil { @@ -303,7 +287,6 @@ func warmupJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } diff --git a/pkg/jobs/pump.go b/pkg/jobs/pump.go index c929f8c..4baf6a9 100644 --- a/pkg/jobs/pump.go +++ b/pkg/jobs/pump.go @@ -15,15 +15,14 @@ package jobs import ( + "context" "time" - "github.com/scylladb/gemini/pkg/stop" - "go.uber.org/zap" "golang.org/x/exp/rand" ) -func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { +func NewPump(ctx context.Context, logger *zap.Logger) <-chan time.Duration { pump := make(chan time.Duration, 10000) logger = logger.Named("Pump") go func() { @@ -32,8 +31,14 @@ func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { close(pump) logger.Debug("pump channel closed") }() - for !stopFlag.IsHardOrSoft() { - pump <- newHeartBeat() + + for { + select { + case <-ctx.Done(): + return + default: + pump <- newHeartBeat() + } } }() diff --git a/pkg/stop/flag.go b/pkg/stop/flag.go deleted file mode 100644 index 54c6e1f..0000000 --- a/pkg/stop/flag.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop - -import ( - "context" - "fmt" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - - "go.uber.org/zap" -) - -const ( - SignalNoop uint32 = iota - SignalSoftStop - SignalHardStop -) - -type SignalChannel chan uint32 - -var closedChan = createClosedChan() - -func createClosedChan() SignalChannel { - ch := make(SignalChannel) - close(ch) - return ch -} - -type SyncList[T any] struct { - children []T - childrenLock sync.RWMutex -} - -func (f *SyncList[T]) Append(el T) { - f.childrenLock.Lock() - defer f.childrenLock.Unlock() - f.children = append(f.children, el) -} - -func (f *SyncList[T]) Get() []T { - f.childrenLock.RLock() - defer f.childrenLock.RUnlock() - return f.children -} - -type logger interface { - Debug(msg string, fields ...zap.Field) -} - -type Flag struct { - name string - log logger - ch atomic.Pointer[SignalChannel] - parent *Flag - children SyncList[*Flag] - stopHandlers SyncList[func(signal uint32)] - val atomic.Uint32 -} - -func (s *Flag) Name() string { - return s.name -} - -func (s *Flag) closeChannel() { - ch := s.ch.Swap(&closedChan) - if ch != &closedChan { - close(*ch) - } -} - -func (s *Flag) sendSignal(signal uint32, sendToParent bool) bool { - s.log.Debug(fmt.Sprintf("flag %s received signal %s", s.name, GetStateName(signal))) - s.closeChannel() - out := s.val.CompareAndSwap(SignalNoop, signal) - if !out { - return false - } - - for _, handler := range s.stopHandlers.Get() { - handler(signal) - } - - for _, child := range s.children.Get() { - child.sendSignal(signal, sendToParent) - } - if sendToParent && s.parent != nil { - s.parent.sendSignal(signal, sendToParent) - } - return out -} - -func (s *Flag) SetHard(sendToParent bool) bool { - return s.sendSignal(SignalHardStop, sendToParent) -} - -func (s *Flag) SetSoft(sendToParent bool) bool { - return s.sendSignal(SignalSoftStop, sendToParent) -} - -func (s *Flag) CreateChild(name string) *Flag { - child := newFlag(name, s) - s.children.Append(child) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - child.sendSignal(val, false) - } - return child -} - -func (s *Flag) SignalChannel() SignalChannel { - return *s.ch.Load() -} - -func (s *Flag) IsSoft() bool { - return s.val.Load() == SignalSoftStop -} - -func (s *Flag) IsHard() bool { - return s.val.Load() == SignalHardStop -} - -func (s *Flag) IsHardOrSoft() bool { - return s.val.Load() != SignalNoop -} - -func (s *Flag) AddHandler(handler func(signal uint32)) { - s.stopHandlers.Append(handler) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - handler(val) - } -} - -func (s *Flag) AddHandler2(handler func(), expectedSignal uint32) { - s.AddHandler(func(signal uint32) { - switch expectedSignal { - case SignalNoop: - handler() - default: - if signal == expectedSignal { - handler() - } - } - }) -} - -func (s *Flag) CancelContextOnSignal(ctx context.Context, expectedSignal uint32) context.Context { - ctx, cancel := context.WithCancel(ctx) - s.AddHandler2(cancel, expectedSignal) - return ctx -} - -func (s *Flag) SetLogger(log logger) { - s.log = log -} - -func NewFlag(name string) *Flag { - return newFlag(name, nil) -} - -func newFlag(name string, parent *Flag) *Flag { - out := Flag{ - name: name, - parent: parent, - log: zap.NewNop(), - } - ch := make(SignalChannel) - out.ch.Store(&ch) - return &out -} - -func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) { - graceful := make(chan os.Signal, 1) - signal.Notify(graceful, syscall.SIGTERM, syscall.SIGINT) - go func() { - sig := <-graceful - switch sig { - case syscall.SIGINT: - for i := range flags { - flags[i].SetSoft(true) - } - logger.Info("Get SIGINT signal, begin soft stop.") - default: - for i := range flags { - flags[i].SetHard(true) - } - logger.Info("Get SIGTERM signal, begin hard stop.") - } - }() -} - -func GetStateName(state uint32) string { - switch state { - case SignalSoftStop: - return "soft" - case SignalHardStop: - return "hard" - case SignalNoop: - return "no-signal" - default: - panic(fmt.Sprintf("unexpected signal %d", state)) - } -} diff --git a/pkg/stop/flag_test.go b/pkg/stop/flag_test.go deleted file mode 100644 index 81a4b34..0000000 --- a/pkg/stop/flag_test.go +++ /dev/null @@ -1,429 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop_test - -import ( - "context" - "errors" - "fmt" - "reflect" - "runtime" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/scylladb/gemini/pkg/stop" -) - -func TestHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHard, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func TestSoftStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } -} - -func TestSoftOrHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } - - workersDone.Store(uint32(0)) - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() != nil { - t.Error("Error:SetHard function apply hardStopHandler after SetSoft") - } - - testFlag, ctx, workersDone = initVars() - workersDone.Store(uint32(0)) - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func initVars() (testFlag *stop.Flag, ctx context.Context, workersDone *atomic.Uint32) { - testFlagOut := stop.NewFlag("main_test") - ctx = testFlagOut.CancelContextOnSignal(context.Background(), stop.SignalHardStop) - workersDone = &atomic.Uint32{} - return testFlagOut, ctx, workersDone -} - -func testSignals( - t *testing.T, - workersDone *atomic.Uint32, - workers int, - checkFunc func() bool, - setFunc func(propagation bool) bool, -) { - t.Helper() - for i := 0; i != workers; i++ { - go func() { - for { - if checkFunc() { - workersDone.Add(1) - return - } - time.Sleep(10 * time.Millisecond) - } - }() - } - time.Sleep(200 * time.Millisecond) - setFunc(false) - - for i := 0; i != 10; i++ { - time.Sleep(100 * time.Millisecond) - if workersDone.Load() == uint32(workers) { - break - } - } - - setFuncName := runtime.FuncForPC(reflect.ValueOf(setFunc).Pointer()).Name() - setFuncName, _ = strings.CutSuffix(setFuncName, "-fm") - _, setFuncName, _ = strings.Cut(setFuncName, ".(") - setFuncName = strings.ReplaceAll(setFuncName, ").", ".") - checkFuncName := runtime.FuncForPC(reflect.ValueOf(checkFunc).Pointer()).Name() - checkFuncName, _ = strings.CutSuffix(checkFuncName, "-fm") - _, checkFuncName, _ = strings.Cut(checkFuncName, ".(") - checkFuncName = strings.ReplaceAll(checkFuncName, ").", ".") - - if workersDone.Load() != uint32(workers) { - t.Errorf("Error:%s or %s functions works not correctly %[2]s=%v", setFuncName, checkFuncName, checkFunc()) - } -} - -func TestSendToParent(t *testing.T) { - t.Parallel() - tcases := []tCase{ - { - testName: "parent-hard-true", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-hard-false", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "parent-soft-false", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalNoop, - }, - { - testName: "child11-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child11-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalNoop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalNoop, - child2Signal: stop.SignalNoop, - }, - } - for id := range tcases { - tcase := tcases[id] - t.Run(tcase.testName, func(t *testing.T) { - t.Parallel() - if err := tcase.runTest(); err != nil { - t.Error(err) - } - }) - } -} - -// nolint: govet -type parentChildInfo struct { - parent *stop.Flag - parentSignal uint32 - child1 *stop.Flag - child1Signal uint32 - child11 *stop.Flag - child11Signal uint32 - child12 *stop.Flag - child12Signal uint32 - child2 *stop.Flag - child2Signal uint32 -} - -func (t *parentChildInfo) getFlag(flagName string) *stop.Flag { - switch flagName { - case "parent": - return t.parent - case "child1": - return t.child1 - case "child2": - return t.child2 - case "child11": - return t.child11 - case "child12": - return t.child12 - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) getFlagHandlerState(flagName string) uint32 { - switch flagName { - case "parent": - return t.parentSignal - case "child1": - return t.child1Signal - case "child2": - return t.child2Signal - case "child11": - return t.child11Signal - case "child12": - return t.child12Signal - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) checkFlagState(flag *stop.Flag, expectedState uint32) error { - var err error - flagName := flag.Name() - state := t.getFlagHandlerState(flagName) - if state != expectedState { - err = errors.Join(err, fmt.Errorf("flag %s handler has state %s while it is expected to be %s", flagName, stop.GetStateName(state), stop.GetStateName(expectedState))) - } - flagState := getFlagState(flag) - if stop.GetStateName(expectedState) != flagState { - err = errors.Join(err, fmt.Errorf("flag %s has state %s while it is expected to be %s", flagName, flagState, stop.GetStateName(expectedState))) - } - return err -} - -type tCase struct { - testName string - parentSignal uint32 - child1Signal uint32 - child11Signal uint32 - child12Signal uint32 - child2Signal uint32 -} - -func (t *tCase) runTest() error { - chunk := strings.Split(t.testName, "-") - if len(chunk) != 3 { - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - flagName := chunk[0] - signalTypeName := chunk[1] - sendToParentName := chunk[2] - - var sendToParent bool - switch sendToParentName { - case "true": - sendToParent = true - case "false": - sendToParent = false - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - runt := newParentChildInfo() - flag := runt.getFlag(flagName) - switch signalTypeName { - case "soft": - flag.SetSoft(sendToParent) - case "hard": - flag.SetHard(sendToParent) - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - var err error - err = errors.Join(err, runt.checkFlagState(runt.parent, t.parentSignal)) - err = errors.Join(err, runt.checkFlagState(runt.child1, t.child1Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child2, t.child2Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child11, t.child11Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child12, t.child12Signal)) - return err -} - -func newParentChildInfo() *parentChildInfo { - parent := stop.NewFlag("parent") - child1 := parent.CreateChild("child1") - out := parentChildInfo{ - parent: parent, - child1: child1, - child11: child1.CreateChild("child11"), - child12: child1.CreateChild("child12"), - child2: parent.CreateChild("child2"), - } - - out.parent.AddHandler(func(signal uint32) { - out.parentSignal = signal - }) - out.child1.AddHandler(func(signal uint32) { - out.child1Signal = signal - }) - out.child11.AddHandler(func(signal uint32) { - out.child11Signal = signal - }) - out.child12.AddHandler(func(signal uint32) { - out.child12Signal = signal - }) - out.child2.AddHandler(func(signal uint32) { - out.child2Signal = signal - }) - return &out -} - -func getFlagState(flag *stop.Flag) string { - switch { - case flag.IsSoft(): - return "soft" - case flag.IsHard(): - return "hard" - default: - return "no-signal" - } -} - -func TestSignalChannel(t *testing.T) { - t.Parallel() - t.Run("single-no-signal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - select { - case <-flag.SignalChannel(): - t.Error("should not get the signal") - case <-time.Tick(200 * time.Millisecond): - } - }) - - t.Run("single-beforehand", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - flag.SetSoft(true) - <-flag.SignalChannel() - }) - - t.Run("single-normal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - go func() { - time.Sleep(200 * time.Millisecond) - flag.SetSoft(true) - }() - <-flag.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - parent.SetSoft(true) - <-child.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - parent.SetSoft(true) - child := parent.CreateChild("child") - <-child.SignalChannel() - }) - - t.Run("parent-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - parent.SetSoft(true) - }() - <-child.SignalChannel() - }) - - t.Run("child-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - child.SetSoft(true) - <-parent.SignalChannel() - }) - - t.Run("child-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - child.SetSoft(true) - }() - <-parent.SignalChannel() - }) -}