diff --git a/cmd/gemini/generators.go b/cmd/gemini/generators.go deleted file mode 100644 index 96a1ffc3..00000000 --- a/cmd/gemini/generators.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package main - -import ( - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/typedef" - - "go.uber.org/zap" -) - -func createGenerators( - schema *typedef.Schema, - schemaConfig typedef.SchemaConfig, - seed, distributionSize uint64, - logger *zap.Logger, -) (generators.Generators, error) { - partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() - - gs := make([]*generators.Generator, 0, len(schema.Tables)) - - for id := range schema.Tables { - table := schema.Tables[id] - pkVariations := table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig) - - distFunc, err := createDistributionFunc(partitionKeyDistribution, distributionSize, seed, stdDistMean, oneStdDev) - if err != nil { - return nil, err - } - - tablePartConfig := &generators.Config{ - PartitionsRangeConfig: partitionRangeConfig, - PartitionsCount: distributionSize, - PartitionsDistributionFunc: distFunc, - Seed: seed, - PkUsedBufferSize: pkBufferReuseSize, - } - g := generators.NewGenerator(table, tablePartConfig, logger.Named("generators")) - if pkVariations < 2^32 { - // Low partition key variation can lead to having staled partitions - // Let's detect and mark them before running test - g.FindAndMarkStalePartitions() - } - gs = append(gs, g) - } - return gs, nil -} diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index a90d7d5c..6d2a22d0 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -22,24 +22,13 @@ import ( "net/http" "net/http/pprof" "os" + "os/signal" "strconv" "strings" + "syscall" "text/tabwriter" "time" - "github.com/scylladb/gemini/pkg/auth" - "github.com/scylladb/gemini/pkg/builders" - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/jobs" - "github.com/scylladb/gemini/pkg/realrandom" - "github.com/scylladb/gemini/pkg/replication" - "github.com/scylladb/gemini/pkg/store" - "github.com/scylladb/gemini/pkg/typedef" - "github.com/scylladb/gemini/pkg/utils" - - "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" - "github.com/gocql/gocql" "github.com/hailocab/go-hostpool" "github.com/pkg/errors" @@ -50,6 +39,17 @@ import ( "golang.org/x/exp/rand" "golang.org/x/net/context" "gonum.org/v1/gonum/stat/distuv" + + "github.com/scylladb/gemini/pkg/auth" + "github.com/scylladb/gemini/pkg/builders" + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/jobs" + "github.com/scylladb/gemini/pkg/realrandom" + "github.com/scylladb/gemini/pkg/replication" + "github.com/scylladb/gemini/pkg/status" + "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/typedef" + "github.com/scylladb/gemini/pkg/utils" ) var ( @@ -137,8 +137,11 @@ func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Sc } func run(_ *cobra.Command, _ []string) error { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGABRT, syscall.SIGTERM, syscall.SIGINT) + defer cancel() + logger := createLogger(level) - globalStatus := status.NewGlobalStatus(1000) + globalStatus := status.NewGlobalStatus(int32(maxErrorsToStore)) defer utils.IgnoreError(logger.Sync) if err := validateSeed(seed); err != nil { @@ -242,7 +245,7 @@ func run(_ *cobra.Command, _ []string) error { if dropSchema && mode != jobs.ReadMode { for _, stmt := range generators.GetDropKeyspace(schema) { logger.Debug(stmt) - if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { + if err = st.Mutate(ctx, typedef.SimpleStmt(stmt, typedef.DropKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to drop schema") } } @@ -250,7 +253,7 @@ func run(_ *cobra.Command, _ []string) error { testKeyspace, oracleKeyspace := generators.GetCreateKeyspaces(schema) if err = st.Create( - context.Background(), + ctx, typedef.SimpleStmt(testKeyspace, typedef.CreateKeyspaceStatementType), typedef.SimpleStmt(oracleKeyspace, typedef.CreateKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to create keyspace") @@ -263,26 +266,23 @@ func run(_ *cobra.Command, _ []string) error { } } - ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2) - stopFlag := stop.NewFlag("main") - warmupStopFlag := stop.NewFlag("warmup") - stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag) - pump := jobs.NewPump(stopFlag, logger) + pump := jobs.NewPump(ctx, logger) - gens, err := createGenerators(schema, schemaConfig, intSeed, partitionCount, logger) + distFunc, err := createDistributionFunc(partitionKeyDistribution, partitionCount, intSeed, normalDistMean, normalDistSigma) if err != nil { - return err + return errors.Wrapf(err, "Faile to create distribution function: %s", partitionKeyDistribution) } - gens.StartAll(stopFlag) + + gens := generators.New(ctx, schema, distFunc, schemaConfig.GetPartitionRangeConfig(), intSeed, partitionCount, pkBufferReuseSize, logger) + defer utils.IgnoreError(gens.Close) if !nonInteractive { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) go func() { - defer done() for { select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return case <-ticker.C: sp.Set(" Running Gemini... %v", globalStatus) @@ -291,20 +291,24 @@ func run(_ *cobra.Command, _ []string) error { }() } - if warmup > 0 && !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, warmupStopFlag, failFast, verbose); err != nil { + if warmup > 0 { + warmupCtx, warmupCancel := context.WithTimeout(ctx, warmup) + defer warmupCancel() + + jobsList := jobs.ListFromMode(jobs.WarmupMode, concurrency) + if err = jobsList.Run(warmupCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { logger.Error("warmup encountered an error", zap.Error(err)) - stopFlag.SetHard(true) } } - if !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(mode, duration, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil { - logger.Debug("error detected", zap.Error(err)) - } + jobsCtx, jobsCancel := context.WithTimeout(ctx, duration) + defer jobsCancel() + + jobsList := jobs.ListFromMode(mode, concurrency) + if err = jobsList.Run(jobsCtx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, failFast, verbose); err != nil { + logger.Debug("error detected", zap.Error(err)) } + logger.Info("test finished") globalStatus.PrintResult(outFile, schema, version) if globalStatus.HasErrors() { diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index 500fc759..c0ee9214 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -15,18 +15,19 @@ package generators import ( + "context" + "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/exp/rand" "github.com/scylladb/gemini/pkg/routingkey" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" ) // TokenIndex represents the position of a token in the token ring. // A token index is translated to a token by a generators. If the generators -// preserves the exact position, then the token index becomes the token; +// preserve the exact position, then the token index becomes the token; // otherwise token index represents an approximation of the token. // // We use a token index approach, because our generators actually generate @@ -37,7 +38,7 @@ type TokenIndex uint64 type DistributionFunc func() TokenIndex -type GeneratorInterface interface { +type Interface interface { Get() *typedef.ValueWithToken GetOld() *typedef.ValueWithToken GiveOld(_ *typedef.ValueWithToken) @@ -64,14 +65,6 @@ func (g *Generator) PartitionCount() uint64 { return g.partitionCount } -type Generators []*Generator - -func (g Generators) StartAll(stopFlag *stop.Flag) { - for _, gen := range g { - gen.Start(stopFlag) - } -} - type Config struct { PartitionsDistributionFunc DistributionFunc PartitionsRangeConfig typedef.PartitionRangeConfig @@ -80,9 +73,9 @@ type Config struct { PkUsedBufferSize uint64 } -func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Generator { +func NewGenerator(table *typedef.Table, config Config, logger *zap.Logger) Generator { wakeUpSignal := make(chan struct{}) - return &Generator{ + return Generator{ partitions: NewPartitions(int(config.PartitionsCount), int(config.PkUsedBufferSize), wakeUpSignal), partitionCount: config.PartitionsCount, table: table, @@ -135,39 +128,33 @@ func (g *Generator) ReleaseToken(token uint64) { g.GetPartitionForToken(TokenIndex(token)).releaseToken(token) } -func (g *Generator) Start(stopFlag *stop.Flag) { - go func() { - g.logger.Info("starting partition key generation loop") - defer g.partitions.CloseAll() - for { - g.fillAllPartitions(stopFlag) - select { - case <-stopFlag.SignalChannel(): - g.logger.Debug("stopping partition key generation loop", - zap.Uint64("keys_created", g.cntCreated), - zap.Uint64("keys_emitted", g.cntEmitted)) - return - case <-g.wakeUpSignal: - } +func (g *Generator) Start(ctx context.Context) { + defer g.partitions.Close() + g.logger.Info("starting partition key generation loop") + for { + g.fillAllPartitions(ctx) + select { + case <-ctx.Done(): + g.logger.Debug("stopping partition key generation loop", + zap.Uint64("keys_created", g.cntCreated), + zap.Uint64("keys_emitted", g.cntEmitted)) + return + case <-g.wakeUpSignal: } - }() + } } func (g *Generator) FindAndMarkStalePartitions() { r := rand.New(rand.NewSource(10)) - nonStale := make([]bool, g.partitionCount) - for n := uint64(0); n < g.partitionCount*100; n++ { - values := CreatePartitionKeyValues(g.table, r, &g.partitionsConfig) - token, err := g.routingKeyCreator.GetHash(g.table, values) + + for range g.partitionCount * 100 { + token, _, err := g.createPartitionKeyValues(r) if err != nil { - g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error()) + g.logger.Panic("failed to get primary key hash", zap.Error(err)) } - nonStale[g.shardOf(token)] = true - } - for idx, val := range nonStale { - if !val { - g.partitions[idx].MarkStale() + if err = g.partition(token).MarkStale(); err != nil { + g.logger.Panic("failed to mark partition as stale", zap.Error(err)) } } } @@ -175,7 +162,7 @@ func (g *Generator) FindAndMarkStalePartitions() { // fillAllPartitions guarantees that each partition was tested to be full // at least once since the function started and before it ended. // In other words no partition will be starved. -func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) { +func (g *Generator) fillAllPartitions(ctx context.Context) { pFilled := make([]bool, len(g.partitions)) allFilled := func() bool { for idx, filled := range pFilled { @@ -188,22 +175,30 @@ func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) { } return true } - for !stopFlag.IsHardOrSoft() { - values := CreatePartitionKeyValues(g.table, g.r, &g.partitionsConfig) - token, err := g.routingKeyCreator.GetHash(g.table, values) + + for { + select { + case <-ctx.Done(): + return + default: + } + + token, values, err := g.createPartitionKeyValues() if err != nil { - g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error()) + g.logger.Panic("failed to get primary key hash", zap.Error(err)) } g.cntCreated++ - idx := token % g.partitionCount - partition := g.partitions[idx] + + partition := g.partition(token) if partition.Stale() || partition.inFlight.Has(token) { continue } + select { case partition.values <- &typedef.ValueWithToken{Token: token, Value: values}: g.cntEmitted++ default: + idx := g.shardOf(token) if !pFilled[idx] { pFilled[idx] = true if allFilled() { @@ -217,3 +212,28 @@ func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) { func (g *Generator) shardOf(token uint64) int { return int(token % g.partitionCount) } + +func (g *Generator) partition(token uint64) *Partition { + return g.partitions[g.shardOf(token)] +} + +func (g *Generator) createPartitionKeyValues(r ...*rand.Rand) (uint64, []any, error) { + rnd := g.r + + if len(r) > 0 && r[0] != nil { + rnd = r[0] + } + + values := make([]any, 0, g.table.PartitionKeysLenValues()) + + for _, pk := range g.table.PartitionKeys { + values = append(values, pk.Type.GenValue(rnd, &g.partitionsConfig)...) + } + + token, err := g.routingKeyCreator.GetHash(g.table, values) + if err != nil { + return 0, nil, errors.Wrap(err, "failed to get primary key hash") + } + + return token, values, nil +} diff --git a/pkg/generators/generator_test.go b/pkg/generators/generator_test.go index 3a46551c..c66c6b04 100644 --- a/pkg/generators/generator_test.go +++ b/pkg/generators/generator_test.go @@ -15,13 +15,13 @@ package generators_test import ( + "context" "sync/atomic" "testing" "go.uber.org/zap" "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" ) @@ -32,7 +32,7 @@ func TestGenerator(t *testing.T) { PartitionKeys: generators.CreatePkColumns(1, "pk"), } var current uint64 - cfg := &generators.Config{ + cfg := generators.Config{ PartitionsRangeConfig: typedef.PartitionRangeConfig{ MaxStringLength: 10, MinStringLength: 0, @@ -47,7 +47,7 @@ func TestGenerator(t *testing.T) { } logger, _ := zap.NewDevelopment() generator := generators.NewGenerator(table, cfg, logger) - generator.Start(stop.NewFlag("main_test")) + generator.Start(context.Background()) for i := uint64(0); i < cfg.PartitionsCount; i++ { atomic.StoreUint64(¤t, i) v := generator.Get() diff --git a/pkg/generators/generators.go b/pkg/generators/generators.go new file mode 100644 index 00000000..95a06d84 --- /dev/null +++ b/pkg/generators/generators.go @@ -0,0 +1,84 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generators + +import ( + "context" + "math" + "sync" + + "go.uber.org/zap" + + "github.com/scylladb/gemini/pkg/typedef" +) + +type Generators struct { + wg *sync.WaitGroup + cancel context.CancelFunc + Generators []Generator +} + +func New( + ctx context.Context, + schema *typedef.Schema, + distFunc DistributionFunc, + partitionRangeConfig typedef.PartitionRangeConfig, + seed, distributionSize, pkBufferReuseSize uint64, + logger *zap.Logger, +) *Generators { + gs := make([]Generator, 0, len(schema.Tables)) + + cfg := Config{ + PartitionsRangeConfig: partitionRangeConfig, + PartitionsCount: distributionSize, + PartitionsDistributionFunc: distFunc, + Seed: seed, + PkUsedBufferSize: pkBufferReuseSize, + } + + wg := new(sync.WaitGroup) + wg.Add(len(schema.Tables)) + + ctx, cancel := context.WithCancel(ctx) + + for _, table := range schema.Tables { + g := NewGenerator(table, cfg, logger.Named("generators-"+table.Name)) + go func() { + defer wg.Done() + g.Start(ctx) + }() + + if table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig) < math.MaxUint32 { + // Low partition key variation can lead to having staled partitions + // Let's detect and mark them before running test + g.FindAndMarkStalePartitions() + } + + gs = append(gs, g) + } + + return &Generators{ + Generators: gs, + wg: wg, + cancel: cancel, + } +} + +func (g *Generators) Close() error { + g.cancel() + g.wg.Wait() + + return nil +} diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index e70d46c2..eb418a72 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -15,7 +15,9 @@ package generators import ( - "sync" + "sync/atomic" + + "go.uber.org/multierr" "github.com/scylladb/gemini/pkg/inflight" "github.com/scylladb/gemini/pkg/typedef" @@ -26,18 +28,17 @@ type Partition struct { oldValues chan *typedef.ValueWithToken inFlight inflight.InFlight wakeUpSignal chan<- struct{} // wakes up generator - closed bool - lock sync.RWMutex - isStale bool + closed atomic.Bool + isStale atomic.Bool } -func (s *Partition) MarkStale() { - s.isStale = true - s.Close() +func (s *Partition) MarkStale() error { + s.isStale.Store(true) + return s.Close() } func (s *Partition) Stale() bool { - return s.isStale + return s.isStale.Load() } // get returns a new value and ensures that it's corresponding token @@ -103,39 +104,33 @@ func (s *Partition) pick() *typedef.ValueWithToken { } func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken { - s.lock.RLock() - if s.closed { + if s.closed.Load() { // Since only giveOld could have been potentially called after partition is closed // we need to protect it against writing to closed channel return nil } - defer s.lock.RUnlock() + return s.oldValues } -func (s *Partition) Close() { - s.lock.RLock() - if s.closed { - s.lock.RUnlock() - return +func (s *Partition) Close() error { + if s.closed.CompareAndSwap(false, true) { + close(s.values) + close(s.oldValues) } - s.lock.RUnlock() - s.lock.Lock() - if s.closed { - return - } - s.closed = true - close(s.values) - close(s.oldValues) - s.lock.Unlock() + + return nil } type Partitions []*Partition -func (p Partitions) CloseAll() { +func (p Partitions) Close() error { + var err error for _, part := range p { - part.Close() + err = multierr.Append(err, part.Close()) } + + return err } func NewPartitions(count, pkBufferSize int, wakeUpSignal chan struct{}) Partitions { diff --git a/pkg/generators/utils.go b/pkg/generators/utils.go index e5758dc0..cd08ee59 100644 --- a/pkg/generators/utils.go +++ b/pkg/generators/utils.go @@ -14,27 +14,17 @@ package generators -import ( - "golang.org/x/exp/rand" - - "github.com/scylladb/gemini/pkg/typedef" -) - -func CreatePartitionKeyValues(table *typedef.Table, r *rand.Rand, g *typedef.PartitionRangeConfig) []any { - values := make([]any, 0, table.PartitionKeysLenValues()) - for _, pk := range table.PartitionKeys { - values = append(values, pk.Type.GenValue(r, g)...) - } - return values -} +import "github.com/scylladb/gemini/pkg/typedef" func CreatePkColumns(cnt int, prefix string) typedef.Columns { - var cols typedef.Columns + cols := make(typedef.Columns, 0, cnt) + for i := 0; i < cnt; i++ { cols = append(cols, &typedef.ColumnDef{ Name: GenColumnName(prefix, i), Type: typedef.TYPE_INT, }) } + return cols } diff --git a/pkg/jobs/gen_check_stmt.go b/pkg/jobs/gen_check_stmt.go index 88ad77f6..e7b5533e 100644 --- a/pkg/jobs/gen_check_stmt.go +++ b/pkg/jobs/gen_check_stmt.go @@ -28,7 +28,7 @@ import ( func GenCheckStmt( s *typedef.Schema, table *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, rnd *rand.Rand, p *typedef.PartitionRangeConfig, ) *typedef.Stmt { @@ -112,7 +112,7 @@ func GenCheckStmt( func genSinglePartitionQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, ) *typedef.Stmt { t.RLock() defer t.RUnlock() @@ -142,7 +142,7 @@ func genSinglePartitionQuery( func genSinglePartitionQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum int, @@ -181,7 +181,7 @@ func genSinglePartitionQueryMv( func genMultiplePartitionQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, numQueryPKs int, ) *typedef.Stmt { t.RLock() @@ -221,7 +221,7 @@ func genMultiplePartitionQuery( func genMultiplePartitionQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, numQueryPKs int, @@ -272,7 +272,7 @@ func genMultiplePartitionQueryMv( func genClusteringRangeQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, maxClusteringRels int, @@ -319,7 +319,7 @@ func genClusteringRangeQuery( func genClusteringRangeQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, maxClusteringRels int, @@ -372,7 +372,7 @@ func genClusteringRangeQueryMv( func genMultiplePartitionClusteringRangeQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, numQueryPKs, maxClusteringRels int, @@ -433,7 +433,7 @@ func genMultiplePartitionClusteringRangeQuery( func genMultiplePartitionClusteringRangeQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, numQueryPKs, maxClusteringRels int, @@ -514,7 +514,7 @@ func genMultiplePartitionClusteringRangeQueryMv( func genSingleIndexQuery( s *typedef.Schema, t *typedef.Table, - _ generators.GeneratorInterface, + _ generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, idxCount int, diff --git a/pkg/jobs/gen_mutate_stmt.go b/pkg/jobs/gen_mutate_stmt.go index 6f65caab..e3baad69 100644 --- a/pkg/jobs/gen_mutate_stmt.go +++ b/pkg/jobs/gen_mutate_stmt.go @@ -26,7 +26,7 @@ import ( "github.com/scylladb/gemini/pkg/typedef" ) -func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.GeneratorInterface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) { +func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) { t.RLock() defer t.RUnlock() diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 3730b7de..55ddf2cb 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -28,7 +28,6 @@ import ( "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/joberror" "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/store" "github.com/scylladb/gemini/pkg/typedef" ) @@ -53,10 +52,9 @@ var ( ) type List struct { - name string - jobs []job - duration time.Duration - workers uint64 + name string + jobs []job + workers uint64 } type job struct { @@ -72,16 +70,16 @@ type job struct { *generators.Generator, *status.GlobalStatus, *zap.Logger, - *stop.Flag, bool, bool, ) error name string } -func ListFromMode(mode string, duration time.Duration, workers uint64) List { +func ListFromMode(mode string, workers uint64) List { jobs := make([]job, 0, 2) name := "work cycle" + switch mode { case WriteMode: jobs = append(jobs, mutate) @@ -93,11 +91,11 @@ func ListFromMode(mode string, duration time.Duration, workers uint64) List { default: jobs = append(jobs, mutate, validate) } + return List{ - name: name, - jobs: jobs, - duration: duration, - workers: workers, + name: name, + jobs: jobs, + workers: workers, } } @@ -107,36 +105,30 @@ func (l List) Run( schemaConfig typedef.SchemaConfig, s store.Store, pump <-chan time.Duration, - generators []*generators.Generator, + generators *generators.Generators, globalStatus *status.GlobalStatus, logger *zap.Logger, seed uint64, - stopFlag *stop.Flag, failFast, verbose bool, ) error { logger = logger.Named(l.name) - ctx = stopFlag.CancelContextOnSignal(ctx, stop.SignalHardStop) g, gCtx := errgroup.WithContext(ctx) - time.AfterFunc(l.duration, func() { - logger.Info("jobs time is up, begins jobs completion") - stopFlag.SetSoft(true) - }) partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() logger.Info("start jobs") - for j := range schema.Tables { - gen := generators[j] - table := schema.Tables[j] + for j, table := range schema.Tables { + generator := &generators.Generators[j] for i := 0; i < int(l.workers); i++ { for idx := range l.jobs { jobF := l.jobs[idx].function r := rand.New(rand.NewSource(seed)) g.Go(func() error { - return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, gen, globalStatus, logger, stopFlag, failFast, verbose) + return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, generator, globalStatus, logger, failFast, verbose) }) } } } + return g.Wait() } @@ -154,7 +146,6 @@ func mutationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, verbose bool, ) error { schemaConfig := &schemaCfg @@ -164,11 +155,8 @@ func mutationJob( logger.Info("ending mutation loop") }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): logger.Debug("mutation job terminated") return nil case hb := <-pump: @@ -187,7 +175,6 @@ func mutationJob( } } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -207,7 +194,6 @@ func validationJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -218,11 +204,8 @@ func validationJob( }() for { - if stopFlag.IsHardOrSoft() { - return nil - } select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return nil case hb := <-pump: time.Sleep(hb) @@ -262,7 +245,6 @@ func validationJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } @@ -282,7 +264,6 @@ func warmupJob( g *generators.Generator, globalStatus *status.GlobalStatus, logger *zap.Logger, - stopFlag *stop.Flag, failFast, _ bool, ) error { schemaConfig := &schemaCfg @@ -292,10 +273,13 @@ func warmupJob( logger.Info("ending warmup loop") }() for { - if stopFlag.IsHardOrSoft() { + select { + case <-ctx.Done(): logger.Debug("warmup job terminated") return nil + default: } + // Do we care about errors during warmup? err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, false, logger) if err != nil { @@ -303,7 +287,6 @@ func warmupJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) return nil } } diff --git a/pkg/jobs/pump.go b/pkg/jobs/pump.go index c929f8ce..4baf6a98 100644 --- a/pkg/jobs/pump.go +++ b/pkg/jobs/pump.go @@ -15,15 +15,14 @@ package jobs import ( + "context" "time" - "github.com/scylladb/gemini/pkg/stop" - "go.uber.org/zap" "golang.org/x/exp/rand" ) -func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { +func NewPump(ctx context.Context, logger *zap.Logger) <-chan time.Duration { pump := make(chan time.Duration, 10000) logger = logger.Named("Pump") go func() { @@ -32,8 +31,14 @@ func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { close(pump) logger.Debug("pump channel closed") }() - for !stopFlag.IsHardOrSoft() { - pump <- newHeartBeat() + + for { + select { + case <-ctx.Done(): + return + default: + pump <- newHeartBeat() + } } }() diff --git a/pkg/stop/flag.go b/pkg/stop/flag.go deleted file mode 100644 index 54c6e1f2..00000000 --- a/pkg/stop/flag.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop - -import ( - "context" - "fmt" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - - "go.uber.org/zap" -) - -const ( - SignalNoop uint32 = iota - SignalSoftStop - SignalHardStop -) - -type SignalChannel chan uint32 - -var closedChan = createClosedChan() - -func createClosedChan() SignalChannel { - ch := make(SignalChannel) - close(ch) - return ch -} - -type SyncList[T any] struct { - children []T - childrenLock sync.RWMutex -} - -func (f *SyncList[T]) Append(el T) { - f.childrenLock.Lock() - defer f.childrenLock.Unlock() - f.children = append(f.children, el) -} - -func (f *SyncList[T]) Get() []T { - f.childrenLock.RLock() - defer f.childrenLock.RUnlock() - return f.children -} - -type logger interface { - Debug(msg string, fields ...zap.Field) -} - -type Flag struct { - name string - log logger - ch atomic.Pointer[SignalChannel] - parent *Flag - children SyncList[*Flag] - stopHandlers SyncList[func(signal uint32)] - val atomic.Uint32 -} - -func (s *Flag) Name() string { - return s.name -} - -func (s *Flag) closeChannel() { - ch := s.ch.Swap(&closedChan) - if ch != &closedChan { - close(*ch) - } -} - -func (s *Flag) sendSignal(signal uint32, sendToParent bool) bool { - s.log.Debug(fmt.Sprintf("flag %s received signal %s", s.name, GetStateName(signal))) - s.closeChannel() - out := s.val.CompareAndSwap(SignalNoop, signal) - if !out { - return false - } - - for _, handler := range s.stopHandlers.Get() { - handler(signal) - } - - for _, child := range s.children.Get() { - child.sendSignal(signal, sendToParent) - } - if sendToParent && s.parent != nil { - s.parent.sendSignal(signal, sendToParent) - } - return out -} - -func (s *Flag) SetHard(sendToParent bool) bool { - return s.sendSignal(SignalHardStop, sendToParent) -} - -func (s *Flag) SetSoft(sendToParent bool) bool { - return s.sendSignal(SignalSoftStop, sendToParent) -} - -func (s *Flag) CreateChild(name string) *Flag { - child := newFlag(name, s) - s.children.Append(child) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - child.sendSignal(val, false) - } - return child -} - -func (s *Flag) SignalChannel() SignalChannel { - return *s.ch.Load() -} - -func (s *Flag) IsSoft() bool { - return s.val.Load() == SignalSoftStop -} - -func (s *Flag) IsHard() bool { - return s.val.Load() == SignalHardStop -} - -func (s *Flag) IsHardOrSoft() bool { - return s.val.Load() != SignalNoop -} - -func (s *Flag) AddHandler(handler func(signal uint32)) { - s.stopHandlers.Append(handler) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - handler(val) - } -} - -func (s *Flag) AddHandler2(handler func(), expectedSignal uint32) { - s.AddHandler(func(signal uint32) { - switch expectedSignal { - case SignalNoop: - handler() - default: - if signal == expectedSignal { - handler() - } - } - }) -} - -func (s *Flag) CancelContextOnSignal(ctx context.Context, expectedSignal uint32) context.Context { - ctx, cancel := context.WithCancel(ctx) - s.AddHandler2(cancel, expectedSignal) - return ctx -} - -func (s *Flag) SetLogger(log logger) { - s.log = log -} - -func NewFlag(name string) *Flag { - return newFlag(name, nil) -} - -func newFlag(name string, parent *Flag) *Flag { - out := Flag{ - name: name, - parent: parent, - log: zap.NewNop(), - } - ch := make(SignalChannel) - out.ch.Store(&ch) - return &out -} - -func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) { - graceful := make(chan os.Signal, 1) - signal.Notify(graceful, syscall.SIGTERM, syscall.SIGINT) - go func() { - sig := <-graceful - switch sig { - case syscall.SIGINT: - for i := range flags { - flags[i].SetSoft(true) - } - logger.Info("Get SIGINT signal, begin soft stop.") - default: - for i := range flags { - flags[i].SetHard(true) - } - logger.Info("Get SIGTERM signal, begin hard stop.") - } - }() -} - -func GetStateName(state uint32) string { - switch state { - case SignalSoftStop: - return "soft" - case SignalHardStop: - return "hard" - case SignalNoop: - return "no-signal" - default: - panic(fmt.Sprintf("unexpected signal %d", state)) - } -} diff --git a/pkg/stop/flag_test.go b/pkg/stop/flag_test.go deleted file mode 100644 index 81a4b344..00000000 --- a/pkg/stop/flag_test.go +++ /dev/null @@ -1,429 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop_test - -import ( - "context" - "errors" - "fmt" - "reflect" - "runtime" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/scylladb/gemini/pkg/stop" -) - -func TestHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHard, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func TestSoftStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } -} - -func TestSoftOrHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } - - workersDone.Store(uint32(0)) - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() != nil { - t.Error("Error:SetHard function apply hardStopHandler after SetSoft") - } - - testFlag, ctx, workersDone = initVars() - workersDone.Store(uint32(0)) - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func initVars() (testFlag *stop.Flag, ctx context.Context, workersDone *atomic.Uint32) { - testFlagOut := stop.NewFlag("main_test") - ctx = testFlagOut.CancelContextOnSignal(context.Background(), stop.SignalHardStop) - workersDone = &atomic.Uint32{} - return testFlagOut, ctx, workersDone -} - -func testSignals( - t *testing.T, - workersDone *atomic.Uint32, - workers int, - checkFunc func() bool, - setFunc func(propagation bool) bool, -) { - t.Helper() - for i := 0; i != workers; i++ { - go func() { - for { - if checkFunc() { - workersDone.Add(1) - return - } - time.Sleep(10 * time.Millisecond) - } - }() - } - time.Sleep(200 * time.Millisecond) - setFunc(false) - - for i := 0; i != 10; i++ { - time.Sleep(100 * time.Millisecond) - if workersDone.Load() == uint32(workers) { - break - } - } - - setFuncName := runtime.FuncForPC(reflect.ValueOf(setFunc).Pointer()).Name() - setFuncName, _ = strings.CutSuffix(setFuncName, "-fm") - _, setFuncName, _ = strings.Cut(setFuncName, ".(") - setFuncName = strings.ReplaceAll(setFuncName, ").", ".") - checkFuncName := runtime.FuncForPC(reflect.ValueOf(checkFunc).Pointer()).Name() - checkFuncName, _ = strings.CutSuffix(checkFuncName, "-fm") - _, checkFuncName, _ = strings.Cut(checkFuncName, ".(") - checkFuncName = strings.ReplaceAll(checkFuncName, ").", ".") - - if workersDone.Load() != uint32(workers) { - t.Errorf("Error:%s or %s functions works not correctly %[2]s=%v", setFuncName, checkFuncName, checkFunc()) - } -} - -func TestSendToParent(t *testing.T) { - t.Parallel() - tcases := []tCase{ - { - testName: "parent-hard-true", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-hard-false", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "parent-soft-false", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalNoop, - }, - { - testName: "child11-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child11-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalNoop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalNoop, - child2Signal: stop.SignalNoop, - }, - } - for id := range tcases { - tcase := tcases[id] - t.Run(tcase.testName, func(t *testing.T) { - t.Parallel() - if err := tcase.runTest(); err != nil { - t.Error(err) - } - }) - } -} - -// nolint: govet -type parentChildInfo struct { - parent *stop.Flag - parentSignal uint32 - child1 *stop.Flag - child1Signal uint32 - child11 *stop.Flag - child11Signal uint32 - child12 *stop.Flag - child12Signal uint32 - child2 *stop.Flag - child2Signal uint32 -} - -func (t *parentChildInfo) getFlag(flagName string) *stop.Flag { - switch flagName { - case "parent": - return t.parent - case "child1": - return t.child1 - case "child2": - return t.child2 - case "child11": - return t.child11 - case "child12": - return t.child12 - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) getFlagHandlerState(flagName string) uint32 { - switch flagName { - case "parent": - return t.parentSignal - case "child1": - return t.child1Signal - case "child2": - return t.child2Signal - case "child11": - return t.child11Signal - case "child12": - return t.child12Signal - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) checkFlagState(flag *stop.Flag, expectedState uint32) error { - var err error - flagName := flag.Name() - state := t.getFlagHandlerState(flagName) - if state != expectedState { - err = errors.Join(err, fmt.Errorf("flag %s handler has state %s while it is expected to be %s", flagName, stop.GetStateName(state), stop.GetStateName(expectedState))) - } - flagState := getFlagState(flag) - if stop.GetStateName(expectedState) != flagState { - err = errors.Join(err, fmt.Errorf("flag %s has state %s while it is expected to be %s", flagName, flagState, stop.GetStateName(expectedState))) - } - return err -} - -type tCase struct { - testName string - parentSignal uint32 - child1Signal uint32 - child11Signal uint32 - child12Signal uint32 - child2Signal uint32 -} - -func (t *tCase) runTest() error { - chunk := strings.Split(t.testName, "-") - if len(chunk) != 3 { - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - flagName := chunk[0] - signalTypeName := chunk[1] - sendToParentName := chunk[2] - - var sendToParent bool - switch sendToParentName { - case "true": - sendToParent = true - case "false": - sendToParent = false - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - runt := newParentChildInfo() - flag := runt.getFlag(flagName) - switch signalTypeName { - case "soft": - flag.SetSoft(sendToParent) - case "hard": - flag.SetHard(sendToParent) - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - var err error - err = errors.Join(err, runt.checkFlagState(runt.parent, t.parentSignal)) - err = errors.Join(err, runt.checkFlagState(runt.child1, t.child1Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child2, t.child2Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child11, t.child11Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child12, t.child12Signal)) - return err -} - -func newParentChildInfo() *parentChildInfo { - parent := stop.NewFlag("parent") - child1 := parent.CreateChild("child1") - out := parentChildInfo{ - parent: parent, - child1: child1, - child11: child1.CreateChild("child11"), - child12: child1.CreateChild("child12"), - child2: parent.CreateChild("child2"), - } - - out.parent.AddHandler(func(signal uint32) { - out.parentSignal = signal - }) - out.child1.AddHandler(func(signal uint32) { - out.child1Signal = signal - }) - out.child11.AddHandler(func(signal uint32) { - out.child11Signal = signal - }) - out.child12.AddHandler(func(signal uint32) { - out.child12Signal = signal - }) - out.child2.AddHandler(func(signal uint32) { - out.child2Signal = signal - }) - return &out -} - -func getFlagState(flag *stop.Flag) string { - switch { - case flag.IsSoft(): - return "soft" - case flag.IsHard(): - return "hard" - default: - return "no-signal" - } -} - -func TestSignalChannel(t *testing.T) { - t.Parallel() - t.Run("single-no-signal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - select { - case <-flag.SignalChannel(): - t.Error("should not get the signal") - case <-time.Tick(200 * time.Millisecond): - } - }) - - t.Run("single-beforehand", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - flag.SetSoft(true) - <-flag.SignalChannel() - }) - - t.Run("single-normal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - go func() { - time.Sleep(200 * time.Millisecond) - flag.SetSoft(true) - }() - <-flag.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - parent.SetSoft(true) - <-child.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - parent.SetSoft(true) - child := parent.CreateChild("child") - <-child.SignalChannel() - }) - - t.Run("parent-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - parent.SetSoft(true) - }() - <-child.SignalChannel() - }) - - t.Run("child-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - child.SetSoft(true) - <-parent.SignalChannel() - }) - - t.Run("child-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - child.SetSoft(true) - }() - <-parent.SignalChannel() - }) -}