diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..2be6c4c2 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 + +updates: + - package-ecosystem: "github-actions" + directory: / + schedule: + interval: "monthly" + + - package-ecosystem: "gomod" + directory: / + schedule: + interval: "weekly" diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index be86ef16..814bc554 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -29,9 +29,9 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - gemini-features: ["basic", "normal", "all"] + gemini-features: ["basic", "normal"] gemini-concurrency: [4] - duration: ["5m"] + duration: ["1m"] dataset-size: [large, small] oracle-scylla-version: ["6.1"] test-scylla-version: ["6.2"] @@ -54,6 +54,7 @@ jobs: CONCURRENCY=${{ matrix.gemini-concurrency }} \ CQL_FEATURES=${{ matrix.gemini-features }} \ DURATION=${{ matrix.duration }} \ + WARMUP=30s \ DATASET_SIZE=${{ matrix.dataset-size }} \ - name: Shutdown ScyllaDB shell: bash diff --git a/.gitignore b/.gitignore index a8d923ff..3704f90f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ cmd/gemini/dist/ bin/ coverage.txt dist/ +results/*.log diff --git a/.run/Run Gemini Mixed.run.xml b/.run/Run Gemini Mixed.run.xml new file mode 100644 index 00000000..c6d7c6ca --- /dev/null +++ b/.run/Run Gemini Mixed.run.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/.run/Run Gemini Read.run.xml b/.run/Run Gemini Read.run.xml new file mode 100644 index 00000000..2c851167 --- /dev/null +++ b/.run/Run Gemini Read.run.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/.run/Run Gemini Write.run.xml b/.run/Run Gemini Write.run.xml new file mode 100644 index 00000000..a2225e64 --- /dev/null +++ b/.run/Run Gemini Write.run.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 9c11c89e..660ed522 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,4 @@ FROM golang:1.23-bookworm AS build - ENV GO111MODULE=on ENV GOAMD64=v3 ENV GOARM64=v8.3,crypto diff --git a/Makefile b/Makefile index 16778510..3d9da0ae 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ define dl_tgz chmod +x "$(GOBIN)/$(1)"; \ fi endef - + $(GOBIN)/golangci-lint: GOLANGCI_VERSION = 1.62.0 $(GOBIN)/golangci-lint: Makefile $(call dl_tgz,golangci-lint,https://github.com/golangci/golangci-lint/releases/download/v$(GOLANGCI_VERSION)/golangci-lint-$(GOLANGCI_VERSION)-$(GOOS)-amd64.tar.gz) @@ -67,36 +67,33 @@ scylla-shutdown: test: @go test -covermode=atomic -race -coverprofile=coverage.txt -timeout 5m -json -v ./... 2>&1 | gotestfmt -showteststatus -CQL_FEATURES ?= all -CONCURRENCY ?= 50 +CQL_FEATURES ?= normal +CONCURRENCY ?= 1 DURATION ?= 10m -WARMUP ?= 1m +WARMUP ?= 0 +MODE ?= mixed DATASET_SIZE ?= large SEED ?= $(shell date +%s) GEMINI_BINARY ?= $(PWD)/bin/gemini GEMINI_TEST_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Networks.gemini.IPAddress }}' gemini-test) GEMINI_ORACLE_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Networks.gemini.IPAddress }}' gemini-oracle) GEMINI_DOCKER_NETWORK ?= gemini -GEMINI_FLAGS = --fail-fast \ +GEMINI_FLAGS =--fail-fast \ --level=info \ --non-interactive \ - --materialized-views=false \ --consistency=LOCAL_QUORUM \ --test-host-selection-policy=token-aware \ --oracle-host-selection-policy=token-aware \ - --mode=mixed \ + --mode=$(MODE) \ --non-interactive \ --request-timeout=5s \ --connect-timeout=15s \ --use-server-timestamps=false \ --async-objects-stabilization-attempts=10 \ --max-mutation-retries=10 \ - --async-objects-stabilization-backoff=1000ms \ - --max-mutation-retries-backoff=1000ms \ --replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \ --oracle-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \ --concurrency=$(CONCURRENCY) \ - --use-lwt=true \ --dataset-size=$(DATASET_SIZE) \ --seed=$(SEED) \ --schema-seed=$(SEED) \ @@ -108,19 +105,23 @@ GEMINI_FLAGS = --fail-fast \ .PHONY: pprof-profile pprof-profile: - go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/profile + go tool pprof -http=:8080 http://localhost:6060/debug/pprof/profile .PHONY: pprof-heap pprof-heap: - go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/heap + go tool pprof -http=:8081 http://localhost:6060/debug/pprof/heap .PHONY: pprof-goroutine pprof-goroutine: - go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/goroutine + go tool pprof -http=:8082 http://localhost:6060/debug/pprof/goroutine .PHONY: pprof-block pprof-block: - go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/block + go tool pprof -http=:8083 http://localhost:6060/debug/pprof/block + +.PHONY: pprof-mutex +pprof-mutex: + go tool pprof -http=:8084 http://localhost:6060/debug/pprof/mutex .PHONY: docker-integration-test docker-integration-test: @@ -130,6 +131,7 @@ docker-integration-test: docker run \ -it \ --rm \ + --memory=4G \ -p 6060:6060 \ --name gemini \ --network $(GEMINI_DOCKER_NETWORK) \ @@ -138,24 +140,16 @@ docker-integration-test: scylladb/gemini:$(DOCKER_VERSION) \ --test-cluster=gemini-test \ --oracle-cluster=gemini-oracle \ - --outfile=/results/gemini_result.log \ - --tracing-outfile=/results/gemini_tracing.log \ - --test-statement-log-file=/results/gemini_test_statement.log \ - --oracle-statement-log-file=/results/gemini_oracle_statement.log \ $(GEMINI_FLAGS) .PHONY: integration-test integration-test: - @mkdir -p $(PWD)/results - @touch $(PWD)/results/gemini_seed - @echo $(GEMINI_SEED) > $(PWD)/results/gemini_seed - @$(GEMINI_BINARY) \ + mkdir -p $(PWD)/results + touch $(PWD)/results/gemini_seed + echo $(GEMINI_SEED) > $(PWD)/results/gemini_seed + $(GEMINI_BINARY) \ --test-cluster=$(GEMINI_TEST_CLUSTER) \ --oracle-cluster=$(GEMINI_ORACLE_CLUSTER) \ - --outfile=$(PWD)/results/gemini_result.log \ - --tracing-outfile=$(PWD)/results/gemini_tracing.log \ - --test-statement-log-file=$(PWD)/results/gemini_test_statement.log \ - --oracle-statement-log-file=$(PWD)/results/gemini_oracle_statement.log \ $(GEMINI_FLAGS) .PHONY: clean diff --git a/cmd/gemini/flags.go b/cmd/gemini/flags.go new file mode 100644 index 00000000..9b6abfa6 --- /dev/null +++ b/cmd/gemini/flags.go @@ -0,0 +1,162 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "time" + + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/jobs" +) + +var ( + testClusterHost []string + testClusterUsername string + testClusterPassword string + oracleClusterHost []string + oracleClusterUsername string + oracleClusterPassword string + schemaFile string + outFileArg string + concurrency uint64 + seed string + schemaSeed string + dropSchema bool + verbose bool + mode string + failFast bool + nonInteractive bool + duration time.Duration + bind string + warmup time.Duration + replicationStrategy string + tableOptions []string + oracleReplicationStrategy string + consistency string + maxTables int + maxPartitionKeys int + minPartitionKeys int + maxClusteringKeys int + minClusteringKeys int + maxColumns int + minColumns int + datasetSize string + cqlFeatures string + useMaterializedViews bool + level string + maxRetriesMutate int + maxRetriesMutateSleep time.Duration + maxErrorsToStore int + pkBufferReuseSize uint64 + partitionCount uint64 + partitionKeyDistribution string + normalDistMean float64 + normalDistSigma float64 + tracingOutFile string + useCounters bool + asyncObjectStabilizationAttempts int + asyncObjectStabilizationDelay time.Duration + useLWT bool + testClusterHostSelectionPolicy string + oracleClusterHostSelectionPolicy string + useServerSideTimestamps bool + requestTimeout time.Duration + connectTimeout time.Duration + profilingPort int + testStatementLogFile string + oracleStatementLogFile string +) + +func init() { + rootCmd.Version = version + ", commit " + commit + ", date " + date + rootCmd.Flags().StringSliceVarP(&testClusterHost, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test") + _ = rootCmd.MarkFlagRequired("test-cluster") + rootCmd.Flags().StringVarP(&testClusterUsername, "test-username", "", "", "Username for the test cluster") + rootCmd.Flags().StringVarP(&testClusterPassword, "test-password", "", "", "Password for the test cluster") + rootCmd.Flags().StringSliceVarP( + &oracleClusterHost, "oracle-cluster", "o", []string{}, + "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used") + rootCmd.Flags().StringVarP(&oracleClusterUsername, "oracle-username", "", "", "Username for the oracle cluster") + rootCmd.Flags().StringVarP(&oracleClusterPassword, "oracle-password", "", "", "Password for the oracle cluster") + rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file") + rootCmd.Flags().StringVarP(&mode, "mode", "m", jobs.MixedMode, "Query operation mode. Mode options: write, read, mixed (default)") + rootCmd.Flags().Uint64VarP(&concurrency, "concurrency", "c", 10, "Number of threads per table to run concurrently") + rootCmd.Flags().StringVarP(&seed, "seed", "s", "random", "Statement seed value") + rootCmd.Flags().StringVarP(&schemaSeed, "schema-seed", "", "random", "Schema seed value") + rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", true, "Drop schema before starting tests run") + rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output during test run") + rootCmd.Flags().BoolVarP(&failFast, "fail-fast", "f", false, "Stop on the first failure") + rootCmd.Flags().BoolVarP(&nonInteractive, "non-interactive", "", false, "Run in non-interactive mode (disable progress indicator)") + rootCmd.Flags().DurationVarP(&duration, "duration", "", 30*time.Second, "") + rootCmd.Flags().StringVarP(&outFileArg, "outfile", "", "", "Specify the name of the file where the results should go") + rootCmd.Flags().StringVarP(&bind, "bind", "b", ":2112", "Specify the interface and port which to bind prometheus metrics on. Default is ':2112'") + rootCmd.Flags().DurationVarP(&warmup, "warmup", "", 30*time.Second, "Specify the warmup perid as a duration for example 30s or 10h") + rootCmd.Flags().StringVarP( + &replicationStrategy, "replication-strategy", "", "simple", + "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide "+ + "the entire specification in the form {'class':'....'}") + rootCmd.Flags().StringVarP( + &oracleReplicationStrategy, "oracle-replication-strategy", "", "simple", + "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each "+ + "type or provide the entire specification in the form {'class':'....'}") + rootCmd.Flags().StringArrayVarP(&tableOptions, "table-options", "", []string{}, "Repeatable argument to set table options to be added to the created tables") + rootCmd.Flags().StringVarP(&consistency, "consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE") + rootCmd.Flags().IntVarP(&maxTables, "max-tables", "", 1, "Maximum number of generated tables") + rootCmd.Flags().IntVarP(&maxPartitionKeys, "max-partition-keys", "", 6, "Maximum number of generated partition keys") + rootCmd.Flags().IntVarP(&minPartitionKeys, "min-partition-keys", "", 2, "Minimum number of generated partition keys") + rootCmd.Flags().IntVarP(&maxClusteringKeys, "max-clustering-keys", "", 4, "Maximum number of generated clustering keys") + rootCmd.Flags().IntVarP(&minClusteringKeys, "min-clustering-keys", "", 2, "Minimum number of generated clustering keys") + rootCmd.Flags().IntVarP(&maxColumns, "max-columns", "", 16, "Maximum number of generated columns") + rootCmd.Flags().IntVarP(&minColumns, "min-columns", "", 8, "Minimum number of generated columns") + rootCmd.Flags().StringVarP(&datasetSize, "dataset-size", "", "large", "Specify the type of dataset size to use, small|large") + rootCmd.Flags().StringVarP(&cqlFeatures, "cql-features", "", "normal", "Specify the type of cql features to use, basic|normal|all") + rootCmd.Flags().BoolVarP(&useMaterializedViews, "materialized-views", "", false, "Run gemini with materialized views support") + rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal") + rootCmd.Flags().IntVarP(&maxRetriesMutate, "max-mutation-retries", "", 10, "Maximum number of attempts to apply a mutation") + rootCmd.Flags().DurationVarP( + &maxRetriesMutateSleep, "max-mutation-retries-backoff", "", 10*time.Millisecond, + "Duration between attempts to apply a mutation for example 10ms or 1s") + rootCmd.Flags().Uint64VarP(&pkBufferReuseSize, "partition-key-buffer-reuse-size", "", 1000, "Number of reused buffered partition keys") + rootCmd.Flags().Uint64VarP(&partitionCount, "token-range-slices", "", 1000, "Number of slices to divide the token space into") + rootCmd.Flags().StringVarP( + &partitionKeyDistribution, "partition-key-distribution", "", "uniform", + "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf") + rootCmd.Flags().Float64VarP(&normalDistMean, "normal-dist-mean", "", generators.StdDistMean, "Mean of the normal distribution") + rootCmd.Flags().Float64VarP(&normalDistSigma, "normal-dist-sigma", "", generators.OneStdDev, "Sigma of the normal distribution, defaults to one standard deviation ~0.341") + rootCmd.Flags().StringVarP( + &tracingOutFile, "tracing-outfile", "", "", + "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.") + rootCmd.Flags().BoolVarP(&useCounters, "use-counters", "", false, "Ensure that at least one table is a counter table") + rootCmd.Flags().IntVarP( + &asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10, + "Maximum number of attempts to validate result sets from MV and SI") + rootCmd.Flags().DurationVarP( + &asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond, + "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s") + rootCmd.Flags().BoolVarP(&useLWT, "use-lwt", "", false, "Emit LWT based updates") + rootCmd.Flags().StringVarP( + &oracleClusterHostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware", + "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware") + rootCmd.Flags().StringVarP( + &testClusterHostSelectionPolicy, "test-host-selection-policy", "", "token-aware", + "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware") + rootCmd.Flags().BoolVarP(&useServerSideTimestamps, "use-server-timestamps", "", false, "Use server-side generated timestamps for writes") + rootCmd.Flags().DurationVarP(&requestTimeout, "request-timeout", "", 30*time.Second, "Duration of waiting request execution") + rootCmd.Flags().DurationVarP(&connectTimeout, "connect-timeout", "", 30*time.Second, "Duration of waiting connection established") + rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'") + rootCmd.Flags().IntVarP(&maxErrorsToStore, "max-errors-to-store", "", 1000, "Maximum number of errors to store and output at the end") + rootCmd.Flags().StringVarP(&testStatementLogFile, "test-statement-log-file", "", "", "File to write statements flow to") + rootCmd.Flags().StringVarP(&oracleStatementLogFile, "oracle-statement-log-file", "", "", "File to write statements flow to") +} diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index ed4d4e2c..4c3a7e8c 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -15,31 +15,18 @@ package main import ( - "encoding/json" "fmt" "log" - "math" "net/http" "net/http/pprof" "os" + "os/signal" "strconv" "strings" + "syscall" "text/tabwriter" "time" - "github.com/scylladb/gemini/pkg/auth" - "github.com/scylladb/gemini/pkg/builders" - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/jobs" - "github.com/scylladb/gemini/pkg/realrandom" - "github.com/scylladb/gemini/pkg/replication" - "github.com/scylladb/gemini/pkg/store" - "github.com/scylladb/gemini/pkg/typedef" - "github.com/scylladb/gemini/pkg/utils" - - "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" - "github.com/gocql/gocql" "github.com/hailocab/go-hostpool" "github.com/pkg/errors" @@ -49,94 +36,25 @@ import ( "go.uber.org/zap/zapcore" "golang.org/x/exp/rand" "golang.org/x/net/context" - "gonum.org/v1/gonum/stat/distuv" -) -var ( - testClusterHost []string - testClusterUsername string - testClusterPassword string - oracleClusterHost []string - oracleClusterUsername string - oracleClusterPassword string - schemaFile string - outFileArg string - concurrency uint64 - seed string - schemaSeed string - dropSchema bool - verbose bool - mode string - failFast bool - nonInteractive bool - duration time.Duration - bind string - warmup time.Duration - replicationStrategy string - tableOptions []string - oracleReplicationStrategy string - consistency string - maxTables int - maxPartitionKeys int - minPartitionKeys int - maxClusteringKeys int - minClusteringKeys int - maxColumns int - minColumns int - datasetSize string - cqlFeatures string - useMaterializedViews bool - level string - maxRetriesMutate int - maxRetriesMutateSleep time.Duration - maxErrorsToStore int - pkBufferReuseSize uint64 - partitionCount uint64 - partitionKeyDistribution string - normalDistMean float64 - normalDistSigma float64 - tracingOutFile string - useCounters bool - asyncObjectStabilizationAttempts int - asyncObjectStabilizationDelay time.Duration - useLWT bool - testClusterHostSelectionPolicy string - oracleClusterHostSelectionPolicy string - useServerSideTimestamps bool - requestTimeout time.Duration - connectTimeout time.Duration - profilingPort int - testStatementLogFile string - oracleStatementLogFile string + "github.com/scylladb/gemini/pkg/auth" + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/jobs" + "github.com/scylladb/gemini/pkg/realrandom" + "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/typedef" + "github.com/scylladb/gemini/pkg/utils" + + "github.com/scylladb/gemini/pkg/status" ) func interactive() bool { return !nonInteractive } -func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Schema, error) { - byteValue, err := os.ReadFile(confFile) - if err != nil { - return nil, err - } - - var shm typedef.Schema - - err = json.Unmarshal(byteValue, &shm) - if err != nil { - return nil, err - } - - schemaBuilder := builders.NewSchemaBuilder() - schemaBuilder.Keyspace(shm.Keyspace).Config(schemaConfig) - for t, tbl := range shm.Tables { - shm.Tables[t].LinkIndexAndColumns() - schemaBuilder.Table(tbl) - } - return schemaBuilder.Build(), nil -} - func run(_ *cobra.Command, _ []string) error { + start := time.Now() + logger := createLogger(level) globalStatus := status.NewGlobalStatus(1000) defer utils.IgnoreError(logger.Sync) @@ -153,11 +71,7 @@ func run(_ *cobra.Command, _ []string) error { rand.Seed(intSeed) - cons, err := gocql.ParseConsistencyWrapper(consistency) - if err != nil { - logger.Error("Unable parse consistency, error=%s. Falling back on Quorum", zap.Error(err)) - cons = gocql.Quorum - } + cons := gocql.ParseConsistency(consistency) testHostSelectionPolicy, err := getHostSelectionPolicy(testClusterHostSelectionPolicy, testClusterHost) if err != nil { @@ -168,15 +82,23 @@ func run(_ *cobra.Command, _ []string) error { return err } + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGABRT, syscall.SIGTERM) + defer cancel() + go func() { - http.Handle("/metrics", promhttp.Handler()) - _ = http.ListenAndServe(bind, nil) + mux := http.NewServeMux() + mux.Handle("GET /metrics", promhttp.Handler()) + log.Fatal(http.ListenAndServe(bind, mux)) }() if profilingPort != 0 { go func() { mux := http.NewServeMux() - mux.HandleFunc("/profile", pprof.Profile) + mux.HandleFunc("GET /debug/pprof/", pprof.Index) + mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile) + mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace) log.Fatal(http.ListenAndServe(":"+strconv.Itoa(profilingPort), mux)) }() } @@ -204,10 +126,7 @@ func run(_ *cobra.Command, _ []string) error { } } - jsonSchema, _ := json.MarshalIndent(schema, "", " ") - - printSetup(intSeed, intSchemaSeed) - fmt.Printf("Schema: %v\n", string(jsonSchema)) + printSetup(intSeed, intSchemaSeed, mode) testCluster, oracleCluster := createClusters(cons, testHostSelectionPolicy, oracleHostSelectionPolicy, logger) storeConfig := store.Config{ @@ -230,7 +149,10 @@ func run(_ *cobra.Command, _ []string) error { return ioErr } tracingFile = tf - defer utils.IgnoreError(tracingFile.Sync) + defer func() { + utils.IgnoreError(tracingFile.Sync) + utils.IgnoreError(tracingFile.Close) + }() } } st, err := store.New(schema, testCluster, oracleCluster, storeConfig, tracingFile, logger) @@ -250,39 +172,44 @@ func run(_ *cobra.Command, _ []string) error { testKeyspace, oracleKeyspace := generators.GetCreateKeyspaces(schema) if err = st.Create( - context.Background(), + ctx, typedef.SimpleStmt(testKeyspace, typedef.CreateKeyspaceStatementType), typedef.SimpleStmt(oracleKeyspace, typedef.CreateKeyspaceStatementType)); err != nil { return errors.Wrap(err, "unable to create keyspace") } - for _, stmt := range generators.GetCreateSchema(schema) { + for _, stmt := range generators.GetCreateSchema(schema, useMaterializedViews) { logger.Debug(stmt) - if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.CreateSchemaStatementType)); err != nil { + if err = st.Mutate(ctx, typedef.SimpleStmt(stmt, typedef.CreateSchemaStatementType)); err != nil { return errors.Wrap(err, "unable to create schema") } } - ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2) - stopFlag := stop.NewFlag("main") - warmupStopFlag := stop.NewFlag("warmup") - stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag) - pump := jobs.NewPump(stopFlag, logger) + distFunc, err := generators.ParseDistributionDefault(partitionKeyDistribution, partitionCount, intSeed) + if err != nil { + return err + } - gens, err := createGenerators(schema, schemaConfig, intSeed, partitionCount, logger) + gens, err := generators.New(ctx, schema, schemaConfig, intSeed, partitionCount, logger, distFunc, pkBufferReuseSize) if err != nil { return err } - gens.StartAll(stopFlag) + + defer utils.IgnoreError(gens.Close) + + ctx, done := context.WithTimeout(ctx, duration+warmup+10*time.Second) + defer done() if !nonInteractive { sp := createSpinner(interactive()) ticker := time.NewTicker(time.Second) + go func() { - defer done() + defer ticker.Stop() + for { select { - case <-stopFlag.SignalChannel(): + case <-ctx.Done(): return case <-ticker.C: sp.Set(" Running Gemini... %v", globalStatus) @@ -291,68 +218,32 @@ func run(_ *cobra.Command, _ []string) error { }() } - if warmup > 0 && !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, warmupStopFlag, failFast, verbose); err != nil { - logger.Error("warmup encountered an error", zap.Error(err)) - stopFlag.SetHard(true) - } + jobsList := jobs.New(mode, duration, concurrency, logger, schema, st, globalStatus, intSeed, gens, warmup, failFast) + if err = jobsList.Run(ctx); err != nil { + logger.Error("error detected", zap.Error(err)) } - if !stopFlag.IsHardOrSoft() { - jobsList := jobs.ListFromMode(mode, duration, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil { - logger.Debug("error detected", zap.Error(err)) - } - } logger.Info("test finished") - globalStatus.PrintResult(outFile, schema, version) + globalStatus.PrintResult(outFile, schema, version, start) + if globalStatus.HasErrors() { return errors.Errorf("gemini encountered errors, exiting with non zero status") } + return nil } -func createFile(fname string, def *os.File) (*os.File, error) { - if fname != "" { - f, err := os.Create(fname) +func createFile(name string, def *os.File) (*os.File, error) { + if name != "" { + f, err := os.Create(name) if err != nil { - return nil, errors.Wrapf(err, "Unable to open output file %s", fname) + return nil, errors.Wrapf(err, "Unable to open output file %s", name) } + return f, nil } - return def, nil -} -const ( - stdDistMean = math.MaxUint64 / 2 - oneStdDev = 0.341 * math.MaxUint64 -) - -func createDistributionFunc(distribution string, size, seed uint64, mu, sigma float64) (generators.DistributionFunc, error) { - switch strings.ToLower(distribution) { - case "zipf": - dist := rand.NewZipf(rand.New(rand.NewSource(seed)), 1.1, 1.1, size) - return func() generators.TokenIndex { - return generators.TokenIndex(dist.Uint64()) - }, nil - case "normal": - dist := distuv.Normal{ - Src: rand.NewSource(seed), - Mu: mu, - Sigma: sigma, - } - return func() generators.TokenIndex { - return generators.TokenIndex(dist.Rand()) - }, nil - case "uniform": - rnd := rand.New(rand.NewSource(seed)) - return func() generators.TokenIndex { - return generators.TokenIndex(rnd.Uint64n(size)) - }, nil - default: - return nil, errors.Errorf("unsupported distribution: %s", distribution) - } + return def, nil } func createLogger(level string) *zap.Logger { @@ -407,35 +298,8 @@ func createClusters( return testCluster, oracleCluster } -func getReplicationStrategy(rs string, fallback *replication.Replication, logger *zap.Logger) *replication.Replication { - switch rs { - case "network": - return replication.NewNetworkTopologyStrategy() - case "simple": - return replication.NewSimpleStrategy() - default: - replicationStrategy := &replication.Replication{} - if err := json.Unmarshal([]byte(strings.ReplaceAll(rs, "'", "\"")), replicationStrategy); err != nil { - logger.Error("unable to parse replication strategy", zap.String("strategy", rs), zap.Error(err)) - return fallback - } - return replicationStrategy - } -} - -func getCQLFeature(feature string) typedef.CQLFeature { - switch strings.ToLower(feature) { - case "all": - return typedef.CQL_FEATURE_ALL - case "normal": - return typedef.CQL_FEATURE_NORMAL - default: - return typedef.CQL_FEATURE_BASIC - } -} - func getHostSelectionPolicy(policy string, hosts []string) (gocql.HostSelectionPolicy, error) { - switch policy { + switch strings.ToLower(policy) { case "round-robin": return gocql.RoundRobinHostPolicy(), nil case "host-pool": @@ -454,89 +318,7 @@ var rootCmd = &cobra.Command{ SilenceUsage: true, } -func init() { - rootCmd.Version = version + ", commit " + commit + ", date " + date - rootCmd.Flags().StringSliceVarP(&testClusterHost, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test") - _ = rootCmd.MarkFlagRequired("test-cluster") - rootCmd.Flags().StringVarP(&testClusterUsername, "test-username", "", "", "Username for the test cluster") - rootCmd.Flags().StringVarP(&testClusterPassword, "test-password", "", "", "Password for the test cluster") - rootCmd.Flags().StringSliceVarP( - &oracleClusterHost, "oracle-cluster", "o", []string{}, - "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used") - rootCmd.Flags().StringVarP(&oracleClusterUsername, "oracle-username", "", "", "Username for the oracle cluster") - rootCmd.Flags().StringVarP(&oracleClusterPassword, "oracle-password", "", "", "Password for the oracle cluster") - rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file") - rootCmd.Flags().StringVarP(&mode, "mode", "m", jobs.MixedMode, "Query operation mode. Mode options: write, read, mixed (default)") - rootCmd.Flags().Uint64VarP(&concurrency, "concurrency", "c", 10, "Number of threads per table to run concurrently") - rootCmd.Flags().StringVarP(&seed, "seed", "s", "random", "Statement seed value") - rootCmd.Flags().StringVarP(&schemaSeed, "schema-seed", "", "random", "Schema seed value") - rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", false, "Drop schema before starting tests run") - rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output during test run") - rootCmd.Flags().BoolVarP(&failFast, "fail-fast", "f", false, "Stop on the first failure") - rootCmd.Flags().BoolVarP(&nonInteractive, "non-interactive", "", false, "Run in non-interactive mode (disable progress indicator)") - rootCmd.Flags().DurationVarP(&duration, "duration", "", 30*time.Second, "") - rootCmd.Flags().StringVarP(&outFileArg, "outfile", "", "", "Specify the name of the file where the results should go") - rootCmd.Flags().StringVarP(&bind, "bind", "b", ":2112", "Specify the interface and port which to bind prometheus metrics on. Default is ':2112'") - rootCmd.Flags().DurationVarP(&warmup, "warmup", "", 30*time.Second, "Specify the warmup perid as a duration for example 30s or 10h") - rootCmd.Flags().StringVarP( - &replicationStrategy, "replication-strategy", "", "simple", - "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide "+ - "the entire specification in the form {'class':'....'}") - rootCmd.Flags().StringVarP( - &oracleReplicationStrategy, "oracle-replication-strategy", "", "simple", - "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each "+ - "type or provide the entire specification in the form {'class':'....'}") - rootCmd.Flags().StringArrayVarP(&tableOptions, "table-options", "", []string{}, "Repeatable argument to set table options to be added to the created tables") - rootCmd.Flags().StringVarP(&consistency, "consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE") - rootCmd.Flags().IntVarP(&maxTables, "max-tables", "", 1, "Maximum number of generated tables") - rootCmd.Flags().IntVarP(&maxPartitionKeys, "max-partition-keys", "", 6, "Maximum number of generated partition keys") - rootCmd.Flags().IntVarP(&minPartitionKeys, "min-partition-keys", "", 2, "Minimum number of generated partition keys") - rootCmd.Flags().IntVarP(&maxClusteringKeys, "max-clustering-keys", "", 4, "Maximum number of generated clustering keys") - rootCmd.Flags().IntVarP(&minClusteringKeys, "min-clustering-keys", "", 2, "Minimum number of generated clustering keys") - rootCmd.Flags().IntVarP(&maxColumns, "max-columns", "", 16, "Maximum number of generated columns") - rootCmd.Flags().IntVarP(&minColumns, "min-columns", "", 8, "Minimum number of generated columns") - rootCmd.Flags().StringVarP(&datasetSize, "dataset-size", "", "large", "Specify the type of dataset size to use, small|large") - rootCmd.Flags().StringVarP(&cqlFeatures, "cql-features", "", "basic", "Specify the type of cql features to use, basic|normal|all") - rootCmd.Flags().BoolVarP(&useMaterializedViews, "materialized-views", "", false, "Run gemini with materialized views support") - rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal") - rootCmd.Flags().IntVarP(&maxRetriesMutate, "max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation") - rootCmd.Flags().DurationVarP( - &maxRetriesMutateSleep, "max-mutation-retries-backoff", "", 10*time.Millisecond, - "Duration between attempts to apply a mutation for example 10ms or 1s") - rootCmd.Flags().Uint64VarP(&pkBufferReuseSize, "partition-key-buffer-reuse-size", "", 100, "Number of reused buffered partition keys") - rootCmd.Flags().Uint64VarP(&partitionCount, "token-range-slices", "", 10000, "Number of slices to divide the token space into") - rootCmd.Flags().StringVarP( - &partitionKeyDistribution, "partition-key-distribution", "", "uniform", - "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf") - rootCmd.Flags().Float64VarP(&normalDistMean, "normal-dist-mean", "", stdDistMean, "Mean of the normal distribution") - rootCmd.Flags().Float64VarP(&normalDistSigma, "normal-dist-sigma", "", oneStdDev, "Sigma of the normal distribution, defaults to one standard deviation ~0.341") - rootCmd.Flags().StringVarP( - &tracingOutFile, "tracing-outfile", "", "", - "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.") - rootCmd.Flags().BoolVarP(&useCounters, "use-counters", "", false, "Ensure that at least one table is a counter table") - rootCmd.Flags().IntVarP( - &asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10, - "Maximum number of attempts to validate result sets from MV and SI") - rootCmd.Flags().DurationVarP( - &asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond, - "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s") - rootCmd.Flags().BoolVarP(&useLWT, "use-lwt", "", false, "Emit LWT based updates") - rootCmd.Flags().StringVarP( - &oracleClusterHostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware", - "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware") - rootCmd.Flags().StringVarP( - &testClusterHostSelectionPolicy, "test-host-selection-policy", "", "token-aware", - "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware") - rootCmd.Flags().BoolVarP(&useServerSideTimestamps, "use-server-timestamps", "", false, "Use server-side generated timestamps for writes") - rootCmd.Flags().DurationVarP(&requestTimeout, "request-timeout", "", 30*time.Second, "Duration of waiting request execution") - rootCmd.Flags().DurationVarP(&connectTimeout, "connect-timeout", "", 30*time.Second, "Duration of waiting connection established") - rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'") - rootCmd.Flags().IntVarP(&maxErrorsToStore, "max-errors-to-store", "", 1000, "Maximum number of errors to store and output at the end") - rootCmd.Flags().StringVarP(&testStatementLogFile, "test-statement-log-file", "", "", "File to write statements flow to") - rootCmd.Flags().StringVarP(&oracleStatementLogFile, "oracle-statement-log-file", "", "", "File to write statements flow to") -} - -func printSetup(seed, schemaSeed uint64) { +func printSetup(seed, schemaSeed uint64, mode string) { tw := new(tabwriter.Writer) tw.Init(os.Stdout, 0, 8, 2, '\t', tabwriter.AlignRight) fmt.Fprintf(tw, "Seed:\t%d\n", seed) @@ -546,6 +328,7 @@ func printSetup(seed, schemaSeed uint64) { fmt.Fprintf(tw, "Concurrency:\t%d\n", concurrency) fmt.Fprintf(tw, "Test cluster:\t%s\n", testClusterHost) fmt.Fprintf(tw, "Oracle cluster:\t%s\n", oracleClusterHost) + fmt.Fprintf(tw, "Mode:\t%s\n", mode) if outFileArg == "" { fmt.Fprintf(tw, "Output file:\t%s\n", "") } else { diff --git a/cmd/gemini/schema.go b/cmd/gemini/schema.go index 68884a62..1f8a5e42 100644 --- a/cmd/gemini/schema.go +++ b/cmd/gemini/schema.go @@ -15,13 +15,25 @@ package main import ( + "encoding/json" + "os" "strings" + "go.uber.org/zap" + + "github.com/scylladb/gemini/pkg/builders" "github.com/scylladb/gemini/pkg/replication" "github.com/scylladb/gemini/pkg/tableopts" "github.com/scylladb/gemini/pkg/typedef" +) - "go.uber.org/zap" +const ( + MaxBlobLength = 1e4 + MinBlobLength = 0 + MaxStringLength = 1000 + MinStringLength = 0 + MaxTupleParts = 20 + MaxUDTParts = 20 ) func createSchemaConfig(logger *zap.Logger) typedef.SchemaConfig { @@ -43,11 +55,13 @@ func createSchemaConfig(logger *zap.Logger) typedef.SchemaConfig { MaxTupleParts: 2, MaxBlobLength: 20, MaxStringLength: 20, + MinBlobLength: 0, + MinStringLength: 0, UseCounters: defaultConfig.UseCounters, UseLWT: defaultConfig.UseLWT, + UseMaterializedViews: defaultConfig.UseMaterializedViews, CQLFeature: defaultConfig.CQLFeature, AsyncObjectStabilizationAttempts: defaultConfig.AsyncObjectStabilizationAttempts, - UseMaterializedViews: defaultConfig.UseMaterializedViews, AsyncObjectStabilizationDelay: defaultConfig.AsyncObjectStabilizationDelay, } default: @@ -56,19 +70,9 @@ func createSchemaConfig(logger *zap.Logger) typedef.SchemaConfig { } func createDefaultSchemaConfig(logger *zap.Logger) typedef.SchemaConfig { - const ( - MaxBlobLength = 1e4 - MinBlobLength = 0 - MaxStringLength = 1000 - MinStringLength = 0 - MaxTupleParts = 20 - MaxUDTParts = 20 - ) - rs := getReplicationStrategy(replicationStrategy, replication.NewSimpleStrategy(), logger) - ors := getReplicationStrategy(oracleReplicationStrategy, rs, logger) return typedef.SchemaConfig{ - ReplicationStrategy: rs, - OracleReplicationStrategy: ors, + ReplicationStrategy: replication.MustParseReplication(replicationStrategy), + OracleReplicationStrategy: replication.MustParseReplication(oracleReplicationStrategy), TableOptions: tableopts.CreateTableOptions(tableOptions, logger), MaxTables: maxTables, MaxPartitionKeys: maxPartitionKeys, @@ -85,9 +89,31 @@ func createDefaultSchemaConfig(logger *zap.Logger) typedef.SchemaConfig { MinStringLength: MinStringLength, UseCounters: useCounters, UseLWT: useLWT, - CQLFeature: getCQLFeature(cqlFeatures), + CQLFeature: typedef.ParseCQLFeature(cqlFeatures), UseMaterializedViews: useMaterializedViews, AsyncObjectStabilizationAttempts: asyncObjectStabilizationAttempts, AsyncObjectStabilizationDelay: asyncObjectStabilizationDelay, } } + +func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Schema, error) { + byteValue, err := os.ReadFile(confFile) + if err != nil { + return nil, err + } + + var shm typedef.Schema + + err = json.Unmarshal(byteValue, &shm) + if err != nil { + return nil, err + } + + schemaBuilder := builders.NewSchemaBuilder() + schemaBuilder.Keyspace(shm.Keyspace).Config(schemaConfig) + for t, tbl := range shm.Tables { + shm.Tables[t].LinkIndexAndColumns() + schemaBuilder.Table(tbl) + } + return schemaBuilder.Build(), nil +} diff --git a/cmd/gemini/strategies_test.go b/cmd/gemini/strategies_test.go index 6859024e..29d0e7e8 100644 --- a/cmd/gemini/strategies_test.go +++ b/cmd/gemini/strategies_test.go @@ -20,46 +20,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "go.uber.org/zap" - - "github.com/scylladb/gemini/pkg/replication" "github.com/scylladb/gemini/pkg/typedef" ) -func TestGetReplicationStrategy(t *testing.T) { - tests := map[string]struct { - strategy string - expected string - }{ - "simple strategy": { - strategy: "{\"class\": \"SimpleStrategy\", \"replication_factor\": \"1\"}", - expected: "{'class':'SimpleStrategy','replication_factor':'1'}", - }, - "simple strategy single quotes": { - strategy: "{'class': 'SimpleStrategy', 'replication_factor': '1'}", - expected: "{'class':'SimpleStrategy','replication_factor':'1'}", - }, - "network topology strategy": { - strategy: "{\"class\": \"NetworkTopologyStrategy\", \"dc1\": 3, \"dc2\": 3}", - expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}", - }, - "network topology strategy single quotes": { - strategy: "{'class': 'NetworkTopologyStrategy', 'dc1': 3, 'dc2': 3}", - expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}", - }, - } - logger := zap.NewNop() - fallback := replication.NewSimpleStrategy() - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - got := getReplicationStrategy(tc.strategy, fallback, logger) - if diff := cmp.Diff(got.ToCQL(), tc.expected); diff != "" { - t.Errorf("expected=%s, got=%s,diff=%s", tc.strategy, got.ToCQL(), diff) - } - }) - } -} - // TestReadExampleSchema main task of this test to be sure that schema example (schema.json) is correct and have correct marshal, unmarshal func TestReadExampleSchema(t *testing.T) { filePath := "schema.json" diff --git a/docker/docker-compose-scylla.yml b/docker/docker-compose-scylla.yml index d9e266f9..521c3161 100644 --- a/docker/docker-compose-scylla.yml +++ b/docker/docker-compose-scylla.yml @@ -21,7 +21,7 @@ services: image: scylladb/scylla:${SCYLLA_TEST_VERSION:-6.2} container_name: gemini-test restart: unless-stopped - command: --smp 2 --memory 1024M --api-address 0.0.0.0 + command: --smp 1 --memory 1024M --api-address 0.0.0.0 networks: gemini: ipv4_address: 192.168.100.3 diff --git a/go.mod b/go.mod index 4b5efcad..c3499609 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f golang.org/x/net v0.31.0 - golang.org/x/sync v0.9.0 gonum.org/v1/gonum v0.15.1 gopkg.in/inf.v0 v0.9.1 ) @@ -42,4 +41,4 @@ require ( google.golang.org/protobuf v1.35.2 // indirect ) -replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.3 +replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.4 diff --git a/go.sum b/go.sum index 8a28967a..b15a1666 100644 --- a/go.sum +++ b/go.sum @@ -65,8 +65,8 @@ github.com/scylladb/go-reflectx v1.0.1 h1:b917wZM7189pZdlND9PbIJ6NQxfDPfBvUaQ7cj github.com/scylladb/go-reflectx v1.0.1/go.mod h1:rWnOfDIRWBGN0miMLIcoPt/Dhi2doCMZqwMCJ3KupFc= github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE= github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs= -github.com/scylladb/gocql v1.14.3 h1:f6ZFxM9plyAk0h7NZcXfZ1aJu3cGk0Mjy/X293gqIFA= -github.com/scylladb/gocql v1.14.3/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0= +github.com/scylladb/gocql v1.14.4 h1:MhevwCfyAraQ6RvZYFO3pF4Lt0YhvQlfg8Eo2HEqVQA= +github.com/scylladb/gocql v1.14.4/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0= github.com/scylladb/gocqlx/v2 v2.8.0 h1:f/oIgoEPjKDKd+RIoeHqexsIQVIbalVmT+axwvUqQUg= github.com/scylladb/gocqlx/v2 v2.8.0/go.mod h1:4/+cga34PVqjhgSoo5Nr2fX1MQIqZB5eCE5DK4xeDig= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= diff --git a/pkg/jobs/pump.go b/pkg/burst/pump.go similarity index 52% rename from pkg/jobs/pump.go rename to pkg/burst/pump.go index c929f8ce..9cfbc61b 100644 --- a/pkg/jobs/pump.go +++ b/pkg/burst/pump.go @@ -12,40 +12,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -package jobs +package burst import ( + "context" + "math/rand/v2" "time" +) - "github.com/scylladb/gemini/pkg/stop" +const ChannelSize = 10000 - "go.uber.org/zap" - "golang.org/x/exp/rand" -) +func work(ctx context.Context, pump chan<- time.Duration, chance int, sleepDuration time.Duration) { + defer close(pump) + for { + select { + case <-ctx.Done(): + return + default: + sleep := time.Duration(0) -func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration { - pump := make(chan time.Duration, 10000) - logger = logger.Named("Pump") - go func() { - logger.Debug("pump channel opened") - defer func() { - close(pump) - logger.Debug("pump channel closed") - }() - for !stopFlag.IsHardOrSoft() { - pump <- newHeartBeat() - } - }() + if rand.Int()%chance == 0 { + sleep = sleepDuration + } - return pump + pump <- sleep + } + } } -func newHeartBeat() time.Duration { - r := rand.Intn(10) - switch r { - case 0: - return 10 * time.Millisecond - default: - return 0 - } +func New(ctx context.Context, chance int, sleepDuration time.Duration) chan time.Duration { + pump := make(chan time.Duration, ChannelSize) + go work(ctx, pump, chance, sleepDuration) + return pump } diff --git a/pkg/generators/distribution.go b/pkg/generators/distribution.go new file mode 100644 index 00000000..0d9c8d49 --- /dev/null +++ b/pkg/generators/distribution.go @@ -0,0 +1,59 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generators + +import ( + "math" + "strings" + + "github.com/pkg/errors" + "golang.org/x/exp/rand" + "gonum.org/v1/gonum/stat/distuv" +) + +const ( + StdDistMean = math.MaxUint64 / 2 + OneStdDev = 0.341 * math.MaxUint64 +) + +func ParseDistributionDefault(distribution string, size, seed uint64) (DistributionFunc, error) { + return ParseDistribution(distribution, size, seed, StdDistMean, OneStdDev) +} + +func ParseDistribution(distribution string, size, seed uint64, mu, sigma float64) (DistributionFunc, error) { + switch strings.ToLower(distribution) { + case "zipf": + dist := rand.NewZipf(rand.New(rand.NewSource(seed)), 1.1, 1.1, size) + return func() TokenIndex { + return TokenIndex(dist.Uint64()) + }, nil + case "normal": + dist := distuv.Normal{ + Src: rand.NewSource(seed), + Mu: mu, + Sigma: sigma, + } + return func() TokenIndex { + return TokenIndex(dist.Rand()) + }, nil + case "uniform": + rnd := rand.New(rand.NewSource(seed)) + return func() TokenIndex { + return TokenIndex(rnd.Uint64n(size)) + }, nil + default: + return nil, errors.Errorf("unsupported distribution: %s", distribution) + } +} diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index 500fc759..f3c1e8f5 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -15,12 +15,14 @@ package generators import ( + "context" + "sync/atomic" + "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/exp/rand" "github.com/scylladb/gemini/pkg/routingkey" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" ) @@ -37,14 +39,15 @@ type TokenIndex uint64 type DistributionFunc func() TokenIndex -type GeneratorInterface interface { +type Interface interface { Get() *typedef.ValueWithToken GetOld() *typedef.ValueWithToken - GiveOld(_ *typedef.ValueWithToken) - GiveOlds(_ []*typedef.ValueWithToken) - ReleaseToken(_ uint64) + GiveOld(...*typedef.ValueWithToken) + ReleaseToken(uint64) } +var _ Interface = &Generator{} + type Generator struct { logger *zap.Logger table *typedef.Table @@ -56,22 +59,14 @@ type Generator struct { partitionsConfig typedef.PartitionRangeConfig partitionCount uint64 - cntCreated uint64 - cntEmitted uint64 + cntCreated atomic.Uint64 + cntEmitted atomic.Uint64 } func (g *Generator) PartitionCount() uint64 { return g.partitionCount } -type Generators []*Generator - -func (g Generators) StartAll(stopFlag *stop.Flag) { - for _, gen := range g { - gen.Start(stopFlag) - } -} - type Config struct { PartitionsDistributionFunc DistributionFunc PartitionsRangeConfig typedef.PartitionRangeConfig @@ -80,8 +75,9 @@ type Config struct { PkUsedBufferSize uint64 } -func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Generator { - wakeUpSignal := make(chan struct{}) +func NewGenerator(table *typedef.Table, config Config, logger *zap.Logger) *Generator { + wakeUpSignal := make(chan struct{}, int(config.PartitionsCount)) + return &Generator{ partitions: NewPartitions(int(config.PartitionsCount), int(config.PkUsedBufferSize), wakeUpSignal), partitionCount: config.PartitionsCount, @@ -96,12 +92,18 @@ func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Gen } func (g *Generator) Get() *typedef.ValueWithToken { - targetPart := g.GetPartitionForToken(g.idxFunc()) - for targetPart.Stale() { - targetPart = g.GetPartitionForToken(g.idxFunc()) + var out *typedef.ValueWithToken + + for out == nil { + targetPart := g.GetPartitionForToken(g.idxFunc()) + for targetPart.Stale() { + targetPart = g.GetPartitionForToken(g.idxFunc()) + } + out = targetPart.get() } - out := targetPart.get() + return out + } func (g *Generator) GetPartitionForToken(token TokenIndex) *Partition { @@ -111,22 +113,23 @@ func (g *Generator) GetPartitionForToken(token TokenIndex) *Partition { // GetOld returns a previously used value and token or a new if // the old queue is empty. func (g *Generator) GetOld() *typedef.ValueWithToken { - targetPart := g.GetPartitionForToken(g.idxFunc()) - for targetPart.Stale() { - targetPart = g.GetPartitionForToken(g.idxFunc()) + var out *typedef.ValueWithToken + + for out == nil { + targetPart := g.GetPartitionForToken(g.idxFunc()) + for targetPart.Stale() { + targetPart = g.GetPartitionForToken(g.idxFunc()) + } + out = targetPart.getOld() } - return targetPart.getOld() -} -// GiveOld returns the supplied value for later reuse unless -func (g *Generator) GiveOld(v *typedef.ValueWithToken) { - g.GetPartitionForToken(TokenIndex(v.Token)).giveOld(v) + return out } -// GiveOlds returns the supplied values for later reuse unless -func (g *Generator) GiveOlds(tokens []*typedef.ValueWithToken) { +// GiveOld returns the supplied value for later reuse +func (g *Generator) GiveOld(tokens ...*typedef.ValueWithToken) { for _, token := range tokens { - g.GiveOld(token) + g.GetPartitionForToken(TokenIndex(token.Token)).giveOld(token) } } @@ -135,22 +138,27 @@ func (g *Generator) ReleaseToken(token uint64) { g.GetPartitionForToken(TokenIndex(token)).releaseToken(token) } -func (g *Generator) Start(stopFlag *stop.Flag) { - go func() { - g.logger.Info("starting partition key generation loop") - defer g.partitions.CloseAll() - for { - g.fillAllPartitions(stopFlag) - select { - case <-stopFlag.SignalChannel(): - g.logger.Debug("stopping partition key generation loop", - zap.Uint64("keys_created", g.cntCreated), - zap.Uint64("keys_emitted", g.cntEmitted)) - return - case <-g.wakeUpSignal: - } +func (g *Generator) start(ctx context.Context) { + g.logger.Info("starting partition key generation loop") + defer func() { + g.logger.Info("stopping partition key generation loop", + zap.Uint64("keys_created", g.cntCreated.Load()), + zap.Uint64("keys_emitted", g.cntEmitted.Load()), + ) + + if err := g.partitions.Close(); err != nil { + g.logger.Error("failed to close partitions", zap.Error(err)) } }() + + for { + select { + case <-ctx.Done(): + return + case <-g.wakeUpSignal: + g.fillAllPartitions(ctx) + } + } } func (g *Generator) FindAndMarkStalePartitions() { @@ -172,45 +180,41 @@ func (g *Generator) FindAndMarkStalePartitions() { } } +func (g *Generator) fillPartition() { + // Be a bit smarter on how to fill partitions + + values := CreatePartitionKeyValues(g.table, g.r, &g.partitionsConfig) + token, err := g.routingKeyCreator.GetHash(g.table, values) + if err != nil { + g.logger.Panic("failed to get primary key hash", zap.Error(err)) + } + g.cntCreated.Add(1) + idx := token % g.partitionCount + partition := g.partitions[idx] + if partition.Stale() || partition.inFlight.Has(token) { + return + } + select { + case partition.values <- &typedef.ValueWithToken{Token: token, Value: values}: + g.cntEmitted.Add(1) + default: + } + + return +} + // fillAllPartitions guarantees that each partition was tested to be full // at least once since the function started and before it ended. // In other words no partition will be starved. -func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) { - pFilled := make([]bool, len(g.partitions)) - allFilled := func() bool { - for idx, filled := range pFilled { - if !filled { - if g.partitions[idx].Stale() { - continue - } - return false - } - } - return true - } - for !stopFlag.IsHardOrSoft() { - values := CreatePartitionKeyValues(g.table, g.r, &g.partitionsConfig) - token, err := g.routingKeyCreator.GetHash(g.table, values) - if err != nil { - g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error()) - } - g.cntCreated++ - idx := token % g.partitionCount - partition := g.partitions[idx] - if partition.Stale() || partition.inFlight.Has(token) { - continue - } +func (g *Generator) fillAllPartitions(ctx context.Context) { + for { select { - case partition.values <- &typedef.ValueWithToken{Token: token, Value: values}: - g.cntEmitted++ + case <-ctx.Done(): + return default: - if !pFilled[idx] { - pFilled[idx] = true - if allFilled() { - return - } - } } + + g.fillPartition() } } diff --git a/pkg/generators/generator_test.go b/pkg/generators/generator_test.go index 3a46551c..26240ae9 100644 --- a/pkg/generators/generator_test.go +++ b/pkg/generators/generator_test.go @@ -15,24 +15,28 @@ package generators_test import ( + "context" "sync/atomic" "testing" "go.uber.org/zap" "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" ) func TestGenerator(t *testing.T) { t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + table := &typedef.Table{ Name: "tbl", PartitionKeys: generators.CreatePkColumns(1, "pk"), } var current uint64 - cfg := &generators.Config{ + cfg := generators.Config{ PartitionsRangeConfig: typedef.PartitionRangeConfig{ MaxStringLength: 10, MinStringLength: 0, @@ -47,7 +51,7 @@ func TestGenerator(t *testing.T) { } logger, _ := zap.NewDevelopment() generator := generators.NewGenerator(table, cfg, logger) - generator.Start(stop.NewFlag("main_test")) + generator.start(ctx) for i := uint64(0); i < cfg.PartitionsCount; i++ { atomic.StoreUint64(¤t, i) v := generator.Get() diff --git a/cmd/gemini/generators.go b/pkg/generators/generators.go similarity index 56% rename from cmd/gemini/generators.go rename to pkg/generators/generators.go index 9b4f51de..f6e6323a 100644 --- a/cmd/gemini/generators.go +++ b/pkg/generators/generators.go @@ -11,47 +11,79 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package main + +package generators import ( - "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/typedef" + "context" + "sync" "go.uber.org/zap" + + "github.com/scylladb/gemini/pkg/typedef" ) -func createGenerators( +type Generators struct { + wg sync.WaitGroup + generators []*Generator + cancel context.CancelFunc + idx int +} + +func (g *Generators) Get() *Generator { + gen := g.generators[g.idx%len(g.generators)] + g.idx++ + return gen +} + +func (g *Generators) Close() error { + g.cancel() + g.wg.Wait() + + return nil +} + +func New( + ctx context.Context, schema *typedef.Schema, schemaConfig typedef.SchemaConfig, - seed, distributionSize uint64, + seed, partitionsCount uint64, logger *zap.Logger, -) (generators.Generators, error) { + distFunc DistributionFunc, + pkBufferReuseSize uint64, +) (*Generators, error) { partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() + ctx, cancel := context.WithCancel(ctx) - var gs []*generators.Generator - for id := range schema.Tables { - table := schema.Tables[id] - pkVariations := table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig) + gens := &Generators{ + generators: make([]*Generator, 0, len(schema.Tables)), + cancel: cancel, + } + gens.wg.Add(len(schema.Tables)) - distFunc, err := createDistributionFunc(partitionKeyDistribution, distributionSize, seed, stdDistMean, oneStdDev) - if err != nil { - return nil, err - } + for _, table := range schema.Tables { + pkVariations := table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig) - tablePartConfig := &generators.Config{ + tablePartConfig := Config{ PartitionsRangeConfig: partitionRangeConfig, - PartitionsCount: distributionSize, + PartitionsCount: partitionsCount, PartitionsDistributionFunc: distFunc, Seed: seed, PkUsedBufferSize: pkBufferReuseSize, } - g := generators.NewGenerator(table, tablePartConfig, logger.Named("generators")) + g := NewGenerator(table, tablePartConfig, logger.Named("generators")) + go func() { + defer gens.wg.Done() + g.start(ctx) + }() if pkVariations < 2^32 { // Low partition key variation can lead to having staled partitions // Let's detect and mark them before running test g.FindAndMarkStalePartitions() } - gs = append(gs, g) + + gens.generators = append(gens.generators, g) } - return gs, nil + + return gens, nil } diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index e70d46c2..dd6c3817 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -15,7 +15,7 @@ package generators import ( - "sync" + "sync/atomic" "github.com/scylladb/gemini/pkg/inflight" "github.com/scylladb/gemini/pkg/typedef" @@ -26,29 +26,36 @@ type Partition struct { oldValues chan *typedef.ValueWithToken inFlight inflight.InFlight wakeUpSignal chan<- struct{} // wakes up generator - closed bool - lock sync.RWMutex - isStale bool + closed atomic.Bool + isStale atomic.Bool +} + +func NewPartition(wakeUpSignal chan<- struct{}, pkBufferSize int) *Partition { + return &Partition{ + values: make(chan *typedef.ValueWithToken, pkBufferSize), + oldValues: make(chan *typedef.ValueWithToken, pkBufferSize), + inFlight: inflight.New(), + wakeUpSignal: wakeUpSignal, + } } func (s *Partition) MarkStale() { - s.isStale = true + s.isStale.Store(true) s.Close() } func (s *Partition) Stale() bool { - return s.isStale + return s.isStale.Load() } // get returns a new value and ensures that it's corresponding token // is not already in-flight. func (s *Partition) get() *typedef.ValueWithToken { - for { - v := s.pick() - if v == nil || s.inFlight.AddIfNotPresent(v.Token) { - return v - } + if v := s.pick(); v != nil && !s.inFlight.Has(v.Token) { + return v } + + return nil } // getOld returns a previously used value and token or a new if @@ -66,12 +73,12 @@ func (s *Partition) getOld() *typedef.ValueWithToken { // is empty in which case it removes the corresponding token from the // in-flight tracking. func (s *Partition) giveOld(v *typedef.ValueWithToken) { - ch := s.safelyGetOldValuesChannel() - if ch == nil { + if s.closed.Load() { return } + select { - case ch <- v: + case s.oldValues <- v: default: // Old partition buffer is full, just drop the value } @@ -98,55 +105,15 @@ func (s *Partition) pick() *typedef.ValueWithToken { return val default: s.wakeUp() // channel empty, need to wait for new values - return <-s.values - } -} - -func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken { - s.lock.RLock() - if s.closed { - // Since only giveOld could have been potentially called after partition is closed - // we need to protect it against writing to closed channel return nil } - defer s.lock.RUnlock() - return s.oldValues } -func (s *Partition) Close() { - s.lock.RLock() - if s.closed { - s.lock.RUnlock() - return - } - s.lock.RUnlock() - s.lock.Lock() - if s.closed { - return - } - s.closed = true - close(s.values) - close(s.oldValues) - s.lock.Unlock() -} - -type Partitions []*Partition - -func (p Partitions) CloseAll() { - for _, part := range p { - part.Close() +func (s *Partition) Close() error { + if !s.closed.Swap(true) { + close(s.values) + close(s.oldValues) } -} -func NewPartitions(count, pkBufferSize int, wakeUpSignal chan struct{}) Partitions { - partitions := make(Partitions, count) - for i := 0; i < len(partitions); i++ { - partitions[i] = &Partition{ - values: make(chan *typedef.ValueWithToken, pkBufferSize), - oldValues: make(chan *typedef.ValueWithToken, pkBufferSize), - inFlight: inflight.New(), - wakeUpSignal: wakeUpSignal, - } - } - return partitions + return nil } diff --git a/pkg/generators/partitions.go b/pkg/generators/partitions.go new file mode 100644 index 00000000..b0eec362 --- /dev/null +++ b/pkg/generators/partitions.go @@ -0,0 +1,25 @@ +package generators + +import "go.uber.org/multierr" + +type Partitions []*Partition + +func (p Partitions) Close() error { + var err error + + for _, part := range p { + err = multierr.Append(err, part.Close()) + } + + return err +} + +func NewPartitions(count, pkBufferSize int, wakeUpSignal chan<- struct{}) Partitions { + partitions := make(Partitions, 0, count) + + for i := 0; i < count; i++ { + partitions = append(partitions, NewPartition(wakeUpSignal, pkBufferSize)) + } + + return partitions +} diff --git a/pkg/generators/statement_generator.go b/pkg/generators/statement_generator.go index 25503a89..902af751 100644 --- a/pkg/generators/statement_generator.go +++ b/pkg/generators/statement_generator.go @@ -87,7 +87,7 @@ func genTable(sc typedef.SchemaConfig, tableName string, r *rand.Rand) *typedef. table.Indexes = indexes var mvs []typedef.MaterializedView - if sc.CQLFeature > typedef.CQL_FEATURE_BASIC && sc.UseMaterializedViews && len(clusteringKeys) > 0 && columns.ValidColumnsForPrimaryKey().Len() != 0 { + if sc.UseMaterializedViews && sc.CQLFeature > typedef.CQL_FEATURE_BASIC && len(clusteringKeys) > 0 && columns.ValidColumnsForPrimaryKey().Len() != 0 { mvs = CreateMaterializedViews(columns, table.Name, partitionKeys, clusteringKeys, r) } @@ -101,7 +101,7 @@ func GetCreateKeyspaces(s *typedef.Schema) (string, string) { fmt.Sprintf("CREATE KEYSPACE IF NOT EXISTS %s WITH REPLICATION = %s", s.Keyspace.Name, s.Keyspace.OracleReplication.ToCQL()) } -func GetCreateSchema(s *typedef.Schema) []string { +func GetCreateSchema(s *typedef.Schema, enableMV bool) []string { var stmts []string for _, t := range s.Tables { @@ -124,17 +124,21 @@ func GetCreateSchema(s *typedef.Schema) []string { for _, ck := range mv.ClusteringKeys { mvPrimaryKeysNotNull = append(mvPrimaryKeysNotNull, fmt.Sprintf("%s IS NOT NULL", ck.Name)) } - var createMaterializedView string - if len(mv.PartitionKeys) == 1 { - createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY (%s" - } else { - createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY ((%s)" + + if enableMV { + var createMaterializedView string + if len(mv.PartitionKeys) == 1 { + createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY (%s" + } else { + createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY ((%s)" + } + createMaterializedView += ",%s)" + stmts = append(stmts, fmt.Sprintf(createMaterializedView, + s.Keyspace.Name, mv.Name, s.Keyspace.Name, t.Name, + strings.Join(mvPrimaryKeysNotNull, " AND "), + strings.Join(mvPartitionKeys, ","), strings.Join(t.ClusteringKeys.Names(), ",")), + ) } - createMaterializedView += ",%s)" - stmts = append(stmts, fmt.Sprintf(createMaterializedView, - s.Keyspace.Name, mv.Name, s.Keyspace.Name, t.Name, - strings.Join(mvPrimaryKeysNotNull, " AND "), - strings.Join(mvPartitionKeys, ","), strings.Join(t.ClusteringKeys.Names(), ","))) } } return stmts diff --git a/pkg/jobs/gen_check_stmt.go b/pkg/generators/statements/gen_check_stmt.go similarity index 77% rename from pkg/jobs/gen_check_stmt.go rename to pkg/generators/statements/gen_check_stmt.go index 3822b409..9dea12a3 100644 --- a/pkg/jobs/gen_check_stmt.go +++ b/pkg/generators/statements/gen_check_stmt.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package jobs +package statements import ( "math" @@ -28,94 +28,105 @@ import ( func GenCheckStmt( s *typedef.Schema, table *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, rnd *rand.Rand, p *typedef.PartitionRangeConfig, -) *typedef.Stmt { - n := 0 - mvNum := -1 - maxClusteringRels := 0 - numQueryPKs := 0 - if len(table.MaterializedViews) > 0 && rnd.Int()%2 == 0 { - mvNum = utils.RandInt2(rnd, 0, len(table.MaterializedViews)) - } - - switch mvNum { - case -1: - if len(table.Indexes) > 0 { - n = rnd.Intn(5) - } else { - n = rnd.Intn(4) - } - switch n { - case 0: - return genSinglePartitionQuery(s, table, g) - case 1: - numQueryPKs = utils.RandInt2(rnd, 1, table.PartitionKeys.Len()) - multiplier := int(math.Pow(float64(numQueryPKs), float64(table.PartitionKeys.Len()))) - if multiplier > 100 { - numQueryPKs = 1 - } - return genMultiplePartitionQuery(s, table, g, numQueryPKs) - case 2: - maxClusteringRels = utils.RandInt2(rnd, 0, table.ClusteringKeys.Len()) - return genClusteringRangeQuery(s, table, g, rnd, p, maxClusteringRels) - case 3: - numQueryPKs = utils.RandInt2(rnd, 1, table.PartitionKeys.Len()) - multiplier := int(math.Pow(float64(numQueryPKs), float64(table.PartitionKeys.Len()))) - if multiplier > 100 { - numQueryPKs = 1 - } - maxClusteringRels = utils.RandInt2(rnd, 0, table.ClusteringKeys.Len()) - return genMultiplePartitionClusteringRangeQuery(s, table, g, rnd, p, numQueryPKs, maxClusteringRels) - case 4: - // Reducing the probability to hit these since they often take a long time to run - switch rnd.Intn(5) { - case 0: - idxCount := utils.RandInt2(rnd, 1, len(table.Indexes)) - return genSingleIndexQuery(s, table, g, rnd, p, idxCount) - default: - return genSinglePartitionQuery(s, table, g) - } - } - default: - n = rnd.Intn(4) - switch n { - case 0: - return genSinglePartitionQueryMv(s, table, g, rnd, p, mvNum) - case 1: - lenPartitionKeys := table.MaterializedViews[mvNum].PartitionKeys.Len() - numQueryPKs = utils.RandInt2(rnd, 1, lenPartitionKeys) - multiplier := int(math.Pow(float64(numQueryPKs), float64(lenPartitionKeys))) - if multiplier > 100 { - numQueryPKs = 1 - } - return genMultiplePartitionQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs) - case 2: - lenClusteringKeys := table.MaterializedViews[mvNum].ClusteringKeys.Len() - maxClusteringRels = utils.RandInt2(rnd, 0, lenClusteringKeys) - return genClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, maxClusteringRels) - case 3: - lenPartitionKeys := table.MaterializedViews[mvNum].PartitionKeys.Len() - numQueryPKs = utils.RandInt2(rnd, 1, lenPartitionKeys) - multiplier := int(math.Pow(float64(numQueryPKs), float64(lenPartitionKeys))) - if multiplier > 100 { - numQueryPKs = 1 +) (*typedef.Stmt, func()) { + var stmt *typedef.Stmt + + if shouldGenerateCheckStatementForMV(table, rnd) { + stmt = genCheckStmtMV(s, table, g, rnd, p) + } else { + stmt = genCheckTableStmt(s, table, g, rnd, p) + } + + return stmt, func() { + if stmt != nil && stmt.ValuesWithToken != nil { + for _, v := range stmt.ValuesWithToken { + g.ReleaseToken(v.Token) } - lenClusteringKeys := table.MaterializedViews[mvNum].ClusteringKeys.Len() - maxClusteringRels = utils.RandInt2(rnd, 0, lenClusteringKeys) - return genMultiplePartitionClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs, maxClusteringRels) } } +} + +// shouldGenerateCheckStatementForMV should be true if we have materialized views +// and the random number is even. So this means that we have a 50% chance of +// checking materialized views. +func shouldGenerateCheckStatementForMV(table *typedef.Table, rnd *rand.Rand) bool { + return len(table.MaterializedViews) > 0 && rnd.Int()%2 == 0 +} - return nil +func genCheckStmtMV(s *typedef.Schema, table *typedef.Table, g generators.Interface, rnd *rand.Rand, p *typedef.PartitionRangeConfig) *typedef.Stmt { + mvNum := utils.RandInt2(rnd, 0, len(table.MaterializedViews)) + lenClusteringKeys := table.MaterializedViews[mvNum].ClusteringKeys.Len() + lenPartitionKeys := table.MaterializedViews[mvNum].PartitionKeys.Len() + + maxClusteringRels := utils.RandInt2(rnd, 0, lenClusteringKeys) + numQueryPKs := utils.RandInt2(rnd, 1, lenPartitionKeys) + if int(math.Pow(float64(numQueryPKs), float64(lenPartitionKeys))) > 100 { + numQueryPKs = 1 + } + + switch rnd.Intn(4) { + case 0: + return genSinglePartitionQueryMv(s, table, g, rnd, p, mvNum) + case 1: + return genMultiplePartitionQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs) + case 2: + return genClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, maxClusteringRels) + case 3: + return genMultiplePartitionClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs, maxClusteringRels) + default: + panic("random number generator does not work correctly, unreachable statement") + } } -func genSinglePartitionQuery( +func genCheckTableStmt( s *typedef.Schema, - t *typedef.Table, - g generators.GeneratorInterface, + table *typedef.Table, + g generators.Interface, + rnd *rand.Rand, + p *typedef.PartitionRangeConfig, ) *typedef.Stmt { + var n int + + if len(table.Indexes) > 0 { + n = rnd.Intn(5) + } else { + n = rnd.Intn(4) + } + + maxClusteringRels := utils.RandInt2(rnd, 0, table.ClusteringKeys.Len()) + numQueryPKs := utils.RandInt2(rnd, 1, table.PartitionKeys.Len()) + multiplier := int(math.Pow(float64(numQueryPKs), float64(table.PartitionKeys.Len()))) + if multiplier > 100 { + numQueryPKs = 1 + } + + switch n { + case 0: + return genSinglePartitionQuery(s, table, g) + case 1: + return genMultiplePartitionQuery(s, table, g, numQueryPKs) + case 2: + return genClusteringRangeQuery(s, table, g, rnd, p, maxClusteringRels) + case 3: + return genMultiplePartitionClusteringRangeQuery(s, table, g, rnd, p, numQueryPKs, maxClusteringRels) + case 4: + // Reducing the probability to hit these since they often take a long time to run + // One in five chance to hit this + if rnd.Intn(5) == 0 { + idxCount := utils.RandInt2(rnd, 1, len(table.Indexes)) + return genSingleIndexQuery(s, table, g, rnd, p, idxCount) + } + + return genSinglePartitionQuery(s, table, g) + default: + panic("random number generator does not work correctly, unreachable statement") + } +} + +func genSinglePartitionQuery(s *typedef.Schema, t *typedef.Table, g generators.Interface) *typedef.Stmt { t.RLock() defer t.RUnlock() valuesWithToken := g.GetOld() @@ -124,7 +135,8 @@ func genSinglePartitionQuery( } values := valuesWithToken.Value.Copy() builder := qb.Select(s.Keyspace.Name + "." + t.Name) - typs := make([]typedef.Type, 0, 10) + typs := make([]typedef.Type, 0, len(t.PartitionKeys)) + for _, pk := range t.PartitionKeys { builder = builder.Where(qb.Eq(pk.Name)) typs = append(typs, pk.Type) @@ -144,7 +156,7 @@ func genSinglePartitionQuery( func genSinglePartitionQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum int, @@ -183,7 +195,7 @@ func genSinglePartitionQueryMv( func genMultiplePartitionQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, numQueryPKs int, ) *typedef.Stmt { t.RLock() @@ -197,7 +209,7 @@ func genMultiplePartitionQuery( for j := 0; j < numQueryPKs; j++ { vs := g.GetOld() if vs == nil { - g.GiveOlds(tokens) + g.GiveOld(tokens...) return nil } tokens = append(tokens, vs) @@ -223,7 +235,7 @@ func genMultiplePartitionQuery( func genMultiplePartitionQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, numQueryPKs int, @@ -241,7 +253,7 @@ func genMultiplePartitionQueryMv( for j := 0; j < numQueryPKs; j++ { vs := g.GetOld() if vs == nil { - g.GiveOlds(tokens) + g.GiveOld(tokens...) return nil } tokens = append(tokens, vs) @@ -274,7 +286,7 @@ func genMultiplePartitionQueryMv( func genClusteringRangeQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, maxClusteringRels int, @@ -321,7 +333,7 @@ func genClusteringRangeQuery( func genClusteringRangeQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, maxClusteringRels int, @@ -374,7 +386,7 @@ func genClusteringRangeQueryMv( func genMultiplePartitionClusteringRangeQuery( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, numQueryPKs, maxClusteringRels int, @@ -397,7 +409,7 @@ func genMultiplePartitionClusteringRangeQuery( for j := 0; j < numQueryPKs; j++ { vs := g.GetOld() if vs == nil { - g.GiveOlds(tokens) + g.GiveOld(tokens...) return nil } tokens = append(tokens, vs) @@ -435,7 +447,7 @@ func genMultiplePartitionClusteringRangeQuery( func genMultiplePartitionClusteringRangeQueryMv( s *typedef.Schema, t *typedef.Table, - g generators.GeneratorInterface, + g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, mvNum, numQueryPKs, maxClusteringRels int, @@ -478,7 +490,7 @@ func genMultiplePartitionClusteringRangeQueryMv( for j := 0; j < numQueryPKs; j++ { vs := g.GetOld() if vs == nil { - g.GiveOlds(tokens) + g.GiveOld(tokens...) return nil } tokens = append(tokens, vs) @@ -516,7 +528,7 @@ func genMultiplePartitionClusteringRangeQueryMv( func genSingleIndexQuery( s *typedef.Schema, t *typedef.Table, - _ generators.GeneratorInterface, + _ generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, idxCount int, diff --git a/pkg/generators/statements/gen_check_stmt_test.go b/pkg/generators/statements/gen_check_stmt_test.go new file mode 100644 index 00000000..55b28468 --- /dev/null +++ b/pkg/generators/statements/gen_check_stmt_test.go @@ -0,0 +1,332 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//nolint:thelper +package statements + +import ( + "path" + "testing" + + "github.com/scylladb/gemini/pkg/testutils" +) + +const checkDataPath = "./test_expected_data/check/" + +func TestGenSinglePartitionQuery(t *testing.T) { + RunStmtTest[results](t, path.Join(checkDataPath, "single_partition.json"), genSinglePartitionQueryCases, + func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { + schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) + stmt := genSinglePartitionQuery(schema, schema.Tables[0], gen) + validateStmt(subT, stmt, nil) + expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) + }) +} + +// func TestGenSinglePartitionQueryMv(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "single_partition_mv.json"), genSinglePartitionQueryMvCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// stmt := genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func TestGenMultiplePartitionQuery(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition.json"), genMultiplePartitionQueryCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) +// options := testutils.GetOptionsFromCaseName(caseName) +// stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func TestGenMultiplePartitionQueryMv(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_mv.json"), genMultiplePartitionQueryMvCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) +// stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func TestGenClusteringRangeQuery(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range.json"), genClusteringRangeQueryCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// options := testutils.GetOptionsFromCaseName(caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// stmt := genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func TestGenClusteringRangeQueryMv(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range_mv.json"), genClusteringRangeQueryMvCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// options := testutils.GetOptionsFromCaseName(caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// stmt := genClusteringRangeQueryMv( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// len(schema.Tables[0].MaterializedViews)-1, +// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func TestGenMultiplePartitionClusteringRangeQuery(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range.json"), genMultiplePartitionClusteringRangeQueryCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// options := testutils.GetOptionsFromCaseName(caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// stmt := genMultiplePartitionClusteringRangeQuery( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), +// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func TestGenMultiplePartitionClusteringRangeQueryMv(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range_mv.json"), genMultiplePartitionClusteringRangeQueryMvCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// stmt := genMultiplePartitionClusteringRangeQueryMv( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// len(schema.Tables[0].MaterializedViews)-1, +// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), +// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func TestGenSingleIndexQuery(t *testing.T) { +// RunStmtTest[results](t, path.Join(checkDataPath, "single_index.json"), genSingleIndexQueryCases, +// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// stmt := genSingleIndexQuery(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].Indexes)) +// validateStmt(subT, stmt, nil) +// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) +// }) +// } + +// func BenchmarkGenSinglePartitionQuery(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genSinglePartitionQueryCases { +// caseName := genSinglePartitionQueryCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genSinglePartitionQuery(schema, schema.Tables[0], gen) +// } +// }) +// } +// } + +// func BenchmarkGenSinglePartitionQueryMv(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genSinglePartitionQueryMvCases { +// caseName := genSinglePartitionQueryMvCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1) +// } +// }) +// } +// } + +// func BenchmarkGenMultiplePartitionQuery(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genMultiplePartitionQueryCases { +// caseName := genMultiplePartitionQueryCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) +// } +// }) +// } +// } + +// func BenchmarkGenMultiplePartitionQueryMv(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genMultiplePartitionQueryMvCases { +// caseName := genMultiplePartitionQueryMvCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genMultiplePartitionQueryMv( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// len(schema.Tables[0].MaterializedViews)-1, +// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) +// } +// }) +// } +// } + +// func BenchmarkGenClusteringRangeQuery(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genClusteringRangeQueryCases { +// caseName := genClusteringRangeQueryCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// } +// }) +// } +// } + +// func BenchmarkGenClusteringRangeQueryMv(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genClusteringRangeQueryMvCases { +// caseName := genClusteringRangeQueryMvCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genClusteringRangeQueryMv( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// len(schema.Tables[0].MaterializedViews)-1, +// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// } +// }) +// } +// } + +// func BenchmarkGenMultiplePartitionClusteringRangeQuery(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genMultiplePartitionClusteringRangeQueryCases { +// caseName := genMultiplePartitionClusteringRangeQueryCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genMultiplePartitionClusteringRangeQuery( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), +// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// } +// }) +// } +// } + +// func BenchmarkGenMultiplePartitionClusteringRangeQueryMv(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genMultiplePartitionClusteringRangeQueryMvCases { +// caseName := genMultiplePartitionClusteringRangeQueryMvCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// options := testutils.GetOptionsFromCaseName(caseName) +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genMultiplePartitionClusteringRangeQueryMv( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// len(schema.Tables[0].MaterializedViews)-1, +// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), +// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) +// } +// }) +// } +// } + +// func BenchmarkGenSingleIndexQuery(t *testing.B) { +// utils.SetUnderTest() +// for idx := range genSingleIndexQueryCases { +// caseName := genSingleIndexQueryCases[idx] +// t.Run(caseName, +// func(subT *testing.B) { +// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) +// prc := schema.Config.GetPartitionRangeConfig() +// subT.ResetTimer() +// for x := 0; x < subT.N; x++ { +// _ = genSingleIndexQuery( +// schema, +// schema.Tables[0], +// gen, +// rnd, +// &prc, +// len(schema.Tables[0].Indexes)) +// } +// }) +// } +// } diff --git a/pkg/jobs/gen_const_test.go b/pkg/generators/statements/gen_const_test.go similarity index 99% rename from pkg/jobs/gen_const_test.go rename to pkg/generators/statements/gen_const_test.go index 1cf5922c..f602f3ef 100644 --- a/pkg/jobs/gen_const_test.go +++ b/pkg/generators/statements/gen_const_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package jobs +package statements var ( genInsertStmtCases = []string{ diff --git a/pkg/jobs/gen_ddl_stmt.go b/pkg/generators/statements/gen_ddl_stmt.go similarity index 96% rename from pkg/jobs/gen_ddl_stmt.go rename to pkg/generators/statements/gen_ddl_stmt.go index a8866fd4..9ec0ec08 100644 --- a/pkg/jobs/gen_ddl_stmt.go +++ b/pkg/generators/statements/gen_ddl_stmt.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package jobs +package statements import ( "fmt" @@ -25,7 +25,7 @@ import ( "github.com/scylladb/gemini/pkg/typedef" ) -func GenDDLStmt(s *typedef.Schema, t *typedef.Table, r *rand.Rand, _ *typedef.PartitionRangeConfig, sc *typedef.SchemaConfig) (*typedef.Stmts, error) { +func GenDDLStmt(s *typedef.Schema, t *typedef.Table, r *rand.Rand, sc *typedef.SchemaConfig) (*typedef.Stmts, error) { maxVariant := 1 validCols := t.ValidColumnsForDelete() if validCols.Len() > 0 { diff --git a/pkg/jobs/gen_ddl_stmt_test.go b/pkg/generators/statements/gen_ddl_stmt_test.go similarity index 99% rename from pkg/jobs/gen_ddl_stmt_test.go rename to pkg/generators/statements/gen_ddl_stmt_test.go index e5f81a73..6643a49c 100644 --- a/pkg/jobs/gen_ddl_stmt_test.go +++ b/pkg/generators/statements/gen_ddl_stmt_test.go @@ -13,7 +13,7 @@ // limitations under the License. //nolint:thelper -package jobs +package statements import ( "fmt" diff --git a/pkg/jobs/gen_mutate_stmt.go b/pkg/generators/statements/gen_mutate_stmt.go similarity index 97% rename from pkg/jobs/gen_mutate_stmt.go rename to pkg/generators/statements/gen_mutate_stmt.go index 6f65caab..f727da77 100644 --- a/pkg/jobs/gen_mutate_stmt.go +++ b/pkg/generators/statements/gen_mutate_stmt.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package jobs +package statements import ( "encoding/json" @@ -26,7 +26,7 @@ import ( "github.com/scylladb/gemini/pkg/typedef" ) -func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.GeneratorInterface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) { +func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) { t.RLock() defer t.RUnlock() diff --git a/pkg/jobs/gen_mutate_stmt_test.go b/pkg/generators/statements/gen_mutate_stmt_test.go similarity index 99% rename from pkg/jobs/gen_mutate_stmt_test.go rename to pkg/generators/statements/gen_mutate_stmt_test.go index 006f5329..1fc1e9e2 100644 --- a/pkg/jobs/gen_mutate_stmt_test.go +++ b/pkg/generators/statements/gen_mutate_stmt_test.go @@ -13,7 +13,7 @@ // limitations under the License. //nolint:thelper -package jobs +package statements import ( "path" diff --git a/pkg/generators/statements/gen_utils_test.go b/pkg/generators/statements/gen_utils_test.go new file mode 100644 index 00000000..bbce4bbb --- /dev/null +++ b/pkg/generators/statements/gen_utils_test.go @@ -0,0 +1,267 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statements + +import ( + "fmt" + "strings" + "testing" + + "github.com/scylladb/gemini/pkg/testutils" + "github.com/scylladb/gemini/pkg/typedef" + "github.com/scylladb/gemini/pkg/utils" +) + +type resultToken struct { + Token string + TokenValues string +} + +func (r resultToken) Equal(received resultToken) bool { + return r.Token == received.Token && r.TokenValues == received.TokenValues +} + +type resultTokens []resultToken + +func (r resultTokens) Equal(received resultTokens) bool { + if len(r) != len(received) { + return false + } + for id, expectedToken := range r { + if !expectedToken.Equal(received[id]) { + return false + } + } + return true +} + +func (r resultTokens) Diff(received resultTokens) string { + var out []string + maxIdx := len(r) + if maxIdx < len(received) { + maxIdx = len(received) + } + var expected, found *resultToken + for idx := 0; idx < maxIdx; idx++ { + if idx < len(r) { + expected = &r[idx] + } else { + expected = &resultToken{} + } + + if idx < len(received) { + found = &received[idx] + } else { + found = &resultToken{} + } + + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + expected.TokenValues, found.TokenValues, " error: value stmt.ValuesWithToken.Token expected and received are different:")) + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + expected.TokenValues, found.TokenValues, " error: value stmt.ValuesWithToken.Value expected and received are different:")) + } + return strings.Join(out, "\n") +} + +type result struct { + Query string + Names string + Values string + Types string + QueryType string + TokenValues resultTokens +} + +func (r *result) Equal(t *result) bool { + var expected result + if r != nil { + expected = *r + } + + var provided result + if t != nil { + provided = *t + } + return expected.Query == provided.Query && + expected.Names == provided.Names && + expected.Values == provided.Values && + expected.Types == provided.Types && + expected.QueryType == provided.QueryType && + expected.TokenValues.Equal(provided.TokenValues) +} + +func (r *result) Diff(received *result) string { + var out []string + out = testutils.AppendIfNotEmpty(out, r.TokenValues.Diff(received.TokenValues)) + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + r.Query, received.Query, " error: value stmt.Query.ToCql().stmt expected and received are different:")) + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + r.Names, received.Names, " error: value stmt.Query.ToCql().Names expected and received are different:")) + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + r.Values, received.Values, " error: value stmt.Values expected and received are different:")) + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + r.Types, received.Types, " error: value stmt.Types expected and received are different:")) + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + r.Values, received.Values, " error: value stmt.Values expected and received are different:")) + out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent( + r.QueryType, received.QueryType, " error: value stmt.QueryType expected and received are different:")) + return strings.Join(out, "\n") +} + +type results []*result + +func (r results) Equal(t results) bool { + return r.Diff(t) == "" +} + +func (r results) Diff(t results) string { + var out []string + maxIdx := len(r) + if maxIdx < len(t) { + maxIdx = len(t) + } + var expected, found *result + for idx := 0; idx < maxIdx; idx++ { + if idx < len(r) { + expected = r[idx] + } else { + expected = &result{} + } + + if idx < len(t) { + found = t[idx] + } else { + found = &result{} + } + + out = testutils.AppendIfNotEmpty(out, expected.Diff(found)) + } + return strings.Join(out, "\n") +} + +func convertStmtsToResults(stmt any) results { + var out results + switch stmts := stmt.(type) { + case *typedef.Stmts: + for idx := range stmts.List { + out = append(out, convertStmtToResults(stmts.List[idx])) + } + case *typedef.Stmt: + out = append(out, convertStmtToResults(stmts)) + } + return out +} + +func convertStmtToResults(stmt *typedef.Stmt) *result { + types := "" + for idx := range stmt.Types { + types = fmt.Sprintf("%s %s", types, stmt.Types[idx].Name()) + } + query, names := stmt.Query.ToCql() + var tokens []resultToken + for _, valueToken := range stmt.ValuesWithToken { + tokens = append(tokens, resultToken{ + Token: fmt.Sprintf("%v", valueToken.Token), + TokenValues: strings.TrimSpace(fmt.Sprintf("%v", valueToken.Value)), + }) + } + + return &result{ + TokenValues: tokens, + Query: strings.TrimSpace(query), + Names: strings.TrimSpace(fmt.Sprintf("%s", names)), + Values: strings.TrimSpace(fmt.Sprintf("%v", stmt.Values)), + Types: types, + QueryType: fmt.Sprintf("%v", stmt.QueryType), + } +} + +func RunStmtTest[T testutils.ExpectedEntry[T]]( + t *testing.T, + filePath string, + cases []string, + testBody func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[T]), +) { + t.Helper() + utils.SetUnderTest() + t.Parallel() + expected := testutils.LoadExpectedFromFile[T](t, filePath, cases, *testutils.UpdateExpectedFlag) + if *testutils.UpdateExpectedFlag { + t.Cleanup(func() { + expected.UpdateExpected(t) + }) + } + for idx := range cases { + caseName := cases[idx] + t.Run(caseName, + func(subT *testing.T) { + subT.Parallel() + testBody(subT, caseName, expected) + }) + } +} + +func GetPkCountFromOptions(options testutils.TestCaseOptions, allValue int) int { + pkCount := 0 + options.HandleOption("cpk", func(option string) { + switch option { + case "cpkAll": + pkCount = allValue + case "cpk1": + pkCount = 1 + } + }) + return pkCount +} + +func GetCkCountFromOptions(options testutils.TestCaseOptions, allValue int) int { + ckCount := -1 + options.HandleOption("cck", func(option string) { + switch option { + case "cckAll": + ckCount = allValue + case "cck1": + ckCount = 0 + } + }) + return ckCount +} + +func validateStmt(t *testing.T, stmt any, err error) { + t.Helper() + if err != nil { + t.Fatalf("error: get an error on create test inputs:%v", err) + } + if stmt == nil { + t.Fatalf("error: stmt is nil") + } + switch stmts := stmt.(type) { + case *typedef.Stmts: + if stmts == nil || stmts.List == nil || len(stmts.List) == 0 { + t.Fatalf("error: stmts is empty") + } + for i := range stmts.List { + if stmts.List[i] == nil || stmts.List[i].Query == nil { + t.Fatalf("error: stmts has nil stmt #%d", i) + } + } + case *typedef.Stmt: + if stmts == nil || stmts.Query == nil { + t.Fatalf("error: stmt is empty") + } + default: + t.Fatalf("error: unkwon type of stmt") + } +} diff --git a/pkg/jobs/test_expected_data/check/clustering_range.json b/pkg/generators/statements/test_expected_data/check/clustering_range.json similarity index 100% rename from pkg/jobs/test_expected_data/check/clustering_range.json rename to pkg/generators/statements/test_expected_data/check/clustering_range.json diff --git a/pkg/jobs/test_expected_data/check/clustering_range_mv.json b/pkg/generators/statements/test_expected_data/check/clustering_range_mv.json similarity index 100% rename from pkg/jobs/test_expected_data/check/clustering_range_mv.json rename to pkg/generators/statements/test_expected_data/check/clustering_range_mv.json diff --git a/pkg/jobs/test_expected_data/check/multiple_partition.json b/pkg/generators/statements/test_expected_data/check/multiple_partition.json similarity index 100% rename from pkg/jobs/test_expected_data/check/multiple_partition.json rename to pkg/generators/statements/test_expected_data/check/multiple_partition.json diff --git a/pkg/jobs/test_expected_data/check/multiple_partition_clustering_range.json b/pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range.json similarity index 100% rename from pkg/jobs/test_expected_data/check/multiple_partition_clustering_range.json rename to pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range.json diff --git a/pkg/jobs/test_expected_data/check/multiple_partition_clustering_range_mv.json b/pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range_mv.json similarity index 100% rename from pkg/jobs/test_expected_data/check/multiple_partition_clustering_range_mv.json rename to pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range_mv.json diff --git a/pkg/jobs/test_expected_data/check/multiple_partition_mv.json b/pkg/generators/statements/test_expected_data/check/multiple_partition_mv.json similarity index 100% rename from pkg/jobs/test_expected_data/check/multiple_partition_mv.json rename to pkg/generators/statements/test_expected_data/check/multiple_partition_mv.json diff --git a/pkg/jobs/test_expected_data/check/single_index.json b/pkg/generators/statements/test_expected_data/check/single_index.json similarity index 100% rename from pkg/jobs/test_expected_data/check/single_index.json rename to pkg/generators/statements/test_expected_data/check/single_index.json diff --git a/pkg/jobs/test_expected_data/check/single_partition.json b/pkg/generators/statements/test_expected_data/check/single_partition.json similarity index 100% rename from pkg/jobs/test_expected_data/check/single_partition.json rename to pkg/generators/statements/test_expected_data/check/single_partition.json diff --git a/pkg/jobs/test_expected_data/check/single_partition_mv.json b/pkg/generators/statements/test_expected_data/check/single_partition_mv.json similarity index 100% rename from pkg/jobs/test_expected_data/check/single_partition_mv.json rename to pkg/generators/statements/test_expected_data/check/single_partition_mv.json diff --git a/pkg/jobs/test_expected_data/ddl/add_column.json b/pkg/generators/statements/test_expected_data/ddl/add_column.json similarity index 100% rename from pkg/jobs/test_expected_data/ddl/add_column.json rename to pkg/generators/statements/test_expected_data/ddl/add_column.json diff --git a/pkg/jobs/test_expected_data/ddl/drop_column.json b/pkg/generators/statements/test_expected_data/ddl/drop_column.json similarity index 100% rename from pkg/jobs/test_expected_data/ddl/drop_column.json rename to pkg/generators/statements/test_expected_data/ddl/drop_column.json diff --git a/pkg/jobs/test_expected_data/mutate/delete.json b/pkg/generators/statements/test_expected_data/mutate/delete.json similarity index 100% rename from pkg/jobs/test_expected_data/mutate/delete.json rename to pkg/generators/statements/test_expected_data/mutate/delete.json diff --git a/pkg/jobs/test_expected_data/mutate/insert.json b/pkg/generators/statements/test_expected_data/mutate/insert.json similarity index 100% rename from pkg/jobs/test_expected_data/mutate/insert.json rename to pkg/generators/statements/test_expected_data/mutate/insert.json diff --git a/pkg/jobs/test_expected_data/mutate/insert_j.json b/pkg/generators/statements/test_expected_data/mutate/insert_j.json similarity index 100% rename from pkg/jobs/test_expected_data/mutate/insert_j.json rename to pkg/generators/statements/test_expected_data/mutate/insert_j.json diff --git a/pkg/jobs/test_expected_data/mutate/update.json b/pkg/generators/statements/test_expected_data/mutate/update.json similarity index 100% rename from pkg/jobs/test_expected_data/mutate/update.json rename to pkg/generators/statements/test_expected_data/mutate/update.json diff --git a/pkg/jobs/ddl.go b/pkg/jobs/ddl.go new file mode 100644 index 00000000..82166baa --- /dev/null +++ b/pkg/jobs/ddl.go @@ -0,0 +1,83 @@ +package jobs + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" + "go.uber.org/zap" + + "github.com/scylladb/gemini/pkg/generators/statements" + "github.com/scylladb/gemini/pkg/joberror" + "github.com/scylladb/gemini/pkg/typedef" +) + +func (m *mutation) DDL(ctx context.Context, table *typedef.Table) error { + table.RLock() + // Scylla does not allow changing the DDL of a table with materialized views. + if len(table.MaterializedViews) > 0 { + table.RUnlock() + return nil + } + table.RUnlock() + + table.Lock() + defer table.Unlock() + ddlStmts, err := statements.GenDDLStmt(m.schema, table, m.random, &m.schema.Config) + if err != nil { + m.logger.Error("Failed! DDL Mutation statement generation failed", zap.Error(err)) + m.globalStatus.WriteErrors.Add(1) + return err + } + if ddlStmts == nil { + if w := m.logger.Check(zap.DebugLevel, "no statement generated"); w != nil { + w.Write(zap.String("job", "ddl")) + } + return nil + } + for _, ddlStmt := range ddlStmts.List { + if w := m.logger.Check(zap.DebugLevel, "ddl statement"); w != nil { + prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL() + if prettyCQLErr != nil { + return PrettyCQLError{ + PrettyCQL: prettyCQLErr, + Stmt: ddlStmt, + } + } + + w.Write(zap.String("pretty_cql", prettyCQL)) + } + if err = m.store.Mutate(ctx, ddlStmt); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + + prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL() + if prettyCQLErr != nil { + return PrettyCQLError{ + PrettyCQL: prettyCQLErr, + Stmt: ddlStmt, + Err: err, + } + } + + m.globalStatus.AddWriteError(&joberror.JobError{ + Timestamp: time.Now(), + StmtType: ddlStmts.QueryType.String(), + Message: "DDL failed: " + err.Error(), + Query: prettyCQL, + }) + + return err + } + m.globalStatus.WriteOps.Add(1) + } + ddlStmts.PostStmtHook() + + jsonSchema, _ := json.MarshalIndent(m.schema, "", " ") + fmt.Printf("New schema: %v\n", string(jsonSchema)) + + return nil +} diff --git a/pkg/jobs/gen_check_stmt_test.go b/pkg/jobs/gen_check_stmt_test.go deleted file mode 100644 index 912c9cf2..00000000 --- a/pkg/jobs/gen_check_stmt_test.go +++ /dev/null @@ -1,333 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//nolint:thelper -package jobs - -import ( - "path" - "testing" - - "github.com/scylladb/gemini/pkg/testutils" - "github.com/scylladb/gemini/pkg/utils" -) - -var checkDataPath = "./test_expected_data/check/" - -func TestGenSinglePartitionQuery(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "single_partition.json"), genSinglePartitionQueryCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) - stmt := genSinglePartitionQuery(schema, schema.Tables[0], gen) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenSinglePartitionQueryMv(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "single_partition_mv.json"), genSinglePartitionQueryMvCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - stmt := genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenMultiplePartitionQuery(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition.json"), genMultiplePartitionQueryCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) - options := testutils.GetOptionsFromCaseName(caseName) - stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenMultiplePartitionQueryMv(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_mv.json"), genMultiplePartitionQueryMvCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) - stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenClusteringRangeQuery(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range.json"), genClusteringRangeQueryCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - options := testutils.GetOptionsFromCaseName(caseName) - prc := schema.Config.GetPartitionRangeConfig() - stmt := genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenClusteringRangeQueryMv(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range_mv.json"), genClusteringRangeQueryMvCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - options := testutils.GetOptionsFromCaseName(caseName) - prc := schema.Config.GetPartitionRangeConfig() - stmt := genClusteringRangeQueryMv( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - len(schema.Tables[0].MaterializedViews)-1, - GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenMultiplePartitionClusteringRangeQuery(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range.json"), genMultiplePartitionClusteringRangeQueryCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - options := testutils.GetOptionsFromCaseName(caseName) - prc := schema.Config.GetPartitionRangeConfig() - stmt := genMultiplePartitionClusteringRangeQuery( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), - GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenMultiplePartitionClusteringRangeQueryMv(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range_mv.json"), genMultiplePartitionClusteringRangeQueryMvCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - stmt := genMultiplePartitionClusteringRangeQueryMv( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - len(schema.Tables[0].MaterializedViews)-1, - GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), - GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func TestGenSingleIndexQuery(t *testing.T) { - RunStmtTest[results](t, path.Join(checkDataPath, "single_index.json"), genSingleIndexQueryCases, - func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) { - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - stmt := genSingleIndexQuery(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].Indexes)) - validateStmt(subT, stmt, nil) - expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt)) - }) -} - -func BenchmarkGenSinglePartitionQuery(t *testing.B) { - utils.SetUnderTest() - for idx := range genSinglePartitionQueryCases { - caseName := genSinglePartitionQueryCases[idx] - t.Run(caseName, - func(subT *testing.B) { - schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genSinglePartitionQuery(schema, schema.Tables[0], gen) - } - }) - } -} - -func BenchmarkGenSinglePartitionQueryMv(t *testing.B) { - utils.SetUnderTest() - for idx := range genSinglePartitionQueryMvCases { - caseName := genSinglePartitionQueryMvCases[idx] - t.Run(caseName, - func(subT *testing.B) { - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1) - } - }) - } -} - -func BenchmarkGenMultiplePartitionQuery(t *testing.B) { - utils.SetUnderTest() - for idx := range genMultiplePartitionQueryCases { - caseName := genMultiplePartitionQueryCases[idx] - t.Run(caseName, - func(subT *testing.B) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName) - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) - } - }) - } -} - -func BenchmarkGenMultiplePartitionQueryMv(t *testing.B) { - utils.SetUnderTest() - for idx := range genMultiplePartitionQueryMvCases { - caseName := genMultiplePartitionQueryMvCases[idx] - t.Run(caseName, - func(subT *testing.B) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genMultiplePartitionQueryMv( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - len(schema.Tables[0].MaterializedViews)-1, - GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys))) - } - }) - } -} - -func BenchmarkGenClusteringRangeQuery(t *testing.B) { - utils.SetUnderTest() - for idx := range genClusteringRangeQueryCases { - caseName := genClusteringRangeQueryCases[idx] - t.Run(caseName, - func(subT *testing.B) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - } - }) - } -} - -func BenchmarkGenClusteringRangeQueryMv(t *testing.B) { - utils.SetUnderTest() - for idx := range genClusteringRangeQueryMvCases { - caseName := genClusteringRangeQueryMvCases[idx] - t.Run(caseName, - func(subT *testing.B) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genClusteringRangeQueryMv( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - len(schema.Tables[0].MaterializedViews)-1, - GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - } - }) - } -} - -func BenchmarkGenMultiplePartitionClusteringRangeQuery(t *testing.B) { - utils.SetUnderTest() - for idx := range genMultiplePartitionClusteringRangeQueryCases { - caseName := genMultiplePartitionClusteringRangeQueryCases[idx] - t.Run(caseName, - func(subT *testing.B) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genMultiplePartitionClusteringRangeQuery( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), - GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - } - }) - } -} - -func BenchmarkGenMultiplePartitionClusteringRangeQueryMv(t *testing.B) { - utils.SetUnderTest() - for idx := range genMultiplePartitionClusteringRangeQueryMvCases { - caseName := genMultiplePartitionClusteringRangeQueryMvCases[idx] - t.Run(caseName, - func(subT *testing.B) { - options := testutils.GetOptionsFromCaseName(caseName) - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genMultiplePartitionClusteringRangeQueryMv( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - len(schema.Tables[0].MaterializedViews)-1, - GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)), - GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1)) - } - }) - } -} - -func BenchmarkGenSingleIndexQuery(t *testing.B) { - utils.SetUnderTest() - for idx := range genSingleIndexQueryCases { - caseName := genSingleIndexQueryCases[idx] - t.Run(caseName, - func(subT *testing.B) { - schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName) - prc := schema.Config.GetPartitionRangeConfig() - subT.ResetTimer() - for x := 0; x < subT.N; x++ { - _ = genSingleIndexQuery( - schema, - schema.Tables[0], - gen, - rnd, - &prc, - len(schema.Tables[0].Indexes)) - } - }) - } -} diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index 3730b7de..3f38c6c5 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -16,528 +16,166 @@ package jobs import ( "context" - "encoding/json" - "fmt" + "errors" + "sync" "time" - "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/exp/rand" - "golang.org/x/sync/errgroup" "github.com/scylladb/gemini/pkg/generators" - "github.com/scylladb/gemini/pkg/joberror" "github.com/scylladb/gemini/pkg/status" - "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/store" "github.com/scylladb/gemini/pkg/typedef" ) -const ( - WriteMode = "write" - ReadMode = "read" - MixedMode = "mixed" - WarmupMode = "warmup" -) - -const ( - warmupName = "Warmup" - validateName = "Validation" - mutateName = "Mutation" -) - -var ( - warmup = job{name: warmupName, function: warmupJob} - validate = job{name: validateName, function: validationJob} - mutate = job{name: mutateName, function: mutationJob} -) - -type List struct { - name string - jobs []job - duration time.Duration - workers uint64 -} - -type job struct { - function func( - context.Context, - <-chan time.Duration, - *typedef.Schema, - typedef.SchemaConfig, - *typedef.Table, - store.Store, - *rand.Rand, - *typedef.PartitionRangeConfig, - *generators.Generator, - *status.GlobalStatus, - *zap.Logger, - *stop.Flag, - bool, - bool, - ) error - name string -} - -func ListFromMode(mode string, duration time.Duration, workers uint64) List { - jobs := make([]job, 0, 2) - name := "work cycle" - switch mode { - case WriteMode: - jobs = append(jobs, mutate) - case ReadMode: - jobs = append(jobs, validate) - case WarmupMode: - jobs = append(jobs, warmup) - name = "warmup cycle" - default: - jobs = append(jobs, mutate, validate) +type ( + Runner struct { + duration time.Duration + logger *zap.Logger + random *rand.Rand + workers uint64 + generators *generators.Generators + schema *typedef.Schema + warmup time.Duration + globalStatus *status.GlobalStatus + store store.Store + mode Mode + failFast bool + } + Job interface { + Name() string + Do(context.Context, generators.Interface, *typedef.Table) error } - return List{ - name: name, - jobs: jobs, - duration: duration, - workers: workers, - } -} +) -func (l List) Run( - ctx context.Context, - schema *typedef.Schema, - schemaConfig typedef.SchemaConfig, - s store.Store, - pump <-chan time.Duration, - generators []*generators.Generator, - globalStatus *status.GlobalStatus, +func New( + mode string, + duration time.Duration, + workers uint64, logger *zap.Logger, - seed uint64, - stopFlag *stop.Flag, - failFast, verbose bool, -) error { - logger = logger.Named(l.name) - ctx = stopFlag.CancelContextOnSignal(ctx, stop.SignalHardStop) - g, gCtx := errgroup.WithContext(ctx) - time.AfterFunc(l.duration, func() { - logger.Info("jobs time is up, begins jobs completion") - stopFlag.SetSoft(true) - }) - - partitionRangeConfig := schemaConfig.GetPartitionRangeConfig() - logger.Info("start jobs") - for j := range schema.Tables { - gen := generators[j] - table := schema.Tables[j] - for i := 0; i < int(l.workers); i++ { - for idx := range l.jobs { - jobF := l.jobs[idx].function - r := rand.New(rand.NewSource(seed)) - g.Go(func() error { - return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, gen, globalStatus, logger, stopFlag, failFast, verbose) - }) - } - } - } - return g.Wait() -} - -// mutationJob continuously applies mutations against the database -// for as long as the pump is active. -func mutationJob( - ctx context.Context, - pump <-chan time.Duration, schema *typedef.Schema, - schemaCfg typedef.SchemaConfig, - table *typedef.Table, - s store.Store, - r *rand.Rand, - p *typedef.PartitionRangeConfig, - g *generators.Generator, + store store.Store, globalStatus *status.GlobalStatus, - logger *zap.Logger, - stopFlag *stop.Flag, - failFast, verbose bool, -) error { - schemaConfig := &schemaCfg - logger = logger.Named("mutation_job") - logger.Info("starting mutation loop") - defer func() { - logger.Info("ending mutation loop") - }() - for { - if stopFlag.IsHardOrSoft() { - return nil - } - select { - case <-stopFlag.SignalChannel(): - logger.Debug("mutation job terminated") - return nil - case hb := <-pump: - time.Sleep(hb) - } - ind := r.Intn(1000000) - if ind%100000 == 0 { - err := ddl(ctx, schema, schemaConfig, table, s, r, p, globalStatus, logger, verbose) - if err != nil { - return err - } - } else { - err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, true, logger) - if err != nil { - return err - } - } - if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) - return nil - } + seed uint64, + gens *generators.Generators, + warmup time.Duration, + failFast bool, +) *Runner { + return &Runner{ + warmup: warmup, + globalStatus: globalStatus, + store: store, + mode: ModeFromString(mode), + logger: logger, + duration: duration, + workers: workers, + failFast: failFast, + random: rand.New(rand.NewSource(seed)), + generators: gens, + schema: schema, } } -// validationJob continuously applies validations against the database -// for as long as the pump is active. -func validationJob( - ctx context.Context, - pump <-chan time.Duration, - schema *typedef.Schema, - schemaCfg typedef.SchemaConfig, - table *typedef.Table, - s store.Store, - r *rand.Rand, - p *typedef.PartitionRangeConfig, - g *generators.Generator, - globalStatus *status.GlobalStatus, - logger *zap.Logger, - stopFlag *stop.Flag, - failFast, _ bool, -) error { - schemaConfig := &schemaCfg - logger = logger.Named("validation_job") - logger.Info("starting validation loop") - defer func() { - logger.Info("ending validation loop") - }() +func (l *Runner) Run(ctx context.Context) error { + l.logger.Info("start jobs") + var wg sync.WaitGroup - for { - if stopFlag.IsHardOrSoft() { - return nil - } - select { - case <-stopFlag.SignalChannel(): - return nil - case hb := <-pump: - time.Sleep(hb) - } - stmt := GenCheckStmt(schema, table, g, r, p) - if stmt == nil { - logger.Info("Validation. No statement generated from GenCheckStmt.") - continue - } - err := validation(ctx, schemaConfig, table, s, stmt, logger) - if stmt.ValuesWithToken != nil { - for _, token := range stmt.ValuesWithToken { - g.ReleaseToken(token.Token) - } - } - switch { - case err == nil: - globalStatus.ReadOps.Add(1) - case errors.Is(err, context.Canceled): - return nil - default: - query, prettyErr := stmt.PrettyCQL() - if prettyErr != nil { - return PrettyCQLError{ - PrettyCQL: prettyErr, - Stmt: stmt, - Err: err, - } - } + if l.warmup > 0 { + l.logger.Info("Warmup Job Started", + zap.Int("duration", int(l.warmup.Seconds())), + zap.Int("workers", int(l.workers)), + ) - globalStatus.AddReadError(&joberror.JobError{ - Timestamp: time.Now(), - StmtType: stmt.QueryType.String(), - Message: "Validation failed: " + err.Error(), - Query: query, - }) - } - - if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) - return nil - } + warmupCtx, cancel := context.WithTimeout(ctx, l.warmup) + defer cancel() + l.startMutation(warmupCtx, cancel, &wg, l.random, "Warmup", false, false) + wg.Wait() } -} -// warmupJob continuously applies mutations against the database -// for as long as the pump is active or the supplied duration expires. -func warmupJob( - ctx context.Context, - _ <-chan time.Duration, - schema *typedef.Schema, - schemaCfg typedef.SchemaConfig, - table *typedef.Table, - s store.Store, - r *rand.Rand, - p *typedef.PartitionRangeConfig, - g *generators.Generator, - globalStatus *status.GlobalStatus, - logger *zap.Logger, - stopFlag *stop.Flag, - failFast, _ bool, -) error { - schemaConfig := &schemaCfg - logger = logger.Named("warmup") - logger.Info("starting warmup loop") - defer func() { - logger.Info("ending warmup loop") - }() - for { - if stopFlag.IsHardOrSoft() { - logger.Debug("warmup job terminated") - return nil - } - // Do we care about errors during warmup? - err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, false, logger) - if err != nil { - return err - } + ctx, cancel := context.WithTimeout(ctx, l.duration+1*time.Second) + defer cancel() - if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft(true) - return nil - } - } -} + src := rand.NewSource(l.random.Uint64()) -func ddl( - ctx context.Context, - schema *typedef.Schema, - sc *typedef.SchemaConfig, - table *typedef.Table, - s store.Store, - r *rand.Rand, - p *typedef.PartitionRangeConfig, - globalStatus *status.GlobalStatus, - logger *zap.Logger, - verbose bool, -) error { - if sc.CQLFeature != typedef.CQL_FEATURE_ALL { - logger.Debug("ddl statements disabled") - return nil - } - if len(table.MaterializedViews) > 0 { - // Scylla does not allow changing the DDL of a table with materialized views. - return nil - } - table.Lock() - defer table.Unlock() - ddlStmts, err := GenDDLStmt(schema, table, r, p, sc) - if err != nil { - logger.Error("Failed! DDL Mutation statement generation failed", zap.Error(err)) - globalStatus.WriteErrors.Add(1) - return err + if l.mode.IsRead() { + l.startValidation(ctx, &wg, cancel, src) } - if ddlStmts == nil { - if w := logger.Check(zap.DebugLevel, "no statement generated"); w != nil { - w.Write(zap.String("job", "ddl")) - } - return nil + if l.mode.IsWrite() { + l.startMutation(ctx, cancel, &wg, src, "Mutation", true, true) } - for _, ddlStmt := range ddlStmts.List { - if w := logger.Check(zap.DebugLevel, "ddl statement"); w != nil { - prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL() - if prettyCQLErr != nil { - return PrettyCQLError{ - PrettyCQL: prettyCQLErr, - Stmt: ddlStmt, - } - } - - w.Write(zap.String("pretty_cql", prettyCQL)) - } - - if err = s.Mutate(ctx, ddlStmt); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } + wg.Wait() - prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL() - if prettyCQLErr != nil { - return PrettyCQLError{ - PrettyCQL: prettyCQLErr, - Stmt: ddlStmt, - Err: err, - } - } - - globalStatus.AddWriteError(&joberror.JobError{ - Timestamp: time.Now(), - StmtType: ddlStmts.QueryType.String(), - Message: "DDL failed: " + err.Error(), - Query: prettyCQL, - }) - - return err - } - globalStatus.WriteOps.Add(1) - } - ddlStmts.PostStmtHook() - if verbose { - jsonSchema, _ := json.MarshalIndent(schema, "", " ") - fmt.Printf("New schema: %v\n", string(jsonSchema)) - } return nil } -func mutation( - ctx context.Context, - schema *typedef.Schema, - _ *typedef.SchemaConfig, - table *typedef.Table, - s store.Store, - r *rand.Rand, - p *typedef.PartitionRangeConfig, - g *generators.Generator, - globalStatus *status.GlobalStatus, - deletes bool, - logger *zap.Logger, -) error { - mutateStmt, err := GenMutateStmt(schema, table, g, r, p, deletes) +func (l *Runner) startMutation(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, src rand.Source, name string, deletes, ddl bool) { + logger := l.logger.Named(name) + + err := l.start(ctx, cancel, wg, rand.New(src), func(rnd *rand.Rand) Job { + return NewMutation( + logger, + l.schema, + l.store, + l.globalStatus, + rnd, + l.failFast, + deletes, + ddl, + ) + }) if err != nil { - logger.Error("Failed! Mutation statement generation failed", zap.Error(err)) - globalStatus.WriteErrors.Add(1) - return err - } - if mutateStmt == nil { - if w := logger.Check(zap.DebugLevel, "no statement generated"); w != nil { - w.Write(zap.String("job", "mutation")) - } - return err - } - - if w := logger.Check(zap.DebugLevel, "mutation statement"); w != nil { - prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL() - if prettyCQLErr != nil { - return PrettyCQLError{ - PrettyCQL: prettyCQLErr, - Stmt: mutateStmt, - Err: err, - } - } - - w.Write(zap.String("pretty_cql", prettyCQL)) - } - if err = s.Mutate(ctx, mutateStmt); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - - prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL() - if prettyCQLErr != nil { - return PrettyCQLError{ - PrettyCQL: prettyCQLErr, - Stmt: mutateStmt, - Err: err, - } + logger.Error("Mutation job failed", zap.Error(err)) + if l.failFast { + cancel() } - - globalStatus.AddWriteError(&joberror.JobError{ - Timestamp: time.Now(), - StmtType: mutateStmt.QueryType.String(), - Message: "Mutation failed: " + err.Error(), - Query: prettyCQL, - }) - - return err } - - globalStatus.WriteOps.Add(1) - g.GiveOlds(mutateStmt.ValuesWithToken) - - return nil } -func validation( - ctx context.Context, - sc *typedef.SchemaConfig, - table *typedef.Table, - s store.Store, - stmt *typedef.Stmt, - logger *zap.Logger, -) error { - if w := logger.Check(zap.DebugLevel, "validation statement"); w != nil { - prettyCQL, prettyCQLErr := stmt.PrettyCQL() - if prettyCQLErr != nil { - return PrettyCQLError{ - PrettyCQL: prettyCQLErr, - Stmt: stmt, - } - } - - w.Write(zap.String("pretty_cql", prettyCQL)) - } - - maxAttempts := 1 - delay := 10 * time.Millisecond - if stmt.QueryType.PossibleAsyncOperation() { - maxAttempts = sc.AsyncObjectStabilizationAttempts - if maxAttempts < 1 { - maxAttempts = 1 +func (l *Runner) startValidation(ctx context.Context, wg *sync.WaitGroup, cancel context.CancelFunc, src rand.Source) { + err := l.start(ctx, cancel, wg, rand.New(src), func(rnd *rand.Rand) Job { + return NewValidation( + l.logger, + l.schema, + l.store, + rnd, + l.globalStatus, + l.failFast, + ) + }) + if err != nil { + l.logger.Error("Validation job failed", zap.Error(err)) + if l.failFast { + cancel() } - delay = sc.AsyncObjectStabilizationDelay } +} - var lastErr, err error - attempt := 1 - for { - lastErr = err - err = s.Check(ctx, table, stmt, attempt == maxAttempts) - - if err == nil { - if attempt > 1 { - logger.Info(fmt.Sprintf("Validation successfully completed on %d attempt.", attempt)) - } - return nil - } - if errors.Is(err, context.Canceled) { - // When context is canceled it means that test was commanded to stop - // to skip logging part it is returned here - return err - } - if attempt == maxAttempts { - break - } - if errors.Is(err, unWrapErr(lastErr)) { - logger.Info(fmt.Sprintf("Retring failed validation. %d attempt from %d attempts. Error same as at attempt before. ", attempt, maxAttempts)) - } else { - logger.Info(fmt.Sprintf("Retring failed validation. %d attempt from %d attempts. Error: %s", attempt, maxAttempts, err)) - } - - select { - case <-time.After(delay): - case <-ctx.Done(): - logger.Info(fmt.Sprintf("Retring failed validation stoped by done context. %d attempt from %d attempts. Error: %s", attempt, maxAttempts, err)) - return nil +func (l *Runner) start(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, rnd *rand.Rand, job func(*rand.Rand) Job) error { + wg.Add(int(l.workers)) + + for _, table := range l.schema.Tables { + gen := l.generators.Get() + for range l.workers { + j := job(rand.New(rand.NewSource(rnd.Uint64()))) + go func(j Job) { + defer wg.Done() + if err := j.Do(ctx, gen, table); err != nil { + if errors.Is(err, context.Canceled) { + return + } + + l.logger.Error("job failed", zap.String("table", table.Name), zap.Error(err)) + + if l.failFast { + cancel() + } + } + }(j) } - attempt++ } - if attempt > 1 { - logger.Info(fmt.Sprintf("Retring failed validation stoped by reach of max attempts %d. Error: %s", maxAttempts, err)) - } else { - logger.Info(fmt.Sprintf("Validation failed. Error: %s", err)) - } - - return err -} - -func unWrapErr(err error) error { - nextErr := err - for nextErr != nil { - err = nextErr - nextErr = errors.Unwrap(err) - } - return err + return nil } diff --git a/pkg/jobs/mode.go b/pkg/jobs/mode.go new file mode 100644 index 00000000..cc2cd109 --- /dev/null +++ b/pkg/jobs/mode.go @@ -0,0 +1,36 @@ +package jobs + +import "strings" + +type Mode []string + +const ( + WriteMode = "write" + ReadMode = "read" + MixedMode = "mixed" +) + +func (m Mode) IsWrite() bool { + return m[0] == WriteMode +} + +func (m Mode) IsRead() bool { + if len(m) == 1 { + return m[0] == ReadMode + } + + return m[0] == ReadMode || m[1] == ReadMode +} + +func ModeFromString(m string) Mode { + switch strings.ToLower(m) { + case WriteMode: + return Mode{WriteMode} + case ReadMode: + return Mode{ReadMode} + case MixedMode: + return Mode{WriteMode, ReadMode} + default: + panic("unknown mode " + m) + } +} diff --git a/pkg/jobs/mutation.go b/pkg/jobs/mutation.go new file mode 100644 index 00000000..6b2068d5 --- /dev/null +++ b/pkg/jobs/mutation.go @@ -0,0 +1,155 @@ +package jobs + +import ( + "context" + "time" + + "github.com/pkg/errors" + "go.uber.org/zap" + "golang.org/x/exp/rand" + + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/generators/statements" + "github.com/scylladb/gemini/pkg/joberror" + "github.com/scylladb/gemini/pkg/status" + "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/typedef" +) + +type ( + Mutation struct { + logger *zap.Logger + mutation mutation + failFast bool + } + + mutation struct { + logger *zap.Logger + schema *typedef.Schema + store store.Store + globalStatus *status.GlobalStatus + random *rand.Rand + deletes bool + ddl bool + } +) + +func NewMutation( + logger *zap.Logger, + schema *typedef.Schema, + store store.Store, + globalStatus *status.GlobalStatus, + rnd *rand.Rand, + failFast bool, + deletes bool, + ddl bool, +) *Mutation { + return &Mutation{ + logger: logger, + mutation: mutation{ + logger: logger.Named("mutation-with-deletes"), + schema: schema, + store: store, + globalStatus: globalStatus, + deletes: deletes, + random: rnd, + ddl: ddl, + }, + failFast: failFast, + } +} + +func (m *Mutation) Name() string { + return "Mutation" +} + +func (m *Mutation) Do(ctx context.Context, generator generators.Interface, table *typedef.Table) error { + m.logger.Info("starting mutation loop") + defer m.logger.Info("ending mutation loop") + + for { + select { + case <-ctx.Done(): + m.logger.Debug("mutation job terminated") + return context.Canceled + default: + } + + var err error + + if m.mutation.ShouldDoDDL() { + err = m.mutation.DDL(ctx, table) + } else { + err = m.mutation.Statement(ctx, generator, table) + } + + if err != nil { + return errors.WithStack(err) + } + } +} + +func (m *mutation) Statement(ctx context.Context, generator generators.Interface, table *typedef.Table) error { + partitionRangeConfig := m.schema.Config.GetPartitionRangeConfig() + mutateStmt, err := statements.GenMutateStmt(m.schema, table, generator, m.random, &partitionRangeConfig, m.deletes) + if err != nil { + m.logger.Error("Failed! Mutation statement generation failed", zap.Error(err)) + m.globalStatus.WriteErrors.Add(1) + return errors.WithStack(err) + } + + if w := m.logger.Check(zap.DebugLevel, "mutation statement"); w != nil { + prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL() + if prettyCQLErr != nil { + return errors.WithStack(PrettyCQLError{ + PrettyCQL: prettyCQLErr, + Stmt: mutateStmt, + Err: err, + }) + } + + w.Write(zap.String("pretty_cql", prettyCQL)) + } + + if err = m.store.Mutate(ctx, mutateStmt); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + + prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL() + if prettyCQLErr != nil { + return errors.WithStack(PrettyCQLError{ + PrettyCQL: prettyCQLErr, + Stmt: mutateStmt, + Err: err, + }) + } + + m.globalStatus.AddWriteError(&joberror.JobError{ + Timestamp: time.Now(), + StmtType: mutateStmt.QueryType.String(), + Message: "Mutation failed: " + err.Error(), + Query: prettyCQL, + }) + + return err + } + + m.globalStatus.WriteOps.Add(1) + generator.GiveOld(mutateStmt.ValuesWithToken...) + + return nil +} + +func (m *mutation) HasErrors() bool { + return m.globalStatus.HasErrors() +} + +func (m *mutation) ShouldDoDDL() bool { + if m.ddl && m.schema.Config.CQLFeature == typedef.CQL_FEATURE_ALL { + ind := m.random.Intn(100000) + return ind < 100 + } + + return false +} diff --git a/pkg/jobs/validation.go b/pkg/jobs/validation.go new file mode 100644 index 00000000..a8e0fb20 --- /dev/null +++ b/pkg/jobs/validation.go @@ -0,0 +1,159 @@ +package jobs + +import ( + "context" + "fmt" + "time" + + "github.com/pkg/errors" + "go.uber.org/zap" + "golang.org/x/exp/rand" + + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/generators/statements" + "github.com/scylladb/gemini/pkg/joberror" + "github.com/scylladb/gemini/pkg/status" + "github.com/scylladb/gemini/pkg/store" + "github.com/scylladb/gemini/pkg/typedef" +) + +type Validation struct { + logger *zap.Logger + schema *typedef.Schema + store store.Store + random *rand.Rand + globalStatus *status.GlobalStatus + failFast bool +} + +func NewValidation( + logger *zap.Logger, + schema *typedef.Schema, + store store.Store, + random *rand.Rand, + globalStatus *status.GlobalStatus, + failFast bool, +) *Validation { + return &Validation{ + logger: logger.Named("validation"), + schema: schema, + store: store, + random: random, + globalStatus: globalStatus, + failFast: failFast, + } +} + +func (v *Validation) Name() string { + return "Validation" +} + +func (v *Validation) validate(ctx context.Context, generator generators.Interface, table *typedef.Table) error { + + err := v.validation(ctx, table, generator) + + switch { + case err == nil: + v.globalStatus.ReadOps.Add(1) + case errors.Is(err, context.Canceled): + return context.Canceled + default: + v.globalStatus.AddReadError(&joberror.JobError{ + Timestamp: time.Now(), + //StmtType: stmt.QueryType.String(), + Message: "Validation failed: " + err.Error(), + }) + + if v.failFast && v.globalStatus.HasErrors() { + return err + } + } + + return err +} + +func (v *Validation) Do(ctx context.Context, generator generators.Interface, table *typedef.Table) error { + v.logger.Info("starting validation loop") + defer v.logger.Info("ending validation loop") + + for { + select { + case <-ctx.Done(): + v.logger.Info("Context Done...") + return nil + default: + } + + if err := v.validate(ctx, generator, table); errors.Is(err, context.Canceled) { + return nil + } + + if v.failFast && v.globalStatus.HasErrors() { + return nil + } + } +} + +func (v *Validation) validation( + ctx context.Context, + table *typedef.Table, + generator generators.Interface, +) error { + partitionRangeConfig := v.schema.Config.GetPartitionRangeConfig() + stmt, cleanup := statements.GenCheckStmt(v.schema, table, generator, v.random, &partitionRangeConfig) + defer cleanup() + + if w := v.logger.Check(zap.DebugLevel, "validation statement"); w != nil { + prettyCQL, prettyCQLErr := stmt.PrettyCQL() + if prettyCQLErr != nil { + return PrettyCQLError{ + PrettyCQL: prettyCQLErr, + } + } + + w.Write(zap.String("pretty_cql", prettyCQL)) + } + + maxAttempts := v.schema.Config.AsyncObjectStabilizationAttempts + delay := time.Duration(0) + if maxAttempts < 1 { + maxAttempts = 1 + } + + var err error + + for attempt := 1; ; attempt++ { + select { + case <-ctx.Done(): + v.logger.Info("Context Done... validation exiting") + return context.Canceled + case <-time.After(delay): + delay = v.schema.Config.AsyncObjectStabilizationDelay + } + + err = v.store.Check(ctx, table, stmt, attempt == maxAttempts) + + if err == nil { + if attempt > 1 { + v.logger.Info(fmt.Sprintf("Validation successfully completed on %d attempt.", attempt)) + } + return nil + } + + if errors.Is(err, context.Canceled) { + return context.Canceled + } + + if attempt == maxAttempts { + if attempt > 1 { + v.logger.Info(fmt.Sprintf("Retring failed validation stoped by reach of max attempts %d. Error: %s", maxAttempts, err)) + } else { + v.logger.Info(fmt.Sprintf("Validation failed. Error: %s", err)) + } + + return err + } + + v.logger.Info(fmt.Sprintf("Retring failed validation. %d attempt from %d attempts. Error: %s", attempt, maxAttempts, err)) + } +} diff --git a/pkg/replication/replication.go b/pkg/replication/replication.go index e0e26ef6..af059e29 100644 --- a/pkg/replication/replication.go +++ b/pkg/replication/replication.go @@ -21,9 +21,30 @@ import ( type Replication map[string]any +var ( + singleQuoteReplacer = strings.NewReplacer("'", "\"") + doubleQuoteReplacer = strings.NewReplacer("\"", "'") +) + func (r *Replication) ToCQL() string { b, _ := json.Marshal(r) - return strings.ReplaceAll(string(b), "\"", "'") + return doubleQuoteReplacer.Replace(string(b)) +} + +func MustParseReplication(rs string) *Replication { + switch rs { + case "network": + return NewNetworkTopologyStrategy() + case "simple": + return NewSimpleStrategy() + default: + var strategy Replication + + if err := json.Unmarshal([]byte(singleQuoteReplacer.Replace(rs)), &strategy); err != nil { + panic("unable to parse replication strategy: " + rs + " Error: " + err.Error()) + } + return &strategy + } } func NewSimpleStrategy() *Replication { diff --git a/pkg/replication/replication_test.go b/pkg/replication/replication_test.go index 82ed471c..9b25e714 100644 --- a/pkg/replication/replication_test.go +++ b/pkg/replication/replication_test.go @@ -17,6 +17,8 @@ package replication_test import ( "testing" + "github.com/google/go-cmp/cmp" + "github.com/scylladb/gemini/pkg/replication" ) @@ -46,3 +48,36 @@ func TestToCQL(t *testing.T) { }) } } + +func TestGetReplicationStrategy(t *testing.T) { + tests := map[string]struct { + strategy string + expected string + }{ + "simple strategy": { + strategy: "{\"class\": \"SimpleStrategy\", \"replication_factor\": \"1\"}", + expected: "{'class':'SimpleStrategy','replication_factor':'1'}", + }, + "simple strategy single quotes": { + strategy: "{'class': 'SimpleStrategy', 'replication_factor': '1'}", + expected: "{'class':'SimpleStrategy','replication_factor':'1'}", + }, + "network topology strategy": { + strategy: "{\"class\": \"NetworkTopologyStrategy\", \"dc1\": 3, \"dc2\": 3}", + expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}", + }, + "network topology strategy single quotes": { + strategy: "{'class': 'NetworkTopologyStrategy', 'dc1': 3, 'dc2': 3}", + expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + got := replication.MustParseReplication(tc.strategy) + if diff := cmp.Diff(got.ToCQL(), tc.expected); diff != "" { + t.Errorf("expected=%s, got=%s,diff=%s", tc.strategy, got.ToCQL(), diff) + } + }) + } +} diff --git a/pkg/status/status.go b/pkg/status/status.go index cbaf1c9d..2568f067 100644 --- a/pkg/status/status.go +++ b/pkg/status/status.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "sync/atomic" + "time" "github.com/pkg/errors" @@ -43,25 +44,21 @@ type GlobalStatus struct { } func (gs *GlobalStatus) AddWriteError(err *joberror.JobError) { - // TODO: https://github.com/scylladb/gemini/issues/302 - Move out and add logging - fmt.Printf("Error detected: %#v", err) gs.Errors.AddError(err) gs.WriteErrors.Add(1) } func (gs *GlobalStatus) AddReadError(err *joberror.JobError) { - // TODO: https://github.com/scylladb/gemini/issues/302 - Move out and add logging - fmt.Printf("Error detected: %#v", err) gs.Errors.AddError(err) gs.ReadErrors.Add(1) } -func (gs *GlobalStatus) PrintResultAsJSON(w io.Writer, schema *typedef.Schema, version string) error { +func (gs *GlobalStatus) PrintResultAsJSON(w io.Writer, schema *typedef.Schema, version string, start time.Time) error { result := map[string]any{ "result": gs, "gemini_version": version, "schemaHash": schema.GetHash(), - "schema": schema, + "Time": time.Since(start).String(), } encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) @@ -81,12 +78,13 @@ func (gs *GlobalStatus) HasErrors() bool { return gs.WriteErrors.Load() > 0 || gs.ReadErrors.Load() > 0 } -func (gs *GlobalStatus) PrintResult(w io.Writer, schema *typedef.Schema, version string) { - if err := gs.PrintResultAsJSON(w, schema, version); err != nil { +func (gs *GlobalStatus) PrintResult(w io.Writer, schema *typedef.Schema, version string, start time.Time) { + if err := gs.PrintResultAsJSON(w, schema, version, start); err != nil { // In case there has been it has been a long run we want to display it anyway... fmt.Printf("Unable to print result as json, using plain text to stdout, error=%s\n", err) fmt.Printf("Gemini version: %s\n", version) fmt.Printf("Results:\n") + fmt.Printf("\ttime: %v\n", time.Since(start).String()) fmt.Printf("\twrite ops: %v\n", gs.WriteOps.Load()) fmt.Printf("\tread ops: %v\n", gs.ReadOps.Load()) fmt.Printf("\twrite errors: %v\n", gs.WriteErrors.Load()) @@ -94,8 +92,6 @@ func (gs *GlobalStatus) PrintResult(w io.Writer, schema *typedef.Schema, version for i, err := range gs.Errors.Errors() { fmt.Printf("Error %d: %s\n", i, err) } - jsonSchema, _ := json.MarshalIndent(schema, "", " ") - fmt.Printf("Schema: %v\n", string(jsonSchema)) } } diff --git a/pkg/stop/flag.go b/pkg/stop/flag.go deleted file mode 100644 index 54c6e1f2..00000000 --- a/pkg/stop/flag.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop - -import ( - "context" - "fmt" - "os" - "os/signal" - "sync" - "sync/atomic" - "syscall" - - "go.uber.org/zap" -) - -const ( - SignalNoop uint32 = iota - SignalSoftStop - SignalHardStop -) - -type SignalChannel chan uint32 - -var closedChan = createClosedChan() - -func createClosedChan() SignalChannel { - ch := make(SignalChannel) - close(ch) - return ch -} - -type SyncList[T any] struct { - children []T - childrenLock sync.RWMutex -} - -func (f *SyncList[T]) Append(el T) { - f.childrenLock.Lock() - defer f.childrenLock.Unlock() - f.children = append(f.children, el) -} - -func (f *SyncList[T]) Get() []T { - f.childrenLock.RLock() - defer f.childrenLock.RUnlock() - return f.children -} - -type logger interface { - Debug(msg string, fields ...zap.Field) -} - -type Flag struct { - name string - log logger - ch atomic.Pointer[SignalChannel] - parent *Flag - children SyncList[*Flag] - stopHandlers SyncList[func(signal uint32)] - val atomic.Uint32 -} - -func (s *Flag) Name() string { - return s.name -} - -func (s *Flag) closeChannel() { - ch := s.ch.Swap(&closedChan) - if ch != &closedChan { - close(*ch) - } -} - -func (s *Flag) sendSignal(signal uint32, sendToParent bool) bool { - s.log.Debug(fmt.Sprintf("flag %s received signal %s", s.name, GetStateName(signal))) - s.closeChannel() - out := s.val.CompareAndSwap(SignalNoop, signal) - if !out { - return false - } - - for _, handler := range s.stopHandlers.Get() { - handler(signal) - } - - for _, child := range s.children.Get() { - child.sendSignal(signal, sendToParent) - } - if sendToParent && s.parent != nil { - s.parent.sendSignal(signal, sendToParent) - } - return out -} - -func (s *Flag) SetHard(sendToParent bool) bool { - return s.sendSignal(SignalHardStop, sendToParent) -} - -func (s *Flag) SetSoft(sendToParent bool) bool { - return s.sendSignal(SignalSoftStop, sendToParent) -} - -func (s *Flag) CreateChild(name string) *Flag { - child := newFlag(name, s) - s.children.Append(child) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - child.sendSignal(val, false) - } - return child -} - -func (s *Flag) SignalChannel() SignalChannel { - return *s.ch.Load() -} - -func (s *Flag) IsSoft() bool { - return s.val.Load() == SignalSoftStop -} - -func (s *Flag) IsHard() bool { - return s.val.Load() == SignalHardStop -} - -func (s *Flag) IsHardOrSoft() bool { - return s.val.Load() != SignalNoop -} - -func (s *Flag) AddHandler(handler func(signal uint32)) { - s.stopHandlers.Append(handler) - val := s.val.Load() - switch val { - case SignalSoftStop, SignalHardStop: - handler(val) - } -} - -func (s *Flag) AddHandler2(handler func(), expectedSignal uint32) { - s.AddHandler(func(signal uint32) { - switch expectedSignal { - case SignalNoop: - handler() - default: - if signal == expectedSignal { - handler() - } - } - }) -} - -func (s *Flag) CancelContextOnSignal(ctx context.Context, expectedSignal uint32) context.Context { - ctx, cancel := context.WithCancel(ctx) - s.AddHandler2(cancel, expectedSignal) - return ctx -} - -func (s *Flag) SetLogger(log logger) { - s.log = log -} - -func NewFlag(name string) *Flag { - return newFlag(name, nil) -} - -func newFlag(name string, parent *Flag) *Flag { - out := Flag{ - name: name, - parent: parent, - log: zap.NewNop(), - } - ch := make(SignalChannel) - out.ch.Store(&ch) - return &out -} - -func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) { - graceful := make(chan os.Signal, 1) - signal.Notify(graceful, syscall.SIGTERM, syscall.SIGINT) - go func() { - sig := <-graceful - switch sig { - case syscall.SIGINT: - for i := range flags { - flags[i].SetSoft(true) - } - logger.Info("Get SIGINT signal, begin soft stop.") - default: - for i := range flags { - flags[i].SetHard(true) - } - logger.Info("Get SIGTERM signal, begin hard stop.") - } - }() -} - -func GetStateName(state uint32) string { - switch state { - case SignalSoftStop: - return "soft" - case SignalHardStop: - return "hard" - case SignalNoop: - return "no-signal" - default: - panic(fmt.Sprintf("unexpected signal %d", state)) - } -} diff --git a/pkg/stop/flag_test.go b/pkg/stop/flag_test.go deleted file mode 100644 index 81a4b344..00000000 --- a/pkg/stop/flag_test.go +++ /dev/null @@ -1,429 +0,0 @@ -// Copyright 2019 ScyllaDB -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stop_test - -import ( - "context" - "errors" - "fmt" - "reflect" - "runtime" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/scylladb/gemini/pkg/stop" -) - -func TestHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHard, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func TestSoftStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } -} - -func TestSoftOrHardStop(t *testing.T) { - t.Parallel() - testFlag, ctx, workersDone := initVars() - workers := 30 - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetSoft) - if ctx.Err() != nil { - t.Error("Error:SetSoft function apply hardStopHandler") - } - - workersDone.Store(uint32(0)) - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() != nil { - t.Error("Error:SetHard function apply hardStopHandler after SetSoft") - } - - testFlag, ctx, workersDone = initVars() - workersDone.Store(uint32(0)) - - testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard) - if ctx.Err() == nil { - t.Error("Error:SetHard function does not apply hardStopHandler") - } -} - -func initVars() (testFlag *stop.Flag, ctx context.Context, workersDone *atomic.Uint32) { - testFlagOut := stop.NewFlag("main_test") - ctx = testFlagOut.CancelContextOnSignal(context.Background(), stop.SignalHardStop) - workersDone = &atomic.Uint32{} - return testFlagOut, ctx, workersDone -} - -func testSignals( - t *testing.T, - workersDone *atomic.Uint32, - workers int, - checkFunc func() bool, - setFunc func(propagation bool) bool, -) { - t.Helper() - for i := 0; i != workers; i++ { - go func() { - for { - if checkFunc() { - workersDone.Add(1) - return - } - time.Sleep(10 * time.Millisecond) - } - }() - } - time.Sleep(200 * time.Millisecond) - setFunc(false) - - for i := 0; i != 10; i++ { - time.Sleep(100 * time.Millisecond) - if workersDone.Load() == uint32(workers) { - break - } - } - - setFuncName := runtime.FuncForPC(reflect.ValueOf(setFunc).Pointer()).Name() - setFuncName, _ = strings.CutSuffix(setFuncName, "-fm") - _, setFuncName, _ = strings.Cut(setFuncName, ".(") - setFuncName = strings.ReplaceAll(setFuncName, ").", ".") - checkFuncName := runtime.FuncForPC(reflect.ValueOf(checkFunc).Pointer()).Name() - checkFuncName, _ = strings.CutSuffix(checkFuncName, "-fm") - _, checkFuncName, _ = strings.Cut(checkFuncName, ".(") - checkFuncName = strings.ReplaceAll(checkFuncName, ").", ".") - - if workersDone.Load() != uint32(workers) { - t.Errorf("Error:%s or %s functions works not correctly %[2]s=%v", setFuncName, checkFuncName, checkFunc()) - } -} - -func TestSendToParent(t *testing.T) { - t.Parallel() - tcases := []tCase{ - { - testName: "parent-hard-true", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-hard-false", - parentSignal: stop.SignalHardStop, - child1Signal: stop.SignalHardStop, - child11Signal: stop.SignalHardStop, - child12Signal: stop.SignalHardStop, - child2Signal: stop.SignalHardStop, - }, - { - testName: "parent-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "parent-soft-false", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child1-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalNoop, - }, - { - testName: "child11-soft-true", - parentSignal: stop.SignalSoftStop, - child1Signal: stop.SignalSoftStop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalSoftStop, - child2Signal: stop.SignalSoftStop, - }, - { - testName: "child11-soft-false", - parentSignal: stop.SignalNoop, - child1Signal: stop.SignalNoop, - child11Signal: stop.SignalSoftStop, - child12Signal: stop.SignalNoop, - child2Signal: stop.SignalNoop, - }, - } - for id := range tcases { - tcase := tcases[id] - t.Run(tcase.testName, func(t *testing.T) { - t.Parallel() - if err := tcase.runTest(); err != nil { - t.Error(err) - } - }) - } -} - -// nolint: govet -type parentChildInfo struct { - parent *stop.Flag - parentSignal uint32 - child1 *stop.Flag - child1Signal uint32 - child11 *stop.Flag - child11Signal uint32 - child12 *stop.Flag - child12Signal uint32 - child2 *stop.Flag - child2Signal uint32 -} - -func (t *parentChildInfo) getFlag(flagName string) *stop.Flag { - switch flagName { - case "parent": - return t.parent - case "child1": - return t.child1 - case "child2": - return t.child2 - case "child11": - return t.child11 - case "child12": - return t.child12 - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) getFlagHandlerState(flagName string) uint32 { - switch flagName { - case "parent": - return t.parentSignal - case "child1": - return t.child1Signal - case "child2": - return t.child2Signal - case "child11": - return t.child11Signal - case "child12": - return t.child12Signal - default: - panic(fmt.Sprintf("no such flag %s", flagName)) - } -} - -func (t *parentChildInfo) checkFlagState(flag *stop.Flag, expectedState uint32) error { - var err error - flagName := flag.Name() - state := t.getFlagHandlerState(flagName) - if state != expectedState { - err = errors.Join(err, fmt.Errorf("flag %s handler has state %s while it is expected to be %s", flagName, stop.GetStateName(state), stop.GetStateName(expectedState))) - } - flagState := getFlagState(flag) - if stop.GetStateName(expectedState) != flagState { - err = errors.Join(err, fmt.Errorf("flag %s has state %s while it is expected to be %s", flagName, flagState, stop.GetStateName(expectedState))) - } - return err -} - -type tCase struct { - testName string - parentSignal uint32 - child1Signal uint32 - child11Signal uint32 - child12Signal uint32 - child2Signal uint32 -} - -func (t *tCase) runTest() error { - chunk := strings.Split(t.testName, "-") - if len(chunk) != 3 { - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - flagName := chunk[0] - signalTypeName := chunk[1] - sendToParentName := chunk[2] - - var sendToParent bool - switch sendToParentName { - case "true": - sendToParent = true - case "false": - sendToParent = false - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - runt := newParentChildInfo() - flag := runt.getFlag(flagName) - switch signalTypeName { - case "soft": - flag.SetSoft(sendToParent) - case "hard": - flag.SetHard(sendToParent) - default: - panic(fmt.Sprintf("wrong test name %s", t.testName)) - } - var err error - err = errors.Join(err, runt.checkFlagState(runt.parent, t.parentSignal)) - err = errors.Join(err, runt.checkFlagState(runt.child1, t.child1Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child2, t.child2Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child11, t.child11Signal)) - err = errors.Join(err, runt.checkFlagState(runt.child12, t.child12Signal)) - return err -} - -func newParentChildInfo() *parentChildInfo { - parent := stop.NewFlag("parent") - child1 := parent.CreateChild("child1") - out := parentChildInfo{ - parent: parent, - child1: child1, - child11: child1.CreateChild("child11"), - child12: child1.CreateChild("child12"), - child2: parent.CreateChild("child2"), - } - - out.parent.AddHandler(func(signal uint32) { - out.parentSignal = signal - }) - out.child1.AddHandler(func(signal uint32) { - out.child1Signal = signal - }) - out.child11.AddHandler(func(signal uint32) { - out.child11Signal = signal - }) - out.child12.AddHandler(func(signal uint32) { - out.child12Signal = signal - }) - out.child2.AddHandler(func(signal uint32) { - out.child2Signal = signal - }) - return &out -} - -func getFlagState(flag *stop.Flag) string { - switch { - case flag.IsSoft(): - return "soft" - case flag.IsHard(): - return "hard" - default: - return "no-signal" - } -} - -func TestSignalChannel(t *testing.T) { - t.Parallel() - t.Run("single-no-signal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - select { - case <-flag.SignalChannel(): - t.Error("should not get the signal") - case <-time.Tick(200 * time.Millisecond): - } - }) - - t.Run("single-beforehand", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - flag.SetSoft(true) - <-flag.SignalChannel() - }) - - t.Run("single-normal", func(t *testing.T) { - t.Parallel() - flag := stop.NewFlag("parent") - go func() { - time.Sleep(200 * time.Millisecond) - flag.SetSoft(true) - }() - <-flag.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - parent.SetSoft(true) - <-child.SignalChannel() - }) - - t.Run("parent-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - parent.SetSoft(true) - child := parent.CreateChild("child") - <-child.SignalChannel() - }) - - t.Run("parent-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - parent.SetSoft(true) - }() - <-child.SignalChannel() - }) - - t.Run("child-beforehand", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - child.SetSoft(true) - <-parent.SignalChannel() - }) - - t.Run("child-normal", func(t *testing.T) { - t.Parallel() - parent := stop.NewFlag("parent") - child := parent.CreateChild("child") - go func() { - time.Sleep(200 * time.Millisecond) - child.SetSoft(true) - }() - <-parent.SignalChannel() - }) -} diff --git a/pkg/store/helpers.go b/pkg/store/helpers.go index 6758851e..d8cc39bb 100644 --- a/pkg/store/helpers.go +++ b/pkg/store/helpers.go @@ -79,7 +79,6 @@ func lt(mi, mj map[string]any) bool { return true default: msg := fmt.Sprintf("unhandled type %T\n", mis) - time.Sleep(time.Second) panic(msg) } } diff --git a/pkg/store/store.go b/pkg/store/store.go index 51095b2d..b434d08b 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -233,9 +233,11 @@ func (ds delegatingStore) Check(ctx context.Context, table *typedef.Table, stmt }), cmp.Comparer(func(x, y *inf.Dec) bool { return x.Cmp(y) == 0 - }), cmp.Comparer(func(x, y *big.Int) bool { + }), + cmp.Comparer(func(x, y *big.Int) bool { return x.Cmp(y) == 0 - })) + }), + ) if diff != "" { return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff) } diff --git a/pkg/testutils/mock_generator.go b/pkg/testutils/mock_generator.go index eb6c0184..add33585 100644 --- a/pkg/testutils/mock_generator.go +++ b/pkg/testutils/mock_generator.go @@ -19,10 +19,13 @@ import ( "golang.org/x/exp/rand" + "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/routingkey" "github.com/scylladb/gemini/pkg/typedef" ) +var _ generators.Interface = &MockGenerator{} + type MockGenerator struct { table *typedef.Table rand *rand.Rand @@ -57,9 +60,7 @@ func (g *MockGenerator) GetOld() *typedef.ValueWithToken { return &typedef.ValueWithToken{Token: token, Value: values} } -func (g *MockGenerator) GiveOld(_ *typedef.ValueWithToken) {} - -func (g *MockGenerator) GiveOlds(_ []*typedef.ValueWithToken) {} +func (g *MockGenerator) GiveOld(_ ...*typedef.ValueWithToken) {} func (g *MockGenerator) ReleaseToken(_ uint64) { } diff --git a/pkg/typedef/feature.go b/pkg/typedef/feature.go new file mode 100644 index 00000000..5b0b215c --- /dev/null +++ b/pkg/typedef/feature.go @@ -0,0 +1,28 @@ +// Copyright 2019 ScyllaDB +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typedef + +import "strings" + +func ParseCQLFeature(feature string) CQLFeature { + switch strings.ToLower(feature) { + case "all": + return CQL_FEATURE_ALL + case "normal": + return CQL_FEATURE_NORMAL + default: + return CQL_FEATURE_BASIC + } +} diff --git a/pkg/typedef/simple_type.go b/pkg/typedef/simple_type.go index 4a02a612..edeff8f8 100644 --- a/pkg/typedef/simple_type.go +++ b/pkg/typedef/simple_type.go @@ -295,7 +295,7 @@ func (st SimpleType) genValue(r *rand.Rand, p *PartitionRangeConfig) any { case TYPE_FLOAT: return r.Float32() case TYPE_INET: - return net.ParseIP(utils.RandIPV4Address(r, r.Intn(255), 2)).String() + return net.ParseIP(utils.RandIPV4Address(r)).String() case TYPE_INT: return r.Int31() case TYPE_SMALLINT: diff --git a/pkg/typedef/table.go b/pkg/typedef/table.go index 1af82947..2479d58b 100644 --- a/pkg/typedef/table.go +++ b/pkg/typedef/table.go @@ -105,14 +105,13 @@ func (t *Table) ValidColumnsForDelete() Columns { } } } - if len(t.MaterializedViews) != 0 { - for _, mv := range t.MaterializedViews { - if mv.HaveNonPrimaryKey() { - for j := range validCols { - if validCols[j].Name == mv.NonPrimaryKey.Name { - validCols = append(validCols[:j], validCols[j+1:]...) - break - } + + for _, mv := range t.MaterializedViews { + if mv.HaveNonPrimaryKey() { + for j := range validCols { + if validCols[j].Name == mv.NonPrimaryKey.Name { + validCols = append(validCols[:j], validCols[j+1:]...) + break } } } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 0b775d48..841b30c8 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -21,6 +21,9 @@ import ( "strings" "time" + "github.com/pkg/errors" + "golang.org/x/exp/constraints" + "github.com/gocql/gocql" "golang.org/x/exp/rand" ) @@ -54,22 +57,34 @@ func RandTime(rnd *rand.Rand) int64 { return rnd.Int63n(86400000000000) } -func RandIPV4Address(rnd *rand.Rand, v, pos int) string { +func ipV4Builder[T constraints.Integer](bytes [4]T) string { + var builder strings.Builder + builder.Grow(16) // Maximum length of an IPv4 address + + for _, b := range bytes { + builder.WriteString(strconv.FormatUint(uint64(b), 10)) + builder.WriteRune('.') + } + + return builder.String()[:builder.Len()-1] +} + +func RandIPV4Address(rnd *rand.Rand) string { + return ipV4Builder([4]int{rnd.Intn(256), rnd.Intn(256), rnd.Intn(256), rnd.Intn(256)}) +} + +func RandIPV4AddressPositional(rnd *rand.Rand, v, pos int) string { if pos < 0 || pos > 4 { panic(fmt.Sprintf("invalid position for the desired value of the IP part %d, 0-3 supported", pos)) } if v < 0 || v > 255 { panic(fmt.Sprintf("invalid value for the desired position %d of the IP, 0-255 suppoerted", v)) } - var blocks []string - for i := 0; i < 4; i++ { - if i == pos { - blocks = append(blocks, strconv.Itoa(v)) - } else { - blocks = append(blocks, strconv.Itoa(rnd.Intn(255))) - } - } - return strings.Join(blocks, ".") + + data := [4]int{rnd.Intn(255), rnd.Intn(255), rnd.Intn(255), rnd.Intn(255)} + data[pos] = v + + return ipV4Builder(data) } func RandInt2(rnd *rand.Rand, min, max int) int { @@ -107,3 +122,12 @@ func UUIDFromTime(rnd *rand.Rand) string { } return gocql.UUIDFromTime(RandDate(rnd)).String() } + +func UnwrapErr(err error) error { + nextErr := err + for nextErr != nil { + err = nextErr + nextErr = errors.Unwrap(err) + } + return err +} diff --git a/results/.gitkeep b/results/.gitkeep new file mode 100644 index 00000000..e69de29b