diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 00000000..2be6c4c2
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,12 @@
+version: 2
+
+updates:
+ - package-ecosystem: "github-actions"
+ directory: /
+ schedule:
+ interval: "monthly"
+
+ - package-ecosystem: "gomod"
+ directory: /
+ schedule:
+ interval: "weekly"
diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml
index be86ef16..814bc554 100644
--- a/.github/workflows/integration-tests.yml
+++ b/.github/workflows/integration-tests.yml
@@ -29,9 +29,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- gemini-features: ["basic", "normal", "all"]
+ gemini-features: ["basic", "normal"]
gemini-concurrency: [4]
- duration: ["5m"]
+ duration: ["1m"]
dataset-size: [large, small]
oracle-scylla-version: ["6.1"]
test-scylla-version: ["6.2"]
@@ -54,6 +54,7 @@ jobs:
CONCURRENCY=${{ matrix.gemini-concurrency }} \
CQL_FEATURES=${{ matrix.gemini-features }} \
DURATION=${{ matrix.duration }} \
+ WARMUP=30s \
DATASET_SIZE=${{ matrix.dataset-size }} \
- name: Shutdown ScyllaDB
shell: bash
diff --git a/.gitignore b/.gitignore
index a8d923ff..3704f90f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,3 +4,4 @@ cmd/gemini/dist/
bin/
coverage.txt
dist/
+results/*.log
diff --git a/.run/Run Gemini Mixed.run.xml b/.run/Run Gemini Mixed.run.xml
new file mode 100644
index 00000000..c6d7c6ca
--- /dev/null
+++ b/.run/Run Gemini Mixed.run.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.run/Run Gemini Read.run.xml b/.run/Run Gemini Read.run.xml
new file mode 100644
index 00000000..2c851167
--- /dev/null
+++ b/.run/Run Gemini Read.run.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.run/Run Gemini Write.run.xml b/.run/Run Gemini Write.run.xml
new file mode 100644
index 00000000..a2225e64
--- /dev/null
+++ b/.run/Run Gemini Write.run.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index 9c11c89e..660ed522 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,4 @@
FROM golang:1.23-bookworm AS build
-
ENV GO111MODULE=on
ENV GOAMD64=v3
ENV GOARM64=v8.3,crypto
diff --git a/Makefile b/Makefile
index 16778510..3d9da0ae 100644
--- a/Makefile
+++ b/Makefile
@@ -20,7 +20,7 @@ define dl_tgz
chmod +x "$(GOBIN)/$(1)"; \
fi
endef
-
+
$(GOBIN)/golangci-lint: GOLANGCI_VERSION = 1.62.0
$(GOBIN)/golangci-lint: Makefile
$(call dl_tgz,golangci-lint,https://github.com/golangci/golangci-lint/releases/download/v$(GOLANGCI_VERSION)/golangci-lint-$(GOLANGCI_VERSION)-$(GOOS)-amd64.tar.gz)
@@ -67,36 +67,33 @@ scylla-shutdown:
test:
@go test -covermode=atomic -race -coverprofile=coverage.txt -timeout 5m -json -v ./... 2>&1 | gotestfmt -showteststatus
-CQL_FEATURES ?= all
-CONCURRENCY ?= 50
+CQL_FEATURES ?= normal
+CONCURRENCY ?= 1
DURATION ?= 10m
-WARMUP ?= 1m
+WARMUP ?= 0
+MODE ?= mixed
DATASET_SIZE ?= large
SEED ?= $(shell date +%s)
GEMINI_BINARY ?= $(PWD)/bin/gemini
GEMINI_TEST_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Networks.gemini.IPAddress }}' gemini-test)
GEMINI_ORACLE_CLUSTER ?= $(shell docker inspect --format='{{ .NetworkSettings.Networks.gemini.IPAddress }}' gemini-oracle)
GEMINI_DOCKER_NETWORK ?= gemini
-GEMINI_FLAGS = --fail-fast \
+GEMINI_FLAGS =--fail-fast \
--level=info \
--non-interactive \
- --materialized-views=false \
--consistency=LOCAL_QUORUM \
--test-host-selection-policy=token-aware \
--oracle-host-selection-policy=token-aware \
- --mode=mixed \
+ --mode=$(MODE) \
--non-interactive \
--request-timeout=5s \
--connect-timeout=15s \
--use-server-timestamps=false \
--async-objects-stabilization-attempts=10 \
--max-mutation-retries=10 \
- --async-objects-stabilization-backoff=1000ms \
- --max-mutation-retries-backoff=1000ms \
--replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \
--oracle-replication-strategy="{'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}" \
--concurrency=$(CONCURRENCY) \
- --use-lwt=true \
--dataset-size=$(DATASET_SIZE) \
--seed=$(SEED) \
--schema-seed=$(SEED) \
@@ -108,19 +105,23 @@ GEMINI_FLAGS = --fail-fast \
.PHONY: pprof-profile
pprof-profile:
- go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/profile
+ go tool pprof -http=:8080 http://localhost:6060/debug/pprof/profile
.PHONY: pprof-heap
pprof-heap:
- go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/heap
+ go tool pprof -http=:8081 http://localhost:6060/debug/pprof/heap
.PHONY: pprof-goroutine
pprof-goroutine:
- go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/goroutine
+ go tool pprof -http=:8082 http://localhost:6060/debug/pprof/goroutine
.PHONY: pprof-block
pprof-block:
- go tool pprof -http=:8080 -intel_syntax -call_tree -seconds 60 http://localhost:6060/debug/pprof/block
+ go tool pprof -http=:8083 http://localhost:6060/debug/pprof/block
+
+.PHONY: pprof-mutex
+pprof-mutex:
+ go tool pprof -http=:8084 http://localhost:6060/debug/pprof/mutex
.PHONY: docker-integration-test
docker-integration-test:
@@ -130,6 +131,7 @@ docker-integration-test:
docker run \
-it \
--rm \
+ --memory=4G \
-p 6060:6060 \
--name gemini \
--network $(GEMINI_DOCKER_NETWORK) \
@@ -138,24 +140,16 @@ docker-integration-test:
scylladb/gemini:$(DOCKER_VERSION) \
--test-cluster=gemini-test \
--oracle-cluster=gemini-oracle \
- --outfile=/results/gemini_result.log \
- --tracing-outfile=/results/gemini_tracing.log \
- --test-statement-log-file=/results/gemini_test_statement.log \
- --oracle-statement-log-file=/results/gemini_oracle_statement.log \
$(GEMINI_FLAGS)
.PHONY: integration-test
integration-test:
- @mkdir -p $(PWD)/results
- @touch $(PWD)/results/gemini_seed
- @echo $(GEMINI_SEED) > $(PWD)/results/gemini_seed
- @$(GEMINI_BINARY) \
+ mkdir -p $(PWD)/results
+ touch $(PWD)/results/gemini_seed
+ echo $(GEMINI_SEED) > $(PWD)/results/gemini_seed
+ $(GEMINI_BINARY) \
--test-cluster=$(GEMINI_TEST_CLUSTER) \
--oracle-cluster=$(GEMINI_ORACLE_CLUSTER) \
- --outfile=$(PWD)/results/gemini_result.log \
- --tracing-outfile=$(PWD)/results/gemini_tracing.log \
- --test-statement-log-file=$(PWD)/results/gemini_test_statement.log \
- --oracle-statement-log-file=$(PWD)/results/gemini_oracle_statement.log \
$(GEMINI_FLAGS)
.PHONY: clean
diff --git a/cmd/gemini/flags.go b/cmd/gemini/flags.go
new file mode 100644
index 00000000..9b6abfa6
--- /dev/null
+++ b/cmd/gemini/flags.go
@@ -0,0 +1,162 @@
+// Copyright 2019 ScyllaDB
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "time"
+
+ "github.com/scylladb/gemini/pkg/generators"
+ "github.com/scylladb/gemini/pkg/jobs"
+)
+
+var (
+ testClusterHost []string
+ testClusterUsername string
+ testClusterPassword string
+ oracleClusterHost []string
+ oracleClusterUsername string
+ oracleClusterPassword string
+ schemaFile string
+ outFileArg string
+ concurrency uint64
+ seed string
+ schemaSeed string
+ dropSchema bool
+ verbose bool
+ mode string
+ failFast bool
+ nonInteractive bool
+ duration time.Duration
+ bind string
+ warmup time.Duration
+ replicationStrategy string
+ tableOptions []string
+ oracleReplicationStrategy string
+ consistency string
+ maxTables int
+ maxPartitionKeys int
+ minPartitionKeys int
+ maxClusteringKeys int
+ minClusteringKeys int
+ maxColumns int
+ minColumns int
+ datasetSize string
+ cqlFeatures string
+ useMaterializedViews bool
+ level string
+ maxRetriesMutate int
+ maxRetriesMutateSleep time.Duration
+ maxErrorsToStore int
+ pkBufferReuseSize uint64
+ partitionCount uint64
+ partitionKeyDistribution string
+ normalDistMean float64
+ normalDistSigma float64
+ tracingOutFile string
+ useCounters bool
+ asyncObjectStabilizationAttempts int
+ asyncObjectStabilizationDelay time.Duration
+ useLWT bool
+ testClusterHostSelectionPolicy string
+ oracleClusterHostSelectionPolicy string
+ useServerSideTimestamps bool
+ requestTimeout time.Duration
+ connectTimeout time.Duration
+ profilingPort int
+ testStatementLogFile string
+ oracleStatementLogFile string
+)
+
+func init() {
+ rootCmd.Version = version + ", commit " + commit + ", date " + date
+ rootCmd.Flags().StringSliceVarP(&testClusterHost, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test")
+ _ = rootCmd.MarkFlagRequired("test-cluster")
+ rootCmd.Flags().StringVarP(&testClusterUsername, "test-username", "", "", "Username for the test cluster")
+ rootCmd.Flags().StringVarP(&testClusterPassword, "test-password", "", "", "Password for the test cluster")
+ rootCmd.Flags().StringSliceVarP(
+ &oracleClusterHost, "oracle-cluster", "o", []string{},
+ "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used")
+ rootCmd.Flags().StringVarP(&oracleClusterUsername, "oracle-username", "", "", "Username for the oracle cluster")
+ rootCmd.Flags().StringVarP(&oracleClusterPassword, "oracle-password", "", "", "Password for the oracle cluster")
+ rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file")
+ rootCmd.Flags().StringVarP(&mode, "mode", "m", jobs.MixedMode, "Query operation mode. Mode options: write, read, mixed (default)")
+ rootCmd.Flags().Uint64VarP(&concurrency, "concurrency", "c", 10, "Number of threads per table to run concurrently")
+ rootCmd.Flags().StringVarP(&seed, "seed", "s", "random", "Statement seed value")
+ rootCmd.Flags().StringVarP(&schemaSeed, "schema-seed", "", "random", "Schema seed value")
+ rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", true, "Drop schema before starting tests run")
+ rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output during test run")
+ rootCmd.Flags().BoolVarP(&failFast, "fail-fast", "f", false, "Stop on the first failure")
+ rootCmd.Flags().BoolVarP(&nonInteractive, "non-interactive", "", false, "Run in non-interactive mode (disable progress indicator)")
+ rootCmd.Flags().DurationVarP(&duration, "duration", "", 30*time.Second, "")
+ rootCmd.Flags().StringVarP(&outFileArg, "outfile", "", "", "Specify the name of the file where the results should go")
+ rootCmd.Flags().StringVarP(&bind, "bind", "b", ":2112", "Specify the interface and port which to bind prometheus metrics on. Default is ':2112'")
+ rootCmd.Flags().DurationVarP(&warmup, "warmup", "", 30*time.Second, "Specify the warmup perid as a duration for example 30s or 10h")
+ rootCmd.Flags().StringVarP(
+ &replicationStrategy, "replication-strategy", "", "simple",
+ "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide "+
+ "the entire specification in the form {'class':'....'}")
+ rootCmd.Flags().StringVarP(
+ &oracleReplicationStrategy, "oracle-replication-strategy", "", "simple",
+ "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each "+
+ "type or provide the entire specification in the form {'class':'....'}")
+ rootCmd.Flags().StringArrayVarP(&tableOptions, "table-options", "", []string{}, "Repeatable argument to set table options to be added to the created tables")
+ rootCmd.Flags().StringVarP(&consistency, "consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE")
+ rootCmd.Flags().IntVarP(&maxTables, "max-tables", "", 1, "Maximum number of generated tables")
+ rootCmd.Flags().IntVarP(&maxPartitionKeys, "max-partition-keys", "", 6, "Maximum number of generated partition keys")
+ rootCmd.Flags().IntVarP(&minPartitionKeys, "min-partition-keys", "", 2, "Minimum number of generated partition keys")
+ rootCmd.Flags().IntVarP(&maxClusteringKeys, "max-clustering-keys", "", 4, "Maximum number of generated clustering keys")
+ rootCmd.Flags().IntVarP(&minClusteringKeys, "min-clustering-keys", "", 2, "Minimum number of generated clustering keys")
+ rootCmd.Flags().IntVarP(&maxColumns, "max-columns", "", 16, "Maximum number of generated columns")
+ rootCmd.Flags().IntVarP(&minColumns, "min-columns", "", 8, "Minimum number of generated columns")
+ rootCmd.Flags().StringVarP(&datasetSize, "dataset-size", "", "large", "Specify the type of dataset size to use, small|large")
+ rootCmd.Flags().StringVarP(&cqlFeatures, "cql-features", "", "normal", "Specify the type of cql features to use, basic|normal|all")
+ rootCmd.Flags().BoolVarP(&useMaterializedViews, "materialized-views", "", false, "Run gemini with materialized views support")
+ rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal")
+ rootCmd.Flags().IntVarP(&maxRetriesMutate, "max-mutation-retries", "", 10, "Maximum number of attempts to apply a mutation")
+ rootCmd.Flags().DurationVarP(
+ &maxRetriesMutateSleep, "max-mutation-retries-backoff", "", 10*time.Millisecond,
+ "Duration between attempts to apply a mutation for example 10ms or 1s")
+ rootCmd.Flags().Uint64VarP(&pkBufferReuseSize, "partition-key-buffer-reuse-size", "", 1000, "Number of reused buffered partition keys")
+ rootCmd.Flags().Uint64VarP(&partitionCount, "token-range-slices", "", 1000, "Number of slices to divide the token space into")
+ rootCmd.Flags().StringVarP(
+ &partitionKeyDistribution, "partition-key-distribution", "", "uniform",
+ "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf")
+ rootCmd.Flags().Float64VarP(&normalDistMean, "normal-dist-mean", "", generators.StdDistMean, "Mean of the normal distribution")
+ rootCmd.Flags().Float64VarP(&normalDistSigma, "normal-dist-sigma", "", generators.OneStdDev, "Sigma of the normal distribution, defaults to one standard deviation ~0.341")
+ rootCmd.Flags().StringVarP(
+ &tracingOutFile, "tracing-outfile", "", "",
+ "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.")
+ rootCmd.Flags().BoolVarP(&useCounters, "use-counters", "", false, "Ensure that at least one table is a counter table")
+ rootCmd.Flags().IntVarP(
+ &asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10,
+ "Maximum number of attempts to validate result sets from MV and SI")
+ rootCmd.Flags().DurationVarP(
+ &asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond,
+ "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s")
+ rootCmd.Flags().BoolVarP(&useLWT, "use-lwt", "", false, "Emit LWT based updates")
+ rootCmd.Flags().StringVarP(
+ &oracleClusterHostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware",
+ "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware")
+ rootCmd.Flags().StringVarP(
+ &testClusterHostSelectionPolicy, "test-host-selection-policy", "", "token-aware",
+ "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware")
+ rootCmd.Flags().BoolVarP(&useServerSideTimestamps, "use-server-timestamps", "", false, "Use server-side generated timestamps for writes")
+ rootCmd.Flags().DurationVarP(&requestTimeout, "request-timeout", "", 30*time.Second, "Duration of waiting request execution")
+ rootCmd.Flags().DurationVarP(&connectTimeout, "connect-timeout", "", 30*time.Second, "Duration of waiting connection established")
+ rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'")
+ rootCmd.Flags().IntVarP(&maxErrorsToStore, "max-errors-to-store", "", 1000, "Maximum number of errors to store and output at the end")
+ rootCmd.Flags().StringVarP(&testStatementLogFile, "test-statement-log-file", "", "", "File to write statements flow to")
+ rootCmd.Flags().StringVarP(&oracleStatementLogFile, "oracle-statement-log-file", "", "", "File to write statements flow to")
+}
diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go
index ed4d4e2c..4c3a7e8c 100644
--- a/cmd/gemini/root.go
+++ b/cmd/gemini/root.go
@@ -15,31 +15,18 @@
package main
import (
- "encoding/json"
"fmt"
"log"
- "math"
"net/http"
"net/http/pprof"
"os"
+ "os/signal"
"strconv"
"strings"
+ "syscall"
"text/tabwriter"
"time"
- "github.com/scylladb/gemini/pkg/auth"
- "github.com/scylladb/gemini/pkg/builders"
- "github.com/scylladb/gemini/pkg/generators"
- "github.com/scylladb/gemini/pkg/jobs"
- "github.com/scylladb/gemini/pkg/realrandom"
- "github.com/scylladb/gemini/pkg/replication"
- "github.com/scylladb/gemini/pkg/store"
- "github.com/scylladb/gemini/pkg/typedef"
- "github.com/scylladb/gemini/pkg/utils"
-
- "github.com/scylladb/gemini/pkg/status"
- "github.com/scylladb/gemini/pkg/stop"
-
"github.com/gocql/gocql"
"github.com/hailocab/go-hostpool"
"github.com/pkg/errors"
@@ -49,94 +36,25 @@ import (
"go.uber.org/zap/zapcore"
"golang.org/x/exp/rand"
"golang.org/x/net/context"
- "gonum.org/v1/gonum/stat/distuv"
-)
-var (
- testClusterHost []string
- testClusterUsername string
- testClusterPassword string
- oracleClusterHost []string
- oracleClusterUsername string
- oracleClusterPassword string
- schemaFile string
- outFileArg string
- concurrency uint64
- seed string
- schemaSeed string
- dropSchema bool
- verbose bool
- mode string
- failFast bool
- nonInteractive bool
- duration time.Duration
- bind string
- warmup time.Duration
- replicationStrategy string
- tableOptions []string
- oracleReplicationStrategy string
- consistency string
- maxTables int
- maxPartitionKeys int
- minPartitionKeys int
- maxClusteringKeys int
- minClusteringKeys int
- maxColumns int
- minColumns int
- datasetSize string
- cqlFeatures string
- useMaterializedViews bool
- level string
- maxRetriesMutate int
- maxRetriesMutateSleep time.Duration
- maxErrorsToStore int
- pkBufferReuseSize uint64
- partitionCount uint64
- partitionKeyDistribution string
- normalDistMean float64
- normalDistSigma float64
- tracingOutFile string
- useCounters bool
- asyncObjectStabilizationAttempts int
- asyncObjectStabilizationDelay time.Duration
- useLWT bool
- testClusterHostSelectionPolicy string
- oracleClusterHostSelectionPolicy string
- useServerSideTimestamps bool
- requestTimeout time.Duration
- connectTimeout time.Duration
- profilingPort int
- testStatementLogFile string
- oracleStatementLogFile string
+ "github.com/scylladb/gemini/pkg/auth"
+ "github.com/scylladb/gemini/pkg/generators"
+ "github.com/scylladb/gemini/pkg/jobs"
+ "github.com/scylladb/gemini/pkg/realrandom"
+ "github.com/scylladb/gemini/pkg/store"
+ "github.com/scylladb/gemini/pkg/typedef"
+ "github.com/scylladb/gemini/pkg/utils"
+
+ "github.com/scylladb/gemini/pkg/status"
)
func interactive() bool {
return !nonInteractive
}
-func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Schema, error) {
- byteValue, err := os.ReadFile(confFile)
- if err != nil {
- return nil, err
- }
-
- var shm typedef.Schema
-
- err = json.Unmarshal(byteValue, &shm)
- if err != nil {
- return nil, err
- }
-
- schemaBuilder := builders.NewSchemaBuilder()
- schemaBuilder.Keyspace(shm.Keyspace).Config(schemaConfig)
- for t, tbl := range shm.Tables {
- shm.Tables[t].LinkIndexAndColumns()
- schemaBuilder.Table(tbl)
- }
- return schemaBuilder.Build(), nil
-}
-
func run(_ *cobra.Command, _ []string) error {
+ start := time.Now()
+
logger := createLogger(level)
globalStatus := status.NewGlobalStatus(1000)
defer utils.IgnoreError(logger.Sync)
@@ -153,11 +71,7 @@ func run(_ *cobra.Command, _ []string) error {
rand.Seed(intSeed)
- cons, err := gocql.ParseConsistencyWrapper(consistency)
- if err != nil {
- logger.Error("Unable parse consistency, error=%s. Falling back on Quorum", zap.Error(err))
- cons = gocql.Quorum
- }
+ cons := gocql.ParseConsistency(consistency)
testHostSelectionPolicy, err := getHostSelectionPolicy(testClusterHostSelectionPolicy, testClusterHost)
if err != nil {
@@ -168,15 +82,23 @@ func run(_ *cobra.Command, _ []string) error {
return err
}
+ ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGABRT, syscall.SIGTERM)
+ defer cancel()
+
go func() {
- http.Handle("/metrics", promhttp.Handler())
- _ = http.ListenAndServe(bind, nil)
+ mux := http.NewServeMux()
+ mux.Handle("GET /metrics", promhttp.Handler())
+ log.Fatal(http.ListenAndServe(bind, mux))
}()
if profilingPort != 0 {
go func() {
mux := http.NewServeMux()
- mux.HandleFunc("/profile", pprof.Profile)
+ mux.HandleFunc("GET /debug/pprof/", pprof.Index)
+ mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
+ mux.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
+ mux.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
+ mux.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
log.Fatal(http.ListenAndServe(":"+strconv.Itoa(profilingPort), mux))
}()
}
@@ -204,10 +126,7 @@ func run(_ *cobra.Command, _ []string) error {
}
}
- jsonSchema, _ := json.MarshalIndent(schema, "", " ")
-
- printSetup(intSeed, intSchemaSeed)
- fmt.Printf("Schema: %v\n", string(jsonSchema))
+ printSetup(intSeed, intSchemaSeed, mode)
testCluster, oracleCluster := createClusters(cons, testHostSelectionPolicy, oracleHostSelectionPolicy, logger)
storeConfig := store.Config{
@@ -230,7 +149,10 @@ func run(_ *cobra.Command, _ []string) error {
return ioErr
}
tracingFile = tf
- defer utils.IgnoreError(tracingFile.Sync)
+ defer func() {
+ utils.IgnoreError(tracingFile.Sync)
+ utils.IgnoreError(tracingFile.Close)
+ }()
}
}
st, err := store.New(schema, testCluster, oracleCluster, storeConfig, tracingFile, logger)
@@ -250,39 +172,44 @@ func run(_ *cobra.Command, _ []string) error {
testKeyspace, oracleKeyspace := generators.GetCreateKeyspaces(schema)
if err = st.Create(
- context.Background(),
+ ctx,
typedef.SimpleStmt(testKeyspace, typedef.CreateKeyspaceStatementType),
typedef.SimpleStmt(oracleKeyspace, typedef.CreateKeyspaceStatementType)); err != nil {
return errors.Wrap(err, "unable to create keyspace")
}
- for _, stmt := range generators.GetCreateSchema(schema) {
+ for _, stmt := range generators.GetCreateSchema(schema, useMaterializedViews) {
logger.Debug(stmt)
- if err = st.Mutate(context.Background(), typedef.SimpleStmt(stmt, typedef.CreateSchemaStatementType)); err != nil {
+ if err = st.Mutate(ctx, typedef.SimpleStmt(stmt, typedef.CreateSchemaStatementType)); err != nil {
return errors.Wrap(err, "unable to create schema")
}
}
- ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2)
- stopFlag := stop.NewFlag("main")
- warmupStopFlag := stop.NewFlag("warmup")
- stop.StartOsSignalsTransmitter(logger, stopFlag, warmupStopFlag)
- pump := jobs.NewPump(stopFlag, logger)
+ distFunc, err := generators.ParseDistributionDefault(partitionKeyDistribution, partitionCount, intSeed)
+ if err != nil {
+ return err
+ }
- gens, err := createGenerators(schema, schemaConfig, intSeed, partitionCount, logger)
+ gens, err := generators.New(ctx, schema, schemaConfig, intSeed, partitionCount, logger, distFunc, pkBufferReuseSize)
if err != nil {
return err
}
- gens.StartAll(stopFlag)
+
+ defer utils.IgnoreError(gens.Close)
+
+ ctx, done := context.WithTimeout(ctx, duration+warmup+10*time.Second)
+ defer done()
if !nonInteractive {
sp := createSpinner(interactive())
ticker := time.NewTicker(time.Second)
+
go func() {
- defer done()
+ defer ticker.Stop()
+
for {
select {
- case <-stopFlag.SignalChannel():
+ case <-ctx.Done():
return
case <-ticker.C:
sp.Set(" Running Gemini... %v", globalStatus)
@@ -291,68 +218,32 @@ func run(_ *cobra.Command, _ []string) error {
}()
}
- if warmup > 0 && !stopFlag.IsHardOrSoft() {
- jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency)
- if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, warmupStopFlag, failFast, verbose); err != nil {
- logger.Error("warmup encountered an error", zap.Error(err))
- stopFlag.SetHard(true)
- }
+ jobsList := jobs.New(mode, duration, concurrency, logger, schema, st, globalStatus, intSeed, gens, warmup, failFast)
+ if err = jobsList.Run(ctx); err != nil {
+ logger.Error("error detected", zap.Error(err))
}
- if !stopFlag.IsHardOrSoft() {
- jobsList := jobs.ListFromMode(mode, duration, concurrency)
- if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, intSeed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil {
- logger.Debug("error detected", zap.Error(err))
- }
- }
logger.Info("test finished")
- globalStatus.PrintResult(outFile, schema, version)
+ globalStatus.PrintResult(outFile, schema, version, start)
+
if globalStatus.HasErrors() {
return errors.Errorf("gemini encountered errors, exiting with non zero status")
}
+
return nil
}
-func createFile(fname string, def *os.File) (*os.File, error) {
- if fname != "" {
- f, err := os.Create(fname)
+func createFile(name string, def *os.File) (*os.File, error) {
+ if name != "" {
+ f, err := os.Create(name)
if err != nil {
- return nil, errors.Wrapf(err, "Unable to open output file %s", fname)
+ return nil, errors.Wrapf(err, "Unable to open output file %s", name)
}
+
return f, nil
}
- return def, nil
-}
-const (
- stdDistMean = math.MaxUint64 / 2
- oneStdDev = 0.341 * math.MaxUint64
-)
-
-func createDistributionFunc(distribution string, size, seed uint64, mu, sigma float64) (generators.DistributionFunc, error) {
- switch strings.ToLower(distribution) {
- case "zipf":
- dist := rand.NewZipf(rand.New(rand.NewSource(seed)), 1.1, 1.1, size)
- return func() generators.TokenIndex {
- return generators.TokenIndex(dist.Uint64())
- }, nil
- case "normal":
- dist := distuv.Normal{
- Src: rand.NewSource(seed),
- Mu: mu,
- Sigma: sigma,
- }
- return func() generators.TokenIndex {
- return generators.TokenIndex(dist.Rand())
- }, nil
- case "uniform":
- rnd := rand.New(rand.NewSource(seed))
- return func() generators.TokenIndex {
- return generators.TokenIndex(rnd.Uint64n(size))
- }, nil
- default:
- return nil, errors.Errorf("unsupported distribution: %s", distribution)
- }
+ return def, nil
}
func createLogger(level string) *zap.Logger {
@@ -407,35 +298,8 @@ func createClusters(
return testCluster, oracleCluster
}
-func getReplicationStrategy(rs string, fallback *replication.Replication, logger *zap.Logger) *replication.Replication {
- switch rs {
- case "network":
- return replication.NewNetworkTopologyStrategy()
- case "simple":
- return replication.NewSimpleStrategy()
- default:
- replicationStrategy := &replication.Replication{}
- if err := json.Unmarshal([]byte(strings.ReplaceAll(rs, "'", "\"")), replicationStrategy); err != nil {
- logger.Error("unable to parse replication strategy", zap.String("strategy", rs), zap.Error(err))
- return fallback
- }
- return replicationStrategy
- }
-}
-
-func getCQLFeature(feature string) typedef.CQLFeature {
- switch strings.ToLower(feature) {
- case "all":
- return typedef.CQL_FEATURE_ALL
- case "normal":
- return typedef.CQL_FEATURE_NORMAL
- default:
- return typedef.CQL_FEATURE_BASIC
- }
-}
-
func getHostSelectionPolicy(policy string, hosts []string) (gocql.HostSelectionPolicy, error) {
- switch policy {
+ switch strings.ToLower(policy) {
case "round-robin":
return gocql.RoundRobinHostPolicy(), nil
case "host-pool":
@@ -454,89 +318,7 @@ var rootCmd = &cobra.Command{
SilenceUsage: true,
}
-func init() {
- rootCmd.Version = version + ", commit " + commit + ", date " + date
- rootCmd.Flags().StringSliceVarP(&testClusterHost, "test-cluster", "t", []string{}, "Host names or IPs of the test cluster that is system under test")
- _ = rootCmd.MarkFlagRequired("test-cluster")
- rootCmd.Flags().StringVarP(&testClusterUsername, "test-username", "", "", "Username for the test cluster")
- rootCmd.Flags().StringVarP(&testClusterPassword, "test-password", "", "", "Password for the test cluster")
- rootCmd.Flags().StringSliceVarP(
- &oracleClusterHost, "oracle-cluster", "o", []string{},
- "Host names or IPs of the oracle cluster that provides correct answers. If omitted no oracle will be used")
- rootCmd.Flags().StringVarP(&oracleClusterUsername, "oracle-username", "", "", "Username for the oracle cluster")
- rootCmd.Flags().StringVarP(&oracleClusterPassword, "oracle-password", "", "", "Password for the oracle cluster")
- rootCmd.Flags().StringVarP(&schemaFile, "schema", "", "", "Schema JSON config file")
- rootCmd.Flags().StringVarP(&mode, "mode", "m", jobs.MixedMode, "Query operation mode. Mode options: write, read, mixed (default)")
- rootCmd.Flags().Uint64VarP(&concurrency, "concurrency", "c", 10, "Number of threads per table to run concurrently")
- rootCmd.Flags().StringVarP(&seed, "seed", "s", "random", "Statement seed value")
- rootCmd.Flags().StringVarP(&schemaSeed, "schema-seed", "", "random", "Schema seed value")
- rootCmd.Flags().BoolVarP(&dropSchema, "drop-schema", "d", false, "Drop schema before starting tests run")
- rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output during test run")
- rootCmd.Flags().BoolVarP(&failFast, "fail-fast", "f", false, "Stop on the first failure")
- rootCmd.Flags().BoolVarP(&nonInteractive, "non-interactive", "", false, "Run in non-interactive mode (disable progress indicator)")
- rootCmd.Flags().DurationVarP(&duration, "duration", "", 30*time.Second, "")
- rootCmd.Flags().StringVarP(&outFileArg, "outfile", "", "", "Specify the name of the file where the results should go")
- rootCmd.Flags().StringVarP(&bind, "bind", "b", ":2112", "Specify the interface and port which to bind prometheus metrics on. Default is ':2112'")
- rootCmd.Flags().DurationVarP(&warmup, "warmup", "", 30*time.Second, "Specify the warmup perid as a duration for example 30s or 10h")
- rootCmd.Flags().StringVarP(
- &replicationStrategy, "replication-strategy", "", "simple",
- "Specify the desired replication strategy as either the coded short hand simple|network to get the default for each type or provide "+
- "the entire specification in the form {'class':'....'}")
- rootCmd.Flags().StringVarP(
- &oracleReplicationStrategy, "oracle-replication-strategy", "", "simple",
- "Specify the desired replication strategy of the oracle cluster as either the coded short hand simple|network to get the default for each "+
- "type or provide the entire specification in the form {'class':'....'}")
- rootCmd.Flags().StringArrayVarP(&tableOptions, "table-options", "", []string{}, "Repeatable argument to set table options to be added to the created tables")
- rootCmd.Flags().StringVarP(&consistency, "consistency", "", "LOCAL_QUORUM", "Specify the desired consistency as ANY|ONE|TWO|THREE|QUORUM|LOCAL_QUORUM|EACH_QUORUM|LOCAL_ONE")
- rootCmd.Flags().IntVarP(&maxTables, "max-tables", "", 1, "Maximum number of generated tables")
- rootCmd.Flags().IntVarP(&maxPartitionKeys, "max-partition-keys", "", 6, "Maximum number of generated partition keys")
- rootCmd.Flags().IntVarP(&minPartitionKeys, "min-partition-keys", "", 2, "Minimum number of generated partition keys")
- rootCmd.Flags().IntVarP(&maxClusteringKeys, "max-clustering-keys", "", 4, "Maximum number of generated clustering keys")
- rootCmd.Flags().IntVarP(&minClusteringKeys, "min-clustering-keys", "", 2, "Minimum number of generated clustering keys")
- rootCmd.Flags().IntVarP(&maxColumns, "max-columns", "", 16, "Maximum number of generated columns")
- rootCmd.Flags().IntVarP(&minColumns, "min-columns", "", 8, "Minimum number of generated columns")
- rootCmd.Flags().StringVarP(&datasetSize, "dataset-size", "", "large", "Specify the type of dataset size to use, small|large")
- rootCmd.Flags().StringVarP(&cqlFeatures, "cql-features", "", "basic", "Specify the type of cql features to use, basic|normal|all")
- rootCmd.Flags().BoolVarP(&useMaterializedViews, "materialized-views", "", false, "Run gemini with materialized views support")
- rootCmd.Flags().StringVarP(&level, "level", "", "info", "Specify the logging level, debug|info|warn|error|dpanic|panic|fatal")
- rootCmd.Flags().IntVarP(&maxRetriesMutate, "max-mutation-retries", "", 2, "Maximum number of attempts to apply a mutation")
- rootCmd.Flags().DurationVarP(
- &maxRetriesMutateSleep, "max-mutation-retries-backoff", "", 10*time.Millisecond,
- "Duration between attempts to apply a mutation for example 10ms or 1s")
- rootCmd.Flags().Uint64VarP(&pkBufferReuseSize, "partition-key-buffer-reuse-size", "", 100, "Number of reused buffered partition keys")
- rootCmd.Flags().Uint64VarP(&partitionCount, "token-range-slices", "", 10000, "Number of slices to divide the token space into")
- rootCmd.Flags().StringVarP(
- &partitionKeyDistribution, "partition-key-distribution", "", "uniform",
- "Specify the distribution from which to draw partition keys, supported values are currently uniform|normal|zipf")
- rootCmd.Flags().Float64VarP(&normalDistMean, "normal-dist-mean", "", stdDistMean, "Mean of the normal distribution")
- rootCmd.Flags().Float64VarP(&normalDistSigma, "normal-dist-sigma", "", oneStdDev, "Sigma of the normal distribution, defaults to one standard deviation ~0.341")
- rootCmd.Flags().StringVarP(
- &tracingOutFile, "tracing-outfile", "", "",
- "Specify the file to which tracing information gets written. Two magic names are available, 'stdout' and 'stderr'. By default tracing is disabled.")
- rootCmd.Flags().BoolVarP(&useCounters, "use-counters", "", false, "Ensure that at least one table is a counter table")
- rootCmd.Flags().IntVarP(
- &asyncObjectStabilizationAttempts, "async-objects-stabilization-attempts", "", 10,
- "Maximum number of attempts to validate result sets from MV and SI")
- rootCmd.Flags().DurationVarP(
- &asyncObjectStabilizationDelay, "async-objects-stabilization-backoff", "", 10*time.Millisecond,
- "Duration between attempts to validate result sets from MV and SI for example 10ms or 1s")
- rootCmd.Flags().BoolVarP(&useLWT, "use-lwt", "", false, "Emit LWT based updates")
- rootCmd.Flags().StringVarP(
- &oracleClusterHostSelectionPolicy, "oracle-host-selection-policy", "", "token-aware",
- "Host selection policy used by the driver for the oracle cluster: round-robin|host-pool|token-aware")
- rootCmd.Flags().StringVarP(
- &testClusterHostSelectionPolicy, "test-host-selection-policy", "", "token-aware",
- "Host selection policy used by the driver for the test cluster: round-robin|host-pool|token-aware")
- rootCmd.Flags().BoolVarP(&useServerSideTimestamps, "use-server-timestamps", "", false, "Use server-side generated timestamps for writes")
- rootCmd.Flags().DurationVarP(&requestTimeout, "request-timeout", "", 30*time.Second, "Duration of waiting request execution")
- rootCmd.Flags().DurationVarP(&connectTimeout, "connect-timeout", "", 30*time.Second, "Duration of waiting connection established")
- rootCmd.Flags().IntVarP(&profilingPort, "profiling-port", "", 0, "If non-zero starts pprof profiler on given port at 'http://0.0.0.0:/profile'")
- rootCmd.Flags().IntVarP(&maxErrorsToStore, "max-errors-to-store", "", 1000, "Maximum number of errors to store and output at the end")
- rootCmd.Flags().StringVarP(&testStatementLogFile, "test-statement-log-file", "", "", "File to write statements flow to")
- rootCmd.Flags().StringVarP(&oracleStatementLogFile, "oracle-statement-log-file", "", "", "File to write statements flow to")
-}
-
-func printSetup(seed, schemaSeed uint64) {
+func printSetup(seed, schemaSeed uint64, mode string) {
tw := new(tabwriter.Writer)
tw.Init(os.Stdout, 0, 8, 2, '\t', tabwriter.AlignRight)
fmt.Fprintf(tw, "Seed:\t%d\n", seed)
@@ -546,6 +328,7 @@ func printSetup(seed, schemaSeed uint64) {
fmt.Fprintf(tw, "Concurrency:\t%d\n", concurrency)
fmt.Fprintf(tw, "Test cluster:\t%s\n", testClusterHost)
fmt.Fprintf(tw, "Oracle cluster:\t%s\n", oracleClusterHost)
+ fmt.Fprintf(tw, "Mode:\t%s\n", mode)
if outFileArg == "" {
fmt.Fprintf(tw, "Output file:\t%s\n", "")
} else {
diff --git a/cmd/gemini/schema.go b/cmd/gemini/schema.go
index 68884a62..1f8a5e42 100644
--- a/cmd/gemini/schema.go
+++ b/cmd/gemini/schema.go
@@ -15,13 +15,25 @@
package main
import (
+ "encoding/json"
+ "os"
"strings"
+ "go.uber.org/zap"
+
+ "github.com/scylladb/gemini/pkg/builders"
"github.com/scylladb/gemini/pkg/replication"
"github.com/scylladb/gemini/pkg/tableopts"
"github.com/scylladb/gemini/pkg/typedef"
+)
- "go.uber.org/zap"
+const (
+ MaxBlobLength = 1e4
+ MinBlobLength = 0
+ MaxStringLength = 1000
+ MinStringLength = 0
+ MaxTupleParts = 20
+ MaxUDTParts = 20
)
func createSchemaConfig(logger *zap.Logger) typedef.SchemaConfig {
@@ -43,11 +55,13 @@ func createSchemaConfig(logger *zap.Logger) typedef.SchemaConfig {
MaxTupleParts: 2,
MaxBlobLength: 20,
MaxStringLength: 20,
+ MinBlobLength: 0,
+ MinStringLength: 0,
UseCounters: defaultConfig.UseCounters,
UseLWT: defaultConfig.UseLWT,
+ UseMaterializedViews: defaultConfig.UseMaterializedViews,
CQLFeature: defaultConfig.CQLFeature,
AsyncObjectStabilizationAttempts: defaultConfig.AsyncObjectStabilizationAttempts,
- UseMaterializedViews: defaultConfig.UseMaterializedViews,
AsyncObjectStabilizationDelay: defaultConfig.AsyncObjectStabilizationDelay,
}
default:
@@ -56,19 +70,9 @@ func createSchemaConfig(logger *zap.Logger) typedef.SchemaConfig {
}
func createDefaultSchemaConfig(logger *zap.Logger) typedef.SchemaConfig {
- const (
- MaxBlobLength = 1e4
- MinBlobLength = 0
- MaxStringLength = 1000
- MinStringLength = 0
- MaxTupleParts = 20
- MaxUDTParts = 20
- )
- rs := getReplicationStrategy(replicationStrategy, replication.NewSimpleStrategy(), logger)
- ors := getReplicationStrategy(oracleReplicationStrategy, rs, logger)
return typedef.SchemaConfig{
- ReplicationStrategy: rs,
- OracleReplicationStrategy: ors,
+ ReplicationStrategy: replication.MustParseReplication(replicationStrategy),
+ OracleReplicationStrategy: replication.MustParseReplication(oracleReplicationStrategy),
TableOptions: tableopts.CreateTableOptions(tableOptions, logger),
MaxTables: maxTables,
MaxPartitionKeys: maxPartitionKeys,
@@ -85,9 +89,31 @@ func createDefaultSchemaConfig(logger *zap.Logger) typedef.SchemaConfig {
MinStringLength: MinStringLength,
UseCounters: useCounters,
UseLWT: useLWT,
- CQLFeature: getCQLFeature(cqlFeatures),
+ CQLFeature: typedef.ParseCQLFeature(cqlFeatures),
UseMaterializedViews: useMaterializedViews,
AsyncObjectStabilizationAttempts: asyncObjectStabilizationAttempts,
AsyncObjectStabilizationDelay: asyncObjectStabilizationDelay,
}
}
+
+func readSchema(confFile string, schemaConfig typedef.SchemaConfig) (*typedef.Schema, error) {
+ byteValue, err := os.ReadFile(confFile)
+ if err != nil {
+ return nil, err
+ }
+
+ var shm typedef.Schema
+
+ err = json.Unmarshal(byteValue, &shm)
+ if err != nil {
+ return nil, err
+ }
+
+ schemaBuilder := builders.NewSchemaBuilder()
+ schemaBuilder.Keyspace(shm.Keyspace).Config(schemaConfig)
+ for t, tbl := range shm.Tables {
+ shm.Tables[t].LinkIndexAndColumns()
+ schemaBuilder.Table(tbl)
+ }
+ return schemaBuilder.Build(), nil
+}
diff --git a/cmd/gemini/strategies_test.go b/cmd/gemini/strategies_test.go
index 6859024e..29d0e7e8 100644
--- a/cmd/gemini/strategies_test.go
+++ b/cmd/gemini/strategies_test.go
@@ -20,46 +20,9 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "go.uber.org/zap"
-
- "github.com/scylladb/gemini/pkg/replication"
"github.com/scylladb/gemini/pkg/typedef"
)
-func TestGetReplicationStrategy(t *testing.T) {
- tests := map[string]struct {
- strategy string
- expected string
- }{
- "simple strategy": {
- strategy: "{\"class\": \"SimpleStrategy\", \"replication_factor\": \"1\"}",
- expected: "{'class':'SimpleStrategy','replication_factor':'1'}",
- },
- "simple strategy single quotes": {
- strategy: "{'class': 'SimpleStrategy', 'replication_factor': '1'}",
- expected: "{'class':'SimpleStrategy','replication_factor':'1'}",
- },
- "network topology strategy": {
- strategy: "{\"class\": \"NetworkTopologyStrategy\", \"dc1\": 3, \"dc2\": 3}",
- expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}",
- },
- "network topology strategy single quotes": {
- strategy: "{'class': 'NetworkTopologyStrategy', 'dc1': 3, 'dc2': 3}",
- expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}",
- },
- }
- logger := zap.NewNop()
- fallback := replication.NewSimpleStrategy()
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- got := getReplicationStrategy(tc.strategy, fallback, logger)
- if diff := cmp.Diff(got.ToCQL(), tc.expected); diff != "" {
- t.Errorf("expected=%s, got=%s,diff=%s", tc.strategy, got.ToCQL(), diff)
- }
- })
- }
-}
-
// TestReadExampleSchema main task of this test to be sure that schema example (schema.json) is correct and have correct marshal, unmarshal
func TestReadExampleSchema(t *testing.T) {
filePath := "schema.json"
diff --git a/docker/docker-compose-scylla.yml b/docker/docker-compose-scylla.yml
index d9e266f9..521c3161 100644
--- a/docker/docker-compose-scylla.yml
+++ b/docker/docker-compose-scylla.yml
@@ -21,7 +21,7 @@ services:
image: scylladb/scylla:${SCYLLA_TEST_VERSION:-6.2}
container_name: gemini-test
restart: unless-stopped
- command: --smp 2 --memory 1024M --api-address 0.0.0.0
+ command: --smp 1 --memory 1024M --api-address 0.0.0.0
networks:
gemini:
ipv4_address: 192.168.100.3
diff --git a/go.mod b/go.mod
index 4b5efcad..c3499609 100644
--- a/go.mod
+++ b/go.mod
@@ -17,7 +17,6 @@ require (
go.uber.org/zap v1.27.0
golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f
golang.org/x/net v0.31.0
- golang.org/x/sync v0.9.0
gonum.org/v1/gonum v0.15.1
gopkg.in/inf.v0 v0.9.1
)
@@ -42,4 +41,4 @@ require (
google.golang.org/protobuf v1.35.2 // indirect
)
-replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.3
+replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.4
diff --git a/go.sum b/go.sum
index 8a28967a..b15a1666 100644
--- a/go.sum
+++ b/go.sum
@@ -65,8 +65,8 @@ github.com/scylladb/go-reflectx v1.0.1 h1:b917wZM7189pZdlND9PbIJ6NQxfDPfBvUaQ7cj
github.com/scylladb/go-reflectx v1.0.1/go.mod h1:rWnOfDIRWBGN0miMLIcoPt/Dhi2doCMZqwMCJ3KupFc=
github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE=
github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs=
-github.com/scylladb/gocql v1.14.3 h1:f6ZFxM9plyAk0h7NZcXfZ1aJu3cGk0Mjy/X293gqIFA=
-github.com/scylladb/gocql v1.14.3/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0=
+github.com/scylladb/gocql v1.14.4 h1:MhevwCfyAraQ6RvZYFO3pF4Lt0YhvQlfg8Eo2HEqVQA=
+github.com/scylladb/gocql v1.14.4/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0=
github.com/scylladb/gocqlx/v2 v2.8.0 h1:f/oIgoEPjKDKd+RIoeHqexsIQVIbalVmT+axwvUqQUg=
github.com/scylladb/gocqlx/v2 v2.8.0/go.mod h1:4/+cga34PVqjhgSoo5Nr2fX1MQIqZB5eCE5DK4xeDig=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
diff --git a/pkg/jobs/pump.go b/pkg/burst/pump.go
similarity index 52%
rename from pkg/jobs/pump.go
rename to pkg/burst/pump.go
index c929f8ce..9cfbc61b 100644
--- a/pkg/jobs/pump.go
+++ b/pkg/burst/pump.go
@@ -12,40 +12,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package jobs
+package burst
import (
+ "context"
+ "math/rand/v2"
"time"
+)
- "github.com/scylladb/gemini/pkg/stop"
+const ChannelSize = 10000
- "go.uber.org/zap"
- "golang.org/x/exp/rand"
-)
+func work(ctx context.Context, pump chan<- time.Duration, chance int, sleepDuration time.Duration) {
+ defer close(pump)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ sleep := time.Duration(0)
-func NewPump(stopFlag *stop.Flag, logger *zap.Logger) chan time.Duration {
- pump := make(chan time.Duration, 10000)
- logger = logger.Named("Pump")
- go func() {
- logger.Debug("pump channel opened")
- defer func() {
- close(pump)
- logger.Debug("pump channel closed")
- }()
- for !stopFlag.IsHardOrSoft() {
- pump <- newHeartBeat()
- }
- }()
+ if rand.Int()%chance == 0 {
+ sleep = sleepDuration
+ }
- return pump
+ pump <- sleep
+ }
+ }
}
-func newHeartBeat() time.Duration {
- r := rand.Intn(10)
- switch r {
- case 0:
- return 10 * time.Millisecond
- default:
- return 0
- }
+func New(ctx context.Context, chance int, sleepDuration time.Duration) chan time.Duration {
+ pump := make(chan time.Duration, ChannelSize)
+ go work(ctx, pump, chance, sleepDuration)
+ return pump
}
diff --git a/pkg/generators/distribution.go b/pkg/generators/distribution.go
new file mode 100644
index 00000000..0d9c8d49
--- /dev/null
+++ b/pkg/generators/distribution.go
@@ -0,0 +1,59 @@
+// Copyright 2019 ScyllaDB
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package generators
+
+import (
+ "math"
+ "strings"
+
+ "github.com/pkg/errors"
+ "golang.org/x/exp/rand"
+ "gonum.org/v1/gonum/stat/distuv"
+)
+
+const (
+ StdDistMean = math.MaxUint64 / 2
+ OneStdDev = 0.341 * math.MaxUint64
+)
+
+func ParseDistributionDefault(distribution string, size, seed uint64) (DistributionFunc, error) {
+ return ParseDistribution(distribution, size, seed, StdDistMean, OneStdDev)
+}
+
+func ParseDistribution(distribution string, size, seed uint64, mu, sigma float64) (DistributionFunc, error) {
+ switch strings.ToLower(distribution) {
+ case "zipf":
+ dist := rand.NewZipf(rand.New(rand.NewSource(seed)), 1.1, 1.1, size)
+ return func() TokenIndex {
+ return TokenIndex(dist.Uint64())
+ }, nil
+ case "normal":
+ dist := distuv.Normal{
+ Src: rand.NewSource(seed),
+ Mu: mu,
+ Sigma: sigma,
+ }
+ return func() TokenIndex {
+ return TokenIndex(dist.Rand())
+ }, nil
+ case "uniform":
+ rnd := rand.New(rand.NewSource(seed))
+ return func() TokenIndex {
+ return TokenIndex(rnd.Uint64n(size))
+ }, nil
+ default:
+ return nil, errors.Errorf("unsupported distribution: %s", distribution)
+ }
+}
diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go
index 500fc759..f3c1e8f5 100644
--- a/pkg/generators/generator.go
+++ b/pkg/generators/generator.go
@@ -15,12 +15,14 @@
package generators
import (
+ "context"
+ "sync/atomic"
+
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/exp/rand"
"github.com/scylladb/gemini/pkg/routingkey"
- "github.com/scylladb/gemini/pkg/stop"
"github.com/scylladb/gemini/pkg/typedef"
)
@@ -37,14 +39,15 @@ type TokenIndex uint64
type DistributionFunc func() TokenIndex
-type GeneratorInterface interface {
+type Interface interface {
Get() *typedef.ValueWithToken
GetOld() *typedef.ValueWithToken
- GiveOld(_ *typedef.ValueWithToken)
- GiveOlds(_ []*typedef.ValueWithToken)
- ReleaseToken(_ uint64)
+ GiveOld(...*typedef.ValueWithToken)
+ ReleaseToken(uint64)
}
+var _ Interface = &Generator{}
+
type Generator struct {
logger *zap.Logger
table *typedef.Table
@@ -56,22 +59,14 @@ type Generator struct {
partitionsConfig typedef.PartitionRangeConfig
partitionCount uint64
- cntCreated uint64
- cntEmitted uint64
+ cntCreated atomic.Uint64
+ cntEmitted atomic.Uint64
}
func (g *Generator) PartitionCount() uint64 {
return g.partitionCount
}
-type Generators []*Generator
-
-func (g Generators) StartAll(stopFlag *stop.Flag) {
- for _, gen := range g {
- gen.Start(stopFlag)
- }
-}
-
type Config struct {
PartitionsDistributionFunc DistributionFunc
PartitionsRangeConfig typedef.PartitionRangeConfig
@@ -80,8 +75,9 @@ type Config struct {
PkUsedBufferSize uint64
}
-func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Generator {
- wakeUpSignal := make(chan struct{})
+func NewGenerator(table *typedef.Table, config Config, logger *zap.Logger) *Generator {
+ wakeUpSignal := make(chan struct{}, int(config.PartitionsCount))
+
return &Generator{
partitions: NewPartitions(int(config.PartitionsCount), int(config.PkUsedBufferSize), wakeUpSignal),
partitionCount: config.PartitionsCount,
@@ -96,12 +92,18 @@ func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Gen
}
func (g *Generator) Get() *typedef.ValueWithToken {
- targetPart := g.GetPartitionForToken(g.idxFunc())
- for targetPart.Stale() {
- targetPart = g.GetPartitionForToken(g.idxFunc())
+ var out *typedef.ValueWithToken
+
+ for out == nil {
+ targetPart := g.GetPartitionForToken(g.idxFunc())
+ for targetPart.Stale() {
+ targetPart = g.GetPartitionForToken(g.idxFunc())
+ }
+ out = targetPart.get()
}
- out := targetPart.get()
+
return out
+
}
func (g *Generator) GetPartitionForToken(token TokenIndex) *Partition {
@@ -111,22 +113,23 @@ func (g *Generator) GetPartitionForToken(token TokenIndex) *Partition {
// GetOld returns a previously used value and token or a new if
// the old queue is empty.
func (g *Generator) GetOld() *typedef.ValueWithToken {
- targetPart := g.GetPartitionForToken(g.idxFunc())
- for targetPart.Stale() {
- targetPart = g.GetPartitionForToken(g.idxFunc())
+ var out *typedef.ValueWithToken
+
+ for out == nil {
+ targetPart := g.GetPartitionForToken(g.idxFunc())
+ for targetPart.Stale() {
+ targetPart = g.GetPartitionForToken(g.idxFunc())
+ }
+ out = targetPart.getOld()
}
- return targetPart.getOld()
-}
-// GiveOld returns the supplied value for later reuse unless
-func (g *Generator) GiveOld(v *typedef.ValueWithToken) {
- g.GetPartitionForToken(TokenIndex(v.Token)).giveOld(v)
+ return out
}
-// GiveOlds returns the supplied values for later reuse unless
-func (g *Generator) GiveOlds(tokens []*typedef.ValueWithToken) {
+// GiveOld returns the supplied value for later reuse
+func (g *Generator) GiveOld(tokens ...*typedef.ValueWithToken) {
for _, token := range tokens {
- g.GiveOld(token)
+ g.GetPartitionForToken(TokenIndex(token.Token)).giveOld(token)
}
}
@@ -135,22 +138,27 @@ func (g *Generator) ReleaseToken(token uint64) {
g.GetPartitionForToken(TokenIndex(token)).releaseToken(token)
}
-func (g *Generator) Start(stopFlag *stop.Flag) {
- go func() {
- g.logger.Info("starting partition key generation loop")
- defer g.partitions.CloseAll()
- for {
- g.fillAllPartitions(stopFlag)
- select {
- case <-stopFlag.SignalChannel():
- g.logger.Debug("stopping partition key generation loop",
- zap.Uint64("keys_created", g.cntCreated),
- zap.Uint64("keys_emitted", g.cntEmitted))
- return
- case <-g.wakeUpSignal:
- }
+func (g *Generator) start(ctx context.Context) {
+ g.logger.Info("starting partition key generation loop")
+ defer func() {
+ g.logger.Info("stopping partition key generation loop",
+ zap.Uint64("keys_created", g.cntCreated.Load()),
+ zap.Uint64("keys_emitted", g.cntEmitted.Load()),
+ )
+
+ if err := g.partitions.Close(); err != nil {
+ g.logger.Error("failed to close partitions", zap.Error(err))
}
}()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-g.wakeUpSignal:
+ g.fillAllPartitions(ctx)
+ }
+ }
}
func (g *Generator) FindAndMarkStalePartitions() {
@@ -172,45 +180,41 @@ func (g *Generator) FindAndMarkStalePartitions() {
}
}
+func (g *Generator) fillPartition() {
+ // Be a bit smarter on how to fill partitions
+
+ values := CreatePartitionKeyValues(g.table, g.r, &g.partitionsConfig)
+ token, err := g.routingKeyCreator.GetHash(g.table, values)
+ if err != nil {
+ g.logger.Panic("failed to get primary key hash", zap.Error(err))
+ }
+ g.cntCreated.Add(1)
+ idx := token % g.partitionCount
+ partition := g.partitions[idx]
+ if partition.Stale() || partition.inFlight.Has(token) {
+ return
+ }
+ select {
+ case partition.values <- &typedef.ValueWithToken{Token: token, Value: values}:
+ g.cntEmitted.Add(1)
+ default:
+ }
+
+ return
+}
+
// fillAllPartitions guarantees that each partition was tested to be full
// at least once since the function started and before it ended.
// In other words no partition will be starved.
-func (g *Generator) fillAllPartitions(stopFlag *stop.Flag) {
- pFilled := make([]bool, len(g.partitions))
- allFilled := func() bool {
- for idx, filled := range pFilled {
- if !filled {
- if g.partitions[idx].Stale() {
- continue
- }
- return false
- }
- }
- return true
- }
- for !stopFlag.IsHardOrSoft() {
- values := CreatePartitionKeyValues(g.table, g.r, &g.partitionsConfig)
- token, err := g.routingKeyCreator.GetHash(g.table, values)
- if err != nil {
- g.logger.Panic(errors.Wrap(err, "failed to get primary key hash").Error())
- }
- g.cntCreated++
- idx := token % g.partitionCount
- partition := g.partitions[idx]
- if partition.Stale() || partition.inFlight.Has(token) {
- continue
- }
+func (g *Generator) fillAllPartitions(ctx context.Context) {
+ for {
select {
- case partition.values <- &typedef.ValueWithToken{Token: token, Value: values}:
- g.cntEmitted++
+ case <-ctx.Done():
+ return
default:
- if !pFilled[idx] {
- pFilled[idx] = true
- if allFilled() {
- return
- }
- }
}
+
+ g.fillPartition()
}
}
diff --git a/pkg/generators/generator_test.go b/pkg/generators/generator_test.go
index 3a46551c..26240ae9 100644
--- a/pkg/generators/generator_test.go
+++ b/pkg/generators/generator_test.go
@@ -15,24 +15,28 @@
package generators_test
import (
+ "context"
"sync/atomic"
"testing"
"go.uber.org/zap"
"github.com/scylladb/gemini/pkg/generators"
- "github.com/scylladb/gemini/pkg/stop"
"github.com/scylladb/gemini/pkg/typedef"
)
func TestGenerator(t *testing.T) {
t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ t.Cleanup(cancel)
+
table := &typedef.Table{
Name: "tbl",
PartitionKeys: generators.CreatePkColumns(1, "pk"),
}
var current uint64
- cfg := &generators.Config{
+ cfg := generators.Config{
PartitionsRangeConfig: typedef.PartitionRangeConfig{
MaxStringLength: 10,
MinStringLength: 0,
@@ -47,7 +51,7 @@ func TestGenerator(t *testing.T) {
}
logger, _ := zap.NewDevelopment()
generator := generators.NewGenerator(table, cfg, logger)
- generator.Start(stop.NewFlag("main_test"))
+ generator.start(ctx)
for i := uint64(0); i < cfg.PartitionsCount; i++ {
atomic.StoreUint64(¤t, i)
v := generator.Get()
diff --git a/cmd/gemini/generators.go b/pkg/generators/generators.go
similarity index 56%
rename from cmd/gemini/generators.go
rename to pkg/generators/generators.go
index 9b4f51de..f6e6323a 100644
--- a/cmd/gemini/generators.go
+++ b/pkg/generators/generators.go
@@ -11,47 +11,79 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+
+package generators
import (
- "github.com/scylladb/gemini/pkg/generators"
- "github.com/scylladb/gemini/pkg/typedef"
+ "context"
+ "sync"
"go.uber.org/zap"
+
+ "github.com/scylladb/gemini/pkg/typedef"
)
-func createGenerators(
+type Generators struct {
+ wg sync.WaitGroup
+ generators []*Generator
+ cancel context.CancelFunc
+ idx int
+}
+
+func (g *Generators) Get() *Generator {
+ gen := g.generators[g.idx%len(g.generators)]
+ g.idx++
+ return gen
+}
+
+func (g *Generators) Close() error {
+ g.cancel()
+ g.wg.Wait()
+
+ return nil
+}
+
+func New(
+ ctx context.Context,
schema *typedef.Schema,
schemaConfig typedef.SchemaConfig,
- seed, distributionSize uint64,
+ seed, partitionsCount uint64,
logger *zap.Logger,
-) (generators.Generators, error) {
+ distFunc DistributionFunc,
+ pkBufferReuseSize uint64,
+) (*Generators, error) {
partitionRangeConfig := schemaConfig.GetPartitionRangeConfig()
+ ctx, cancel := context.WithCancel(ctx)
- var gs []*generators.Generator
- for id := range schema.Tables {
- table := schema.Tables[id]
- pkVariations := table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig)
+ gens := &Generators{
+ generators: make([]*Generator, 0, len(schema.Tables)),
+ cancel: cancel,
+ }
+ gens.wg.Add(len(schema.Tables))
- distFunc, err := createDistributionFunc(partitionKeyDistribution, distributionSize, seed, stdDistMean, oneStdDev)
- if err != nil {
- return nil, err
- }
+ for _, table := range schema.Tables {
+ pkVariations := table.PartitionKeys.ValueVariationsNumber(&partitionRangeConfig)
- tablePartConfig := &generators.Config{
+ tablePartConfig := Config{
PartitionsRangeConfig: partitionRangeConfig,
- PartitionsCount: distributionSize,
+ PartitionsCount: partitionsCount,
PartitionsDistributionFunc: distFunc,
Seed: seed,
PkUsedBufferSize: pkBufferReuseSize,
}
- g := generators.NewGenerator(table, tablePartConfig, logger.Named("generators"))
+ g := NewGenerator(table, tablePartConfig, logger.Named("generators"))
+ go func() {
+ defer gens.wg.Done()
+ g.start(ctx)
+ }()
if pkVariations < 2^32 {
// Low partition key variation can lead to having staled partitions
// Let's detect and mark them before running test
g.FindAndMarkStalePartitions()
}
- gs = append(gs, g)
+
+ gens.generators = append(gens.generators, g)
}
- return gs, nil
+
+ return gens, nil
}
diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go
index e70d46c2..dd6c3817 100644
--- a/pkg/generators/partition.go
+++ b/pkg/generators/partition.go
@@ -15,7 +15,7 @@
package generators
import (
- "sync"
+ "sync/atomic"
"github.com/scylladb/gemini/pkg/inflight"
"github.com/scylladb/gemini/pkg/typedef"
@@ -26,29 +26,36 @@ type Partition struct {
oldValues chan *typedef.ValueWithToken
inFlight inflight.InFlight
wakeUpSignal chan<- struct{} // wakes up generator
- closed bool
- lock sync.RWMutex
- isStale bool
+ closed atomic.Bool
+ isStale atomic.Bool
+}
+
+func NewPartition(wakeUpSignal chan<- struct{}, pkBufferSize int) *Partition {
+ return &Partition{
+ values: make(chan *typedef.ValueWithToken, pkBufferSize),
+ oldValues: make(chan *typedef.ValueWithToken, pkBufferSize),
+ inFlight: inflight.New(),
+ wakeUpSignal: wakeUpSignal,
+ }
}
func (s *Partition) MarkStale() {
- s.isStale = true
+ s.isStale.Store(true)
s.Close()
}
func (s *Partition) Stale() bool {
- return s.isStale
+ return s.isStale.Load()
}
// get returns a new value and ensures that it's corresponding token
// is not already in-flight.
func (s *Partition) get() *typedef.ValueWithToken {
- for {
- v := s.pick()
- if v == nil || s.inFlight.AddIfNotPresent(v.Token) {
- return v
- }
+ if v := s.pick(); v != nil && !s.inFlight.Has(v.Token) {
+ return v
}
+
+ return nil
}
// getOld returns a previously used value and token or a new if
@@ -66,12 +73,12 @@ func (s *Partition) getOld() *typedef.ValueWithToken {
// is empty in which case it removes the corresponding token from the
// in-flight tracking.
func (s *Partition) giveOld(v *typedef.ValueWithToken) {
- ch := s.safelyGetOldValuesChannel()
- if ch == nil {
+ if s.closed.Load() {
return
}
+
select {
- case ch <- v:
+ case s.oldValues <- v:
default:
// Old partition buffer is full, just drop the value
}
@@ -98,55 +105,15 @@ func (s *Partition) pick() *typedef.ValueWithToken {
return val
default:
s.wakeUp() // channel empty, need to wait for new values
- return <-s.values
- }
-}
-
-func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken {
- s.lock.RLock()
- if s.closed {
- // Since only giveOld could have been potentially called after partition is closed
- // we need to protect it against writing to closed channel
return nil
}
- defer s.lock.RUnlock()
- return s.oldValues
}
-func (s *Partition) Close() {
- s.lock.RLock()
- if s.closed {
- s.lock.RUnlock()
- return
- }
- s.lock.RUnlock()
- s.lock.Lock()
- if s.closed {
- return
- }
- s.closed = true
- close(s.values)
- close(s.oldValues)
- s.lock.Unlock()
-}
-
-type Partitions []*Partition
-
-func (p Partitions) CloseAll() {
- for _, part := range p {
- part.Close()
+func (s *Partition) Close() error {
+ if !s.closed.Swap(true) {
+ close(s.values)
+ close(s.oldValues)
}
-}
-func NewPartitions(count, pkBufferSize int, wakeUpSignal chan struct{}) Partitions {
- partitions := make(Partitions, count)
- for i := 0; i < len(partitions); i++ {
- partitions[i] = &Partition{
- values: make(chan *typedef.ValueWithToken, pkBufferSize),
- oldValues: make(chan *typedef.ValueWithToken, pkBufferSize),
- inFlight: inflight.New(),
- wakeUpSignal: wakeUpSignal,
- }
- }
- return partitions
+ return nil
}
diff --git a/pkg/generators/partitions.go b/pkg/generators/partitions.go
new file mode 100644
index 00000000..b0eec362
--- /dev/null
+++ b/pkg/generators/partitions.go
@@ -0,0 +1,25 @@
+package generators
+
+import "go.uber.org/multierr"
+
+type Partitions []*Partition
+
+func (p Partitions) Close() error {
+ var err error
+
+ for _, part := range p {
+ err = multierr.Append(err, part.Close())
+ }
+
+ return err
+}
+
+func NewPartitions(count, pkBufferSize int, wakeUpSignal chan<- struct{}) Partitions {
+ partitions := make(Partitions, 0, count)
+
+ for i := 0; i < count; i++ {
+ partitions = append(partitions, NewPartition(wakeUpSignal, pkBufferSize))
+ }
+
+ return partitions
+}
diff --git a/pkg/generators/statement_generator.go b/pkg/generators/statement_generator.go
index 25503a89..902af751 100644
--- a/pkg/generators/statement_generator.go
+++ b/pkg/generators/statement_generator.go
@@ -87,7 +87,7 @@ func genTable(sc typedef.SchemaConfig, tableName string, r *rand.Rand) *typedef.
table.Indexes = indexes
var mvs []typedef.MaterializedView
- if sc.CQLFeature > typedef.CQL_FEATURE_BASIC && sc.UseMaterializedViews && len(clusteringKeys) > 0 && columns.ValidColumnsForPrimaryKey().Len() != 0 {
+ if sc.UseMaterializedViews && sc.CQLFeature > typedef.CQL_FEATURE_BASIC && len(clusteringKeys) > 0 && columns.ValidColumnsForPrimaryKey().Len() != 0 {
mvs = CreateMaterializedViews(columns, table.Name, partitionKeys, clusteringKeys, r)
}
@@ -101,7 +101,7 @@ func GetCreateKeyspaces(s *typedef.Schema) (string, string) {
fmt.Sprintf("CREATE KEYSPACE IF NOT EXISTS %s WITH REPLICATION = %s", s.Keyspace.Name, s.Keyspace.OracleReplication.ToCQL())
}
-func GetCreateSchema(s *typedef.Schema) []string {
+func GetCreateSchema(s *typedef.Schema, enableMV bool) []string {
var stmts []string
for _, t := range s.Tables {
@@ -124,17 +124,21 @@ func GetCreateSchema(s *typedef.Schema) []string {
for _, ck := range mv.ClusteringKeys {
mvPrimaryKeysNotNull = append(mvPrimaryKeysNotNull, fmt.Sprintf("%s IS NOT NULL", ck.Name))
}
- var createMaterializedView string
- if len(mv.PartitionKeys) == 1 {
- createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY (%s"
- } else {
- createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY ((%s)"
+
+ if enableMV {
+ var createMaterializedView string
+ if len(mv.PartitionKeys) == 1 {
+ createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY (%s"
+ } else {
+ createMaterializedView = "CREATE MATERIALIZED VIEW IF NOT EXISTS %s.%s AS SELECT * FROM %s.%s WHERE %s PRIMARY KEY ((%s)"
+ }
+ createMaterializedView += ",%s)"
+ stmts = append(stmts, fmt.Sprintf(createMaterializedView,
+ s.Keyspace.Name, mv.Name, s.Keyspace.Name, t.Name,
+ strings.Join(mvPrimaryKeysNotNull, " AND "),
+ strings.Join(mvPartitionKeys, ","), strings.Join(t.ClusteringKeys.Names(), ",")),
+ )
}
- createMaterializedView += ",%s)"
- stmts = append(stmts, fmt.Sprintf(createMaterializedView,
- s.Keyspace.Name, mv.Name, s.Keyspace.Name, t.Name,
- strings.Join(mvPrimaryKeysNotNull, " AND "),
- strings.Join(mvPartitionKeys, ","), strings.Join(t.ClusteringKeys.Names(), ",")))
}
}
return stmts
diff --git a/pkg/jobs/gen_check_stmt.go b/pkg/generators/statements/gen_check_stmt.go
similarity index 77%
rename from pkg/jobs/gen_check_stmt.go
rename to pkg/generators/statements/gen_check_stmt.go
index 3822b409..9dea12a3 100644
--- a/pkg/jobs/gen_check_stmt.go
+++ b/pkg/generators/statements/gen_check_stmt.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package jobs
+package statements
import (
"math"
@@ -28,94 +28,105 @@ import (
func GenCheckStmt(
s *typedef.Schema,
table *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
rnd *rand.Rand,
p *typedef.PartitionRangeConfig,
-) *typedef.Stmt {
- n := 0
- mvNum := -1
- maxClusteringRels := 0
- numQueryPKs := 0
- if len(table.MaterializedViews) > 0 && rnd.Int()%2 == 0 {
- mvNum = utils.RandInt2(rnd, 0, len(table.MaterializedViews))
- }
-
- switch mvNum {
- case -1:
- if len(table.Indexes) > 0 {
- n = rnd.Intn(5)
- } else {
- n = rnd.Intn(4)
- }
- switch n {
- case 0:
- return genSinglePartitionQuery(s, table, g)
- case 1:
- numQueryPKs = utils.RandInt2(rnd, 1, table.PartitionKeys.Len())
- multiplier := int(math.Pow(float64(numQueryPKs), float64(table.PartitionKeys.Len())))
- if multiplier > 100 {
- numQueryPKs = 1
- }
- return genMultiplePartitionQuery(s, table, g, numQueryPKs)
- case 2:
- maxClusteringRels = utils.RandInt2(rnd, 0, table.ClusteringKeys.Len())
- return genClusteringRangeQuery(s, table, g, rnd, p, maxClusteringRels)
- case 3:
- numQueryPKs = utils.RandInt2(rnd, 1, table.PartitionKeys.Len())
- multiplier := int(math.Pow(float64(numQueryPKs), float64(table.PartitionKeys.Len())))
- if multiplier > 100 {
- numQueryPKs = 1
- }
- maxClusteringRels = utils.RandInt2(rnd, 0, table.ClusteringKeys.Len())
- return genMultiplePartitionClusteringRangeQuery(s, table, g, rnd, p, numQueryPKs, maxClusteringRels)
- case 4:
- // Reducing the probability to hit these since they often take a long time to run
- switch rnd.Intn(5) {
- case 0:
- idxCount := utils.RandInt2(rnd, 1, len(table.Indexes))
- return genSingleIndexQuery(s, table, g, rnd, p, idxCount)
- default:
- return genSinglePartitionQuery(s, table, g)
- }
- }
- default:
- n = rnd.Intn(4)
- switch n {
- case 0:
- return genSinglePartitionQueryMv(s, table, g, rnd, p, mvNum)
- case 1:
- lenPartitionKeys := table.MaterializedViews[mvNum].PartitionKeys.Len()
- numQueryPKs = utils.RandInt2(rnd, 1, lenPartitionKeys)
- multiplier := int(math.Pow(float64(numQueryPKs), float64(lenPartitionKeys)))
- if multiplier > 100 {
- numQueryPKs = 1
- }
- return genMultiplePartitionQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs)
- case 2:
- lenClusteringKeys := table.MaterializedViews[mvNum].ClusteringKeys.Len()
- maxClusteringRels = utils.RandInt2(rnd, 0, lenClusteringKeys)
- return genClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, maxClusteringRels)
- case 3:
- lenPartitionKeys := table.MaterializedViews[mvNum].PartitionKeys.Len()
- numQueryPKs = utils.RandInt2(rnd, 1, lenPartitionKeys)
- multiplier := int(math.Pow(float64(numQueryPKs), float64(lenPartitionKeys)))
- if multiplier > 100 {
- numQueryPKs = 1
+) (*typedef.Stmt, func()) {
+ var stmt *typedef.Stmt
+
+ if shouldGenerateCheckStatementForMV(table, rnd) {
+ stmt = genCheckStmtMV(s, table, g, rnd, p)
+ } else {
+ stmt = genCheckTableStmt(s, table, g, rnd, p)
+ }
+
+ return stmt, func() {
+ if stmt != nil && stmt.ValuesWithToken != nil {
+ for _, v := range stmt.ValuesWithToken {
+ g.ReleaseToken(v.Token)
}
- lenClusteringKeys := table.MaterializedViews[mvNum].ClusteringKeys.Len()
- maxClusteringRels = utils.RandInt2(rnd, 0, lenClusteringKeys)
- return genMultiplePartitionClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs, maxClusteringRels)
}
}
+}
+
+// shouldGenerateCheckStatementForMV should be true if we have materialized views
+// and the random number is even. So this means that we have a 50% chance of
+// checking materialized views.
+func shouldGenerateCheckStatementForMV(table *typedef.Table, rnd *rand.Rand) bool {
+ return len(table.MaterializedViews) > 0 && rnd.Int()%2 == 0
+}
- return nil
+func genCheckStmtMV(s *typedef.Schema, table *typedef.Table, g generators.Interface, rnd *rand.Rand, p *typedef.PartitionRangeConfig) *typedef.Stmt {
+ mvNum := utils.RandInt2(rnd, 0, len(table.MaterializedViews))
+ lenClusteringKeys := table.MaterializedViews[mvNum].ClusteringKeys.Len()
+ lenPartitionKeys := table.MaterializedViews[mvNum].PartitionKeys.Len()
+
+ maxClusteringRels := utils.RandInt2(rnd, 0, lenClusteringKeys)
+ numQueryPKs := utils.RandInt2(rnd, 1, lenPartitionKeys)
+ if int(math.Pow(float64(numQueryPKs), float64(lenPartitionKeys))) > 100 {
+ numQueryPKs = 1
+ }
+
+ switch rnd.Intn(4) {
+ case 0:
+ return genSinglePartitionQueryMv(s, table, g, rnd, p, mvNum)
+ case 1:
+ return genMultiplePartitionQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs)
+ case 2:
+ return genClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, maxClusteringRels)
+ case 3:
+ return genMultiplePartitionClusteringRangeQueryMv(s, table, g, rnd, p, mvNum, numQueryPKs, maxClusteringRels)
+ default:
+ panic("random number generator does not work correctly, unreachable statement")
+ }
}
-func genSinglePartitionQuery(
+func genCheckTableStmt(
s *typedef.Schema,
- t *typedef.Table,
- g generators.GeneratorInterface,
+ table *typedef.Table,
+ g generators.Interface,
+ rnd *rand.Rand,
+ p *typedef.PartitionRangeConfig,
) *typedef.Stmt {
+ var n int
+
+ if len(table.Indexes) > 0 {
+ n = rnd.Intn(5)
+ } else {
+ n = rnd.Intn(4)
+ }
+
+ maxClusteringRels := utils.RandInt2(rnd, 0, table.ClusteringKeys.Len())
+ numQueryPKs := utils.RandInt2(rnd, 1, table.PartitionKeys.Len())
+ multiplier := int(math.Pow(float64(numQueryPKs), float64(table.PartitionKeys.Len())))
+ if multiplier > 100 {
+ numQueryPKs = 1
+ }
+
+ switch n {
+ case 0:
+ return genSinglePartitionQuery(s, table, g)
+ case 1:
+ return genMultiplePartitionQuery(s, table, g, numQueryPKs)
+ case 2:
+ return genClusteringRangeQuery(s, table, g, rnd, p, maxClusteringRels)
+ case 3:
+ return genMultiplePartitionClusteringRangeQuery(s, table, g, rnd, p, numQueryPKs, maxClusteringRels)
+ case 4:
+ // Reducing the probability to hit these since they often take a long time to run
+ // One in five chance to hit this
+ if rnd.Intn(5) == 0 {
+ idxCount := utils.RandInt2(rnd, 1, len(table.Indexes))
+ return genSingleIndexQuery(s, table, g, rnd, p, idxCount)
+ }
+
+ return genSinglePartitionQuery(s, table, g)
+ default:
+ panic("random number generator does not work correctly, unreachable statement")
+ }
+}
+
+func genSinglePartitionQuery(s *typedef.Schema, t *typedef.Table, g generators.Interface) *typedef.Stmt {
t.RLock()
defer t.RUnlock()
valuesWithToken := g.GetOld()
@@ -124,7 +135,8 @@ func genSinglePartitionQuery(
}
values := valuesWithToken.Value.Copy()
builder := qb.Select(s.Keyspace.Name + "." + t.Name)
- typs := make([]typedef.Type, 0, 10)
+ typs := make([]typedef.Type, 0, len(t.PartitionKeys))
+
for _, pk := range t.PartitionKeys {
builder = builder.Where(qb.Eq(pk.Name))
typs = append(typs, pk.Type)
@@ -144,7 +156,7 @@ func genSinglePartitionQuery(
func genSinglePartitionQueryMv(
s *typedef.Schema,
t *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
r *rand.Rand,
p *typedef.PartitionRangeConfig,
mvNum int,
@@ -183,7 +195,7 @@ func genSinglePartitionQueryMv(
func genMultiplePartitionQuery(
s *typedef.Schema,
t *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
numQueryPKs int,
) *typedef.Stmt {
t.RLock()
@@ -197,7 +209,7 @@ func genMultiplePartitionQuery(
for j := 0; j < numQueryPKs; j++ {
vs := g.GetOld()
if vs == nil {
- g.GiveOlds(tokens)
+ g.GiveOld(tokens...)
return nil
}
tokens = append(tokens, vs)
@@ -223,7 +235,7 @@ func genMultiplePartitionQuery(
func genMultiplePartitionQueryMv(
s *typedef.Schema,
t *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
r *rand.Rand,
p *typedef.PartitionRangeConfig,
mvNum, numQueryPKs int,
@@ -241,7 +253,7 @@ func genMultiplePartitionQueryMv(
for j := 0; j < numQueryPKs; j++ {
vs := g.GetOld()
if vs == nil {
- g.GiveOlds(tokens)
+ g.GiveOld(tokens...)
return nil
}
tokens = append(tokens, vs)
@@ -274,7 +286,7 @@ func genMultiplePartitionQueryMv(
func genClusteringRangeQuery(
s *typedef.Schema,
t *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
r *rand.Rand,
p *typedef.PartitionRangeConfig,
maxClusteringRels int,
@@ -321,7 +333,7 @@ func genClusteringRangeQuery(
func genClusteringRangeQueryMv(
s *typedef.Schema,
t *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
r *rand.Rand,
p *typedef.PartitionRangeConfig,
mvNum, maxClusteringRels int,
@@ -374,7 +386,7 @@ func genClusteringRangeQueryMv(
func genMultiplePartitionClusteringRangeQuery(
s *typedef.Schema,
t *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
r *rand.Rand,
p *typedef.PartitionRangeConfig,
numQueryPKs, maxClusteringRels int,
@@ -397,7 +409,7 @@ func genMultiplePartitionClusteringRangeQuery(
for j := 0; j < numQueryPKs; j++ {
vs := g.GetOld()
if vs == nil {
- g.GiveOlds(tokens)
+ g.GiveOld(tokens...)
return nil
}
tokens = append(tokens, vs)
@@ -435,7 +447,7 @@ func genMultiplePartitionClusteringRangeQuery(
func genMultiplePartitionClusteringRangeQueryMv(
s *typedef.Schema,
t *typedef.Table,
- g generators.GeneratorInterface,
+ g generators.Interface,
r *rand.Rand,
p *typedef.PartitionRangeConfig,
mvNum, numQueryPKs, maxClusteringRels int,
@@ -478,7 +490,7 @@ func genMultiplePartitionClusteringRangeQueryMv(
for j := 0; j < numQueryPKs; j++ {
vs := g.GetOld()
if vs == nil {
- g.GiveOlds(tokens)
+ g.GiveOld(tokens...)
return nil
}
tokens = append(tokens, vs)
@@ -516,7 +528,7 @@ func genMultiplePartitionClusteringRangeQueryMv(
func genSingleIndexQuery(
s *typedef.Schema,
t *typedef.Table,
- _ generators.GeneratorInterface,
+ _ generators.Interface,
r *rand.Rand,
p *typedef.PartitionRangeConfig,
idxCount int,
diff --git a/pkg/generators/statements/gen_check_stmt_test.go b/pkg/generators/statements/gen_check_stmt_test.go
new file mode 100644
index 00000000..55b28468
--- /dev/null
+++ b/pkg/generators/statements/gen_check_stmt_test.go
@@ -0,0 +1,332 @@
+// Copyright 2019 ScyllaDB
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//nolint:thelper
+package statements
+
+import (
+ "path"
+ "testing"
+
+ "github.com/scylladb/gemini/pkg/testutils"
+)
+
+const checkDataPath = "./test_expected_data/check/"
+
+func TestGenSinglePartitionQuery(t *testing.T) {
+ RunStmtTest[results](t, path.Join(checkDataPath, "single_partition.json"), genSinglePartitionQueryCases,
+ func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+ schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
+ stmt := genSinglePartitionQuery(schema, schema.Tables[0], gen)
+ validateStmt(subT, stmt, nil)
+ expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+ })
+}
+
+// func TestGenSinglePartitionQueryMv(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "single_partition_mv.json"), genSinglePartitionQueryMvCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// stmt := genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1)
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func TestGenMultiplePartitionQuery(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition.json"), genMultiplePartitionQueryCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
+// options := testutils.GetOptionsFromCaseName(caseName)
+// stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func TestGenMultiplePartitionQueryMv(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_mv.json"), genMultiplePartitionQueryMvCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
+// stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func TestGenClusteringRangeQuery(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range.json"), genClusteringRangeQueryCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// options := testutils.GetOptionsFromCaseName(caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// stmt := genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func TestGenClusteringRangeQueryMv(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range_mv.json"), genClusteringRangeQueryMvCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// options := testutils.GetOptionsFromCaseName(caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// stmt := genClusteringRangeQueryMv(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// len(schema.Tables[0].MaterializedViews)-1,
+// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func TestGenMultiplePartitionClusteringRangeQuery(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range.json"), genMultiplePartitionClusteringRangeQueryCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// options := testutils.GetOptionsFromCaseName(caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// stmt := genMultiplePartitionClusteringRangeQuery(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
+// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func TestGenMultiplePartitionClusteringRangeQueryMv(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range_mv.json"), genMultiplePartitionClusteringRangeQueryMvCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// stmt := genMultiplePartitionClusteringRangeQueryMv(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// len(schema.Tables[0].MaterializedViews)-1,
+// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
+// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func TestGenSingleIndexQuery(t *testing.T) {
+// RunStmtTest[results](t, path.Join(checkDataPath, "single_index.json"), genSingleIndexQueryCases,
+// func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// stmt := genSingleIndexQuery(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].Indexes))
+// validateStmt(subT, stmt, nil)
+// expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
+// })
+// }
+
+// func BenchmarkGenSinglePartitionQuery(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genSinglePartitionQueryCases {
+// caseName := genSinglePartitionQueryCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genSinglePartitionQuery(schema, schema.Tables[0], gen)
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenSinglePartitionQueryMv(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genSinglePartitionQueryMvCases {
+// caseName := genSinglePartitionQueryMvCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1)
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenMultiplePartitionQuery(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genMultiplePartitionQueryCases {
+// caseName := genMultiplePartitionQueryCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenMultiplePartitionQueryMv(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genMultiplePartitionQueryMvCases {
+// caseName := genMultiplePartitionQueryMvCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genMultiplePartitionQueryMv(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// len(schema.Tables[0].MaterializedViews)-1,
+// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenClusteringRangeQuery(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genClusteringRangeQueryCases {
+// caseName := genClusteringRangeQueryCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenClusteringRangeQueryMv(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genClusteringRangeQueryMvCases {
+// caseName := genClusteringRangeQueryMvCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genClusteringRangeQueryMv(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// len(schema.Tables[0].MaterializedViews)-1,
+// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenMultiplePartitionClusteringRangeQuery(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genMultiplePartitionClusteringRangeQueryCases {
+// caseName := genMultiplePartitionClusteringRangeQueryCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genMultiplePartitionClusteringRangeQuery(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
+// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenMultiplePartitionClusteringRangeQueryMv(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genMultiplePartitionClusteringRangeQueryMvCases {
+// caseName := genMultiplePartitionClusteringRangeQueryMvCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// options := testutils.GetOptionsFromCaseName(caseName)
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genMultiplePartitionClusteringRangeQueryMv(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// len(schema.Tables[0].MaterializedViews)-1,
+// GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
+// GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
+// }
+// })
+// }
+// }
+
+// func BenchmarkGenSingleIndexQuery(t *testing.B) {
+// utils.SetUnderTest()
+// for idx := range genSingleIndexQueryCases {
+// caseName := genSingleIndexQueryCases[idx]
+// t.Run(caseName,
+// func(subT *testing.B) {
+// schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
+// prc := schema.Config.GetPartitionRangeConfig()
+// subT.ResetTimer()
+// for x := 0; x < subT.N; x++ {
+// _ = genSingleIndexQuery(
+// schema,
+// schema.Tables[0],
+// gen,
+// rnd,
+// &prc,
+// len(schema.Tables[0].Indexes))
+// }
+// })
+// }
+// }
diff --git a/pkg/jobs/gen_const_test.go b/pkg/generators/statements/gen_const_test.go
similarity index 99%
rename from pkg/jobs/gen_const_test.go
rename to pkg/generators/statements/gen_const_test.go
index 1cf5922c..f602f3ef 100644
--- a/pkg/jobs/gen_const_test.go
+++ b/pkg/generators/statements/gen_const_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package jobs
+package statements
var (
genInsertStmtCases = []string{
diff --git a/pkg/jobs/gen_ddl_stmt.go b/pkg/generators/statements/gen_ddl_stmt.go
similarity index 96%
rename from pkg/jobs/gen_ddl_stmt.go
rename to pkg/generators/statements/gen_ddl_stmt.go
index a8866fd4..9ec0ec08 100644
--- a/pkg/jobs/gen_ddl_stmt.go
+++ b/pkg/generators/statements/gen_ddl_stmt.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package jobs
+package statements
import (
"fmt"
@@ -25,7 +25,7 @@ import (
"github.com/scylladb/gemini/pkg/typedef"
)
-func GenDDLStmt(s *typedef.Schema, t *typedef.Table, r *rand.Rand, _ *typedef.PartitionRangeConfig, sc *typedef.SchemaConfig) (*typedef.Stmts, error) {
+func GenDDLStmt(s *typedef.Schema, t *typedef.Table, r *rand.Rand, sc *typedef.SchemaConfig) (*typedef.Stmts, error) {
maxVariant := 1
validCols := t.ValidColumnsForDelete()
if validCols.Len() > 0 {
diff --git a/pkg/jobs/gen_ddl_stmt_test.go b/pkg/generators/statements/gen_ddl_stmt_test.go
similarity index 99%
rename from pkg/jobs/gen_ddl_stmt_test.go
rename to pkg/generators/statements/gen_ddl_stmt_test.go
index e5f81a73..6643a49c 100644
--- a/pkg/jobs/gen_ddl_stmt_test.go
+++ b/pkg/generators/statements/gen_ddl_stmt_test.go
@@ -13,7 +13,7 @@
// limitations under the License.
//nolint:thelper
-package jobs
+package statements
import (
"fmt"
diff --git a/pkg/jobs/gen_mutate_stmt.go b/pkg/generators/statements/gen_mutate_stmt.go
similarity index 97%
rename from pkg/jobs/gen_mutate_stmt.go
rename to pkg/generators/statements/gen_mutate_stmt.go
index 6f65caab..f727da77 100644
--- a/pkg/jobs/gen_mutate_stmt.go
+++ b/pkg/generators/statements/gen_mutate_stmt.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package jobs
+package statements
import (
"encoding/json"
@@ -26,7 +26,7 @@ import (
"github.com/scylladb/gemini/pkg/typedef"
)
-func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.GeneratorInterface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) {
+func GenMutateStmt(s *typedef.Schema, t *typedef.Table, g generators.Interface, r *rand.Rand, p *typedef.PartitionRangeConfig, deletes bool) (*typedef.Stmt, error) {
t.RLock()
defer t.RUnlock()
diff --git a/pkg/jobs/gen_mutate_stmt_test.go b/pkg/generators/statements/gen_mutate_stmt_test.go
similarity index 99%
rename from pkg/jobs/gen_mutate_stmt_test.go
rename to pkg/generators/statements/gen_mutate_stmt_test.go
index 006f5329..1fc1e9e2 100644
--- a/pkg/jobs/gen_mutate_stmt_test.go
+++ b/pkg/generators/statements/gen_mutate_stmt_test.go
@@ -13,7 +13,7 @@
// limitations under the License.
//nolint:thelper
-package jobs
+package statements
import (
"path"
diff --git a/pkg/generators/statements/gen_utils_test.go b/pkg/generators/statements/gen_utils_test.go
new file mode 100644
index 00000000..bbce4bbb
--- /dev/null
+++ b/pkg/generators/statements/gen_utils_test.go
@@ -0,0 +1,267 @@
+// Copyright 2019 ScyllaDB
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package statements
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/scylladb/gemini/pkg/testutils"
+ "github.com/scylladb/gemini/pkg/typedef"
+ "github.com/scylladb/gemini/pkg/utils"
+)
+
+type resultToken struct {
+ Token string
+ TokenValues string
+}
+
+func (r resultToken) Equal(received resultToken) bool {
+ return r.Token == received.Token && r.TokenValues == received.TokenValues
+}
+
+type resultTokens []resultToken
+
+func (r resultTokens) Equal(received resultTokens) bool {
+ if len(r) != len(received) {
+ return false
+ }
+ for id, expectedToken := range r {
+ if !expectedToken.Equal(received[id]) {
+ return false
+ }
+ }
+ return true
+}
+
+func (r resultTokens) Diff(received resultTokens) string {
+ var out []string
+ maxIdx := len(r)
+ if maxIdx < len(received) {
+ maxIdx = len(received)
+ }
+ var expected, found *resultToken
+ for idx := 0; idx < maxIdx; idx++ {
+ if idx < len(r) {
+ expected = &r[idx]
+ } else {
+ expected = &resultToken{}
+ }
+
+ if idx < len(received) {
+ found = &received[idx]
+ } else {
+ found = &resultToken{}
+ }
+
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ expected.TokenValues, found.TokenValues, " error: value stmt.ValuesWithToken.Token expected and received are different:"))
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ expected.TokenValues, found.TokenValues, " error: value stmt.ValuesWithToken.Value expected and received are different:"))
+ }
+ return strings.Join(out, "\n")
+}
+
+type result struct {
+ Query string
+ Names string
+ Values string
+ Types string
+ QueryType string
+ TokenValues resultTokens
+}
+
+func (r *result) Equal(t *result) bool {
+ var expected result
+ if r != nil {
+ expected = *r
+ }
+
+ var provided result
+ if t != nil {
+ provided = *t
+ }
+ return expected.Query == provided.Query &&
+ expected.Names == provided.Names &&
+ expected.Values == provided.Values &&
+ expected.Types == provided.Types &&
+ expected.QueryType == provided.QueryType &&
+ expected.TokenValues.Equal(provided.TokenValues)
+}
+
+func (r *result) Diff(received *result) string {
+ var out []string
+ out = testutils.AppendIfNotEmpty(out, r.TokenValues.Diff(received.TokenValues))
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ r.Query, received.Query, " error: value stmt.Query.ToCql().stmt expected and received are different:"))
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ r.Names, received.Names, " error: value stmt.Query.ToCql().Names expected and received are different:"))
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ r.Values, received.Values, " error: value stmt.Values expected and received are different:"))
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ r.Types, received.Types, " error: value stmt.Types expected and received are different:"))
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ r.Values, received.Values, " error: value stmt.Values expected and received are different:"))
+ out = testutils.AppendIfNotEmpty(out, testutils.GetErrorMsgIfDifferent(
+ r.QueryType, received.QueryType, " error: value stmt.QueryType expected and received are different:"))
+ return strings.Join(out, "\n")
+}
+
+type results []*result
+
+func (r results) Equal(t results) bool {
+ return r.Diff(t) == ""
+}
+
+func (r results) Diff(t results) string {
+ var out []string
+ maxIdx := len(r)
+ if maxIdx < len(t) {
+ maxIdx = len(t)
+ }
+ var expected, found *result
+ for idx := 0; idx < maxIdx; idx++ {
+ if idx < len(r) {
+ expected = r[idx]
+ } else {
+ expected = &result{}
+ }
+
+ if idx < len(t) {
+ found = t[idx]
+ } else {
+ found = &result{}
+ }
+
+ out = testutils.AppendIfNotEmpty(out, expected.Diff(found))
+ }
+ return strings.Join(out, "\n")
+}
+
+func convertStmtsToResults(stmt any) results {
+ var out results
+ switch stmts := stmt.(type) {
+ case *typedef.Stmts:
+ for idx := range stmts.List {
+ out = append(out, convertStmtToResults(stmts.List[idx]))
+ }
+ case *typedef.Stmt:
+ out = append(out, convertStmtToResults(stmts))
+ }
+ return out
+}
+
+func convertStmtToResults(stmt *typedef.Stmt) *result {
+ types := ""
+ for idx := range stmt.Types {
+ types = fmt.Sprintf("%s %s", types, stmt.Types[idx].Name())
+ }
+ query, names := stmt.Query.ToCql()
+ var tokens []resultToken
+ for _, valueToken := range stmt.ValuesWithToken {
+ tokens = append(tokens, resultToken{
+ Token: fmt.Sprintf("%v", valueToken.Token),
+ TokenValues: strings.TrimSpace(fmt.Sprintf("%v", valueToken.Value)),
+ })
+ }
+
+ return &result{
+ TokenValues: tokens,
+ Query: strings.TrimSpace(query),
+ Names: strings.TrimSpace(fmt.Sprintf("%s", names)),
+ Values: strings.TrimSpace(fmt.Sprintf("%v", stmt.Values)),
+ Types: types,
+ QueryType: fmt.Sprintf("%v", stmt.QueryType),
+ }
+}
+
+func RunStmtTest[T testutils.ExpectedEntry[T]](
+ t *testing.T,
+ filePath string,
+ cases []string,
+ testBody func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[T]),
+) {
+ t.Helper()
+ utils.SetUnderTest()
+ t.Parallel()
+ expected := testutils.LoadExpectedFromFile[T](t, filePath, cases, *testutils.UpdateExpectedFlag)
+ if *testutils.UpdateExpectedFlag {
+ t.Cleanup(func() {
+ expected.UpdateExpected(t)
+ })
+ }
+ for idx := range cases {
+ caseName := cases[idx]
+ t.Run(caseName,
+ func(subT *testing.T) {
+ subT.Parallel()
+ testBody(subT, caseName, expected)
+ })
+ }
+}
+
+func GetPkCountFromOptions(options testutils.TestCaseOptions, allValue int) int {
+ pkCount := 0
+ options.HandleOption("cpk", func(option string) {
+ switch option {
+ case "cpkAll":
+ pkCount = allValue
+ case "cpk1":
+ pkCount = 1
+ }
+ })
+ return pkCount
+}
+
+func GetCkCountFromOptions(options testutils.TestCaseOptions, allValue int) int {
+ ckCount := -1
+ options.HandleOption("cck", func(option string) {
+ switch option {
+ case "cckAll":
+ ckCount = allValue
+ case "cck1":
+ ckCount = 0
+ }
+ })
+ return ckCount
+}
+
+func validateStmt(t *testing.T, stmt any, err error) {
+ t.Helper()
+ if err != nil {
+ t.Fatalf("error: get an error on create test inputs:%v", err)
+ }
+ if stmt == nil {
+ t.Fatalf("error: stmt is nil")
+ }
+ switch stmts := stmt.(type) {
+ case *typedef.Stmts:
+ if stmts == nil || stmts.List == nil || len(stmts.List) == 0 {
+ t.Fatalf("error: stmts is empty")
+ }
+ for i := range stmts.List {
+ if stmts.List[i] == nil || stmts.List[i].Query == nil {
+ t.Fatalf("error: stmts has nil stmt #%d", i)
+ }
+ }
+ case *typedef.Stmt:
+ if stmts == nil || stmts.Query == nil {
+ t.Fatalf("error: stmt is empty")
+ }
+ default:
+ t.Fatalf("error: unkwon type of stmt")
+ }
+}
diff --git a/pkg/jobs/test_expected_data/check/clustering_range.json b/pkg/generators/statements/test_expected_data/check/clustering_range.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/clustering_range.json
rename to pkg/generators/statements/test_expected_data/check/clustering_range.json
diff --git a/pkg/jobs/test_expected_data/check/clustering_range_mv.json b/pkg/generators/statements/test_expected_data/check/clustering_range_mv.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/clustering_range_mv.json
rename to pkg/generators/statements/test_expected_data/check/clustering_range_mv.json
diff --git a/pkg/jobs/test_expected_data/check/multiple_partition.json b/pkg/generators/statements/test_expected_data/check/multiple_partition.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/multiple_partition.json
rename to pkg/generators/statements/test_expected_data/check/multiple_partition.json
diff --git a/pkg/jobs/test_expected_data/check/multiple_partition_clustering_range.json b/pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/multiple_partition_clustering_range.json
rename to pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range.json
diff --git a/pkg/jobs/test_expected_data/check/multiple_partition_clustering_range_mv.json b/pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range_mv.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/multiple_partition_clustering_range_mv.json
rename to pkg/generators/statements/test_expected_data/check/multiple_partition_clustering_range_mv.json
diff --git a/pkg/jobs/test_expected_data/check/multiple_partition_mv.json b/pkg/generators/statements/test_expected_data/check/multiple_partition_mv.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/multiple_partition_mv.json
rename to pkg/generators/statements/test_expected_data/check/multiple_partition_mv.json
diff --git a/pkg/jobs/test_expected_data/check/single_index.json b/pkg/generators/statements/test_expected_data/check/single_index.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/single_index.json
rename to pkg/generators/statements/test_expected_data/check/single_index.json
diff --git a/pkg/jobs/test_expected_data/check/single_partition.json b/pkg/generators/statements/test_expected_data/check/single_partition.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/single_partition.json
rename to pkg/generators/statements/test_expected_data/check/single_partition.json
diff --git a/pkg/jobs/test_expected_data/check/single_partition_mv.json b/pkg/generators/statements/test_expected_data/check/single_partition_mv.json
similarity index 100%
rename from pkg/jobs/test_expected_data/check/single_partition_mv.json
rename to pkg/generators/statements/test_expected_data/check/single_partition_mv.json
diff --git a/pkg/jobs/test_expected_data/ddl/add_column.json b/pkg/generators/statements/test_expected_data/ddl/add_column.json
similarity index 100%
rename from pkg/jobs/test_expected_data/ddl/add_column.json
rename to pkg/generators/statements/test_expected_data/ddl/add_column.json
diff --git a/pkg/jobs/test_expected_data/ddl/drop_column.json b/pkg/generators/statements/test_expected_data/ddl/drop_column.json
similarity index 100%
rename from pkg/jobs/test_expected_data/ddl/drop_column.json
rename to pkg/generators/statements/test_expected_data/ddl/drop_column.json
diff --git a/pkg/jobs/test_expected_data/mutate/delete.json b/pkg/generators/statements/test_expected_data/mutate/delete.json
similarity index 100%
rename from pkg/jobs/test_expected_data/mutate/delete.json
rename to pkg/generators/statements/test_expected_data/mutate/delete.json
diff --git a/pkg/jobs/test_expected_data/mutate/insert.json b/pkg/generators/statements/test_expected_data/mutate/insert.json
similarity index 100%
rename from pkg/jobs/test_expected_data/mutate/insert.json
rename to pkg/generators/statements/test_expected_data/mutate/insert.json
diff --git a/pkg/jobs/test_expected_data/mutate/insert_j.json b/pkg/generators/statements/test_expected_data/mutate/insert_j.json
similarity index 100%
rename from pkg/jobs/test_expected_data/mutate/insert_j.json
rename to pkg/generators/statements/test_expected_data/mutate/insert_j.json
diff --git a/pkg/jobs/test_expected_data/mutate/update.json b/pkg/generators/statements/test_expected_data/mutate/update.json
similarity index 100%
rename from pkg/jobs/test_expected_data/mutate/update.json
rename to pkg/generators/statements/test_expected_data/mutate/update.json
diff --git a/pkg/jobs/ddl.go b/pkg/jobs/ddl.go
new file mode 100644
index 00000000..82166baa
--- /dev/null
+++ b/pkg/jobs/ddl.go
@@ -0,0 +1,83 @@
+package jobs
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/pkg/errors"
+ "go.uber.org/zap"
+
+ "github.com/scylladb/gemini/pkg/generators/statements"
+ "github.com/scylladb/gemini/pkg/joberror"
+ "github.com/scylladb/gemini/pkg/typedef"
+)
+
+func (m *mutation) DDL(ctx context.Context, table *typedef.Table) error {
+ table.RLock()
+ // Scylla does not allow changing the DDL of a table with materialized views.
+ if len(table.MaterializedViews) > 0 {
+ table.RUnlock()
+ return nil
+ }
+ table.RUnlock()
+
+ table.Lock()
+ defer table.Unlock()
+ ddlStmts, err := statements.GenDDLStmt(m.schema, table, m.random, &m.schema.Config)
+ if err != nil {
+ m.logger.Error("Failed! DDL Mutation statement generation failed", zap.Error(err))
+ m.globalStatus.WriteErrors.Add(1)
+ return err
+ }
+ if ddlStmts == nil {
+ if w := m.logger.Check(zap.DebugLevel, "no statement generated"); w != nil {
+ w.Write(zap.String("job", "ddl"))
+ }
+ return nil
+ }
+ for _, ddlStmt := range ddlStmts.List {
+ if w := m.logger.Check(zap.DebugLevel, "ddl statement"); w != nil {
+ prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL()
+ if prettyCQLErr != nil {
+ return PrettyCQLError{
+ PrettyCQL: prettyCQLErr,
+ Stmt: ddlStmt,
+ }
+ }
+
+ w.Write(zap.String("pretty_cql", prettyCQL))
+ }
+ if err = m.store.Mutate(ctx, ddlStmt); err != nil {
+ if errors.Is(err, context.Canceled) {
+ return nil
+ }
+
+ prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL()
+ if prettyCQLErr != nil {
+ return PrettyCQLError{
+ PrettyCQL: prettyCQLErr,
+ Stmt: ddlStmt,
+ Err: err,
+ }
+ }
+
+ m.globalStatus.AddWriteError(&joberror.JobError{
+ Timestamp: time.Now(),
+ StmtType: ddlStmts.QueryType.String(),
+ Message: "DDL failed: " + err.Error(),
+ Query: prettyCQL,
+ })
+
+ return err
+ }
+ m.globalStatus.WriteOps.Add(1)
+ }
+ ddlStmts.PostStmtHook()
+
+ jsonSchema, _ := json.MarshalIndent(m.schema, "", " ")
+ fmt.Printf("New schema: %v\n", string(jsonSchema))
+
+ return nil
+}
diff --git a/pkg/jobs/gen_check_stmt_test.go b/pkg/jobs/gen_check_stmt_test.go
deleted file mode 100644
index 912c9cf2..00000000
--- a/pkg/jobs/gen_check_stmt_test.go
+++ /dev/null
@@ -1,333 +0,0 @@
-// Copyright 2019 ScyllaDB
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//nolint:thelper
-package jobs
-
-import (
- "path"
- "testing"
-
- "github.com/scylladb/gemini/pkg/testutils"
- "github.com/scylladb/gemini/pkg/utils"
-)
-
-var checkDataPath = "./test_expected_data/check/"
-
-func TestGenSinglePartitionQuery(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "single_partition.json"), genSinglePartitionQueryCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
- stmt := genSinglePartitionQuery(schema, schema.Tables[0], gen)
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenSinglePartitionQueryMv(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "single_partition_mv.json"), genSinglePartitionQueryMvCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- stmt := genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1)
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenMultiplePartitionQuery(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition.json"), genMultiplePartitionQueryCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
- options := testutils.GetOptionsFromCaseName(caseName)
- stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenMultiplePartitionQueryMv(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_mv.json"), genMultiplePartitionQueryMvCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
- stmt := genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenClusteringRangeQuery(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range.json"), genClusteringRangeQueryCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- options := testutils.GetOptionsFromCaseName(caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- stmt := genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenClusteringRangeQueryMv(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "clustering_range_mv.json"), genClusteringRangeQueryMvCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- options := testutils.GetOptionsFromCaseName(caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- stmt := genClusteringRangeQueryMv(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- len(schema.Tables[0].MaterializedViews)-1,
- GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenMultiplePartitionClusteringRangeQuery(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range.json"), genMultiplePartitionClusteringRangeQueryCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- options := testutils.GetOptionsFromCaseName(caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- stmt := genMultiplePartitionClusteringRangeQuery(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
- GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenMultiplePartitionClusteringRangeQueryMv(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "multiple_partition_clustering_range_mv.json"), genMultiplePartitionClusteringRangeQueryMvCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- stmt := genMultiplePartitionClusteringRangeQueryMv(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- len(schema.Tables[0].MaterializedViews)-1,
- GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
- GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func TestGenSingleIndexQuery(t *testing.T) {
- RunStmtTest[results](t, path.Join(checkDataPath, "single_index.json"), genSingleIndexQueryCases,
- func(subT *testing.T, caseName string, expected *testutils.ExpectedStore[results]) {
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- stmt := genSingleIndexQuery(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].Indexes))
- validateStmt(subT, stmt, nil)
- expected.CompareOrStore(subT, caseName, convertStmtsToResults(stmt))
- })
-}
-
-func BenchmarkGenSinglePartitionQuery(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genSinglePartitionQueryCases {
- caseName := genSinglePartitionQueryCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genSinglePartitionQuery(schema, schema.Tables[0], gen)
- }
- })
- }
-}
-
-func BenchmarkGenSinglePartitionQueryMv(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genSinglePartitionQueryMvCases {
- caseName := genSinglePartitionQueryMvCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genSinglePartitionQueryMv(schema, schema.Tables[0], gen, rnd, &prc, len(schema.Tables[0].MaterializedViews)-1)
- }
- })
- }
-}
-
-func BenchmarkGenMultiplePartitionQuery(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genMultiplePartitionQueryCases {
- caseName := genMultiplePartitionQueryCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, _ := testutils.GetAllForTestStmt(subT, caseName)
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genMultiplePartitionQuery(schema, schema.Tables[0], gen, GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
- }
- })
- }
-}
-
-func BenchmarkGenMultiplePartitionQueryMv(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genMultiplePartitionQueryMvCases {
- caseName := genMultiplePartitionQueryMvCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genMultiplePartitionQueryMv(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- len(schema.Tables[0].MaterializedViews)-1,
- GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)))
- }
- })
- }
-}
-
-func BenchmarkGenClusteringRangeQuery(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genClusteringRangeQueryCases {
- caseName := genClusteringRangeQueryCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genClusteringRangeQuery(schema, schema.Tables[0], gen, rnd, &prc, GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- }
- })
- }
-}
-
-func BenchmarkGenClusteringRangeQueryMv(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genClusteringRangeQueryMvCases {
- caseName := genClusteringRangeQueryMvCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genClusteringRangeQueryMv(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- len(schema.Tables[0].MaterializedViews)-1,
- GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- }
- })
- }
-}
-
-func BenchmarkGenMultiplePartitionClusteringRangeQuery(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genMultiplePartitionClusteringRangeQueryCases {
- caseName := genMultiplePartitionClusteringRangeQueryCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genMultiplePartitionClusteringRangeQuery(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
- GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- }
- })
- }
-}
-
-func BenchmarkGenMultiplePartitionClusteringRangeQueryMv(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genMultiplePartitionClusteringRangeQueryMvCases {
- caseName := genMultiplePartitionClusteringRangeQueryMvCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- options := testutils.GetOptionsFromCaseName(caseName)
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genMultiplePartitionClusteringRangeQueryMv(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- len(schema.Tables[0].MaterializedViews)-1,
- GetPkCountFromOptions(options, len(schema.Tables[0].PartitionKeys)),
- GetCkCountFromOptions(options, len(schema.Tables[0].ClusteringKeys)-1))
- }
- })
- }
-}
-
-func BenchmarkGenSingleIndexQuery(t *testing.B) {
- utils.SetUnderTest()
- for idx := range genSingleIndexQueryCases {
- caseName := genSingleIndexQueryCases[idx]
- t.Run(caseName,
- func(subT *testing.B) {
- schema, gen, rnd := testutils.GetAllForTestStmt(subT, caseName)
- prc := schema.Config.GetPartitionRangeConfig()
- subT.ResetTimer()
- for x := 0; x < subT.N; x++ {
- _ = genSingleIndexQuery(
- schema,
- schema.Tables[0],
- gen,
- rnd,
- &prc,
- len(schema.Tables[0].Indexes))
- }
- })
- }
-}
diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go
index 3730b7de..3f38c6c5 100644
--- a/pkg/jobs/jobs.go
+++ b/pkg/jobs/jobs.go
@@ -16,528 +16,166 @@ package jobs
import (
"context"
- "encoding/json"
- "fmt"
+ "errors"
+ "sync"
"time"
- "github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/exp/rand"
- "golang.org/x/sync/errgroup"
"github.com/scylladb/gemini/pkg/generators"
- "github.com/scylladb/gemini/pkg/joberror"
"github.com/scylladb/gemini/pkg/status"
- "github.com/scylladb/gemini/pkg/stop"
"github.com/scylladb/gemini/pkg/store"
"github.com/scylladb/gemini/pkg/typedef"
)
-const (
- WriteMode = "write"
- ReadMode = "read"
- MixedMode = "mixed"
- WarmupMode = "warmup"
-)
-
-const (
- warmupName = "Warmup"
- validateName = "Validation"
- mutateName = "Mutation"
-)
-
-var (
- warmup = job{name: warmupName, function: warmupJob}
- validate = job{name: validateName, function: validationJob}
- mutate = job{name: mutateName, function: mutationJob}
-)
-
-type List struct {
- name string
- jobs []job
- duration time.Duration
- workers uint64
-}
-
-type job struct {
- function func(
- context.Context,
- <-chan time.Duration,
- *typedef.Schema,
- typedef.SchemaConfig,
- *typedef.Table,
- store.Store,
- *rand.Rand,
- *typedef.PartitionRangeConfig,
- *generators.Generator,
- *status.GlobalStatus,
- *zap.Logger,
- *stop.Flag,
- bool,
- bool,
- ) error
- name string
-}
-
-func ListFromMode(mode string, duration time.Duration, workers uint64) List {
- jobs := make([]job, 0, 2)
- name := "work cycle"
- switch mode {
- case WriteMode:
- jobs = append(jobs, mutate)
- case ReadMode:
- jobs = append(jobs, validate)
- case WarmupMode:
- jobs = append(jobs, warmup)
- name = "warmup cycle"
- default:
- jobs = append(jobs, mutate, validate)
+type (
+ Runner struct {
+ duration time.Duration
+ logger *zap.Logger
+ random *rand.Rand
+ workers uint64
+ generators *generators.Generators
+ schema *typedef.Schema
+ warmup time.Duration
+ globalStatus *status.GlobalStatus
+ store store.Store
+ mode Mode
+ failFast bool
+ }
+ Job interface {
+ Name() string
+ Do(context.Context, generators.Interface, *typedef.Table) error
}
- return List{
- name: name,
- jobs: jobs,
- duration: duration,
- workers: workers,
- }
-}
+)
-func (l List) Run(
- ctx context.Context,
- schema *typedef.Schema,
- schemaConfig typedef.SchemaConfig,
- s store.Store,
- pump <-chan time.Duration,
- generators []*generators.Generator,
- globalStatus *status.GlobalStatus,
+func New(
+ mode string,
+ duration time.Duration,
+ workers uint64,
logger *zap.Logger,
- seed uint64,
- stopFlag *stop.Flag,
- failFast, verbose bool,
-) error {
- logger = logger.Named(l.name)
- ctx = stopFlag.CancelContextOnSignal(ctx, stop.SignalHardStop)
- g, gCtx := errgroup.WithContext(ctx)
- time.AfterFunc(l.duration, func() {
- logger.Info("jobs time is up, begins jobs completion")
- stopFlag.SetSoft(true)
- })
-
- partitionRangeConfig := schemaConfig.GetPartitionRangeConfig()
- logger.Info("start jobs")
- for j := range schema.Tables {
- gen := generators[j]
- table := schema.Tables[j]
- for i := 0; i < int(l.workers); i++ {
- for idx := range l.jobs {
- jobF := l.jobs[idx].function
- r := rand.New(rand.NewSource(seed))
- g.Go(func() error {
- return jobF(gCtx, pump, schema, schemaConfig, table, s, r, &partitionRangeConfig, gen, globalStatus, logger, stopFlag, failFast, verbose)
- })
- }
- }
- }
- return g.Wait()
-}
-
-// mutationJob continuously applies mutations against the database
-// for as long as the pump is active.
-func mutationJob(
- ctx context.Context,
- pump <-chan time.Duration,
schema *typedef.Schema,
- schemaCfg typedef.SchemaConfig,
- table *typedef.Table,
- s store.Store,
- r *rand.Rand,
- p *typedef.PartitionRangeConfig,
- g *generators.Generator,
+ store store.Store,
globalStatus *status.GlobalStatus,
- logger *zap.Logger,
- stopFlag *stop.Flag,
- failFast, verbose bool,
-) error {
- schemaConfig := &schemaCfg
- logger = logger.Named("mutation_job")
- logger.Info("starting mutation loop")
- defer func() {
- logger.Info("ending mutation loop")
- }()
- for {
- if stopFlag.IsHardOrSoft() {
- return nil
- }
- select {
- case <-stopFlag.SignalChannel():
- logger.Debug("mutation job terminated")
- return nil
- case hb := <-pump:
- time.Sleep(hb)
- }
- ind := r.Intn(1000000)
- if ind%100000 == 0 {
- err := ddl(ctx, schema, schemaConfig, table, s, r, p, globalStatus, logger, verbose)
- if err != nil {
- return err
- }
- } else {
- err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, true, logger)
- if err != nil {
- return err
- }
- }
- if failFast && globalStatus.HasErrors() {
- stopFlag.SetSoft(true)
- return nil
- }
+ seed uint64,
+ gens *generators.Generators,
+ warmup time.Duration,
+ failFast bool,
+) *Runner {
+ return &Runner{
+ warmup: warmup,
+ globalStatus: globalStatus,
+ store: store,
+ mode: ModeFromString(mode),
+ logger: logger,
+ duration: duration,
+ workers: workers,
+ failFast: failFast,
+ random: rand.New(rand.NewSource(seed)),
+ generators: gens,
+ schema: schema,
}
}
-// validationJob continuously applies validations against the database
-// for as long as the pump is active.
-func validationJob(
- ctx context.Context,
- pump <-chan time.Duration,
- schema *typedef.Schema,
- schemaCfg typedef.SchemaConfig,
- table *typedef.Table,
- s store.Store,
- r *rand.Rand,
- p *typedef.PartitionRangeConfig,
- g *generators.Generator,
- globalStatus *status.GlobalStatus,
- logger *zap.Logger,
- stopFlag *stop.Flag,
- failFast, _ bool,
-) error {
- schemaConfig := &schemaCfg
- logger = logger.Named("validation_job")
- logger.Info("starting validation loop")
- defer func() {
- logger.Info("ending validation loop")
- }()
+func (l *Runner) Run(ctx context.Context) error {
+ l.logger.Info("start jobs")
+ var wg sync.WaitGroup
- for {
- if stopFlag.IsHardOrSoft() {
- return nil
- }
- select {
- case <-stopFlag.SignalChannel():
- return nil
- case hb := <-pump:
- time.Sleep(hb)
- }
- stmt := GenCheckStmt(schema, table, g, r, p)
- if stmt == nil {
- logger.Info("Validation. No statement generated from GenCheckStmt.")
- continue
- }
- err := validation(ctx, schemaConfig, table, s, stmt, logger)
- if stmt.ValuesWithToken != nil {
- for _, token := range stmt.ValuesWithToken {
- g.ReleaseToken(token.Token)
- }
- }
- switch {
- case err == nil:
- globalStatus.ReadOps.Add(1)
- case errors.Is(err, context.Canceled):
- return nil
- default:
- query, prettyErr := stmt.PrettyCQL()
- if prettyErr != nil {
- return PrettyCQLError{
- PrettyCQL: prettyErr,
- Stmt: stmt,
- Err: err,
- }
- }
+ if l.warmup > 0 {
+ l.logger.Info("Warmup Job Started",
+ zap.Int("duration", int(l.warmup.Seconds())),
+ zap.Int("workers", int(l.workers)),
+ )
- globalStatus.AddReadError(&joberror.JobError{
- Timestamp: time.Now(),
- StmtType: stmt.QueryType.String(),
- Message: "Validation failed: " + err.Error(),
- Query: query,
- })
- }
-
- if failFast && globalStatus.HasErrors() {
- stopFlag.SetSoft(true)
- return nil
- }
+ warmupCtx, cancel := context.WithTimeout(ctx, l.warmup)
+ defer cancel()
+ l.startMutation(warmupCtx, cancel, &wg, l.random, "Warmup", false, false)
+ wg.Wait()
}
-}
-// warmupJob continuously applies mutations against the database
-// for as long as the pump is active or the supplied duration expires.
-func warmupJob(
- ctx context.Context,
- _ <-chan time.Duration,
- schema *typedef.Schema,
- schemaCfg typedef.SchemaConfig,
- table *typedef.Table,
- s store.Store,
- r *rand.Rand,
- p *typedef.PartitionRangeConfig,
- g *generators.Generator,
- globalStatus *status.GlobalStatus,
- logger *zap.Logger,
- stopFlag *stop.Flag,
- failFast, _ bool,
-) error {
- schemaConfig := &schemaCfg
- logger = logger.Named("warmup")
- logger.Info("starting warmup loop")
- defer func() {
- logger.Info("ending warmup loop")
- }()
- for {
- if stopFlag.IsHardOrSoft() {
- logger.Debug("warmup job terminated")
- return nil
- }
- // Do we care about errors during warmup?
- err := mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, false, logger)
- if err != nil {
- return err
- }
+ ctx, cancel := context.WithTimeout(ctx, l.duration+1*time.Second)
+ defer cancel()
- if failFast && globalStatus.HasErrors() {
- stopFlag.SetSoft(true)
- return nil
- }
- }
-}
+ src := rand.NewSource(l.random.Uint64())
-func ddl(
- ctx context.Context,
- schema *typedef.Schema,
- sc *typedef.SchemaConfig,
- table *typedef.Table,
- s store.Store,
- r *rand.Rand,
- p *typedef.PartitionRangeConfig,
- globalStatus *status.GlobalStatus,
- logger *zap.Logger,
- verbose bool,
-) error {
- if sc.CQLFeature != typedef.CQL_FEATURE_ALL {
- logger.Debug("ddl statements disabled")
- return nil
- }
- if len(table.MaterializedViews) > 0 {
- // Scylla does not allow changing the DDL of a table with materialized views.
- return nil
- }
- table.Lock()
- defer table.Unlock()
- ddlStmts, err := GenDDLStmt(schema, table, r, p, sc)
- if err != nil {
- logger.Error("Failed! DDL Mutation statement generation failed", zap.Error(err))
- globalStatus.WriteErrors.Add(1)
- return err
+ if l.mode.IsRead() {
+ l.startValidation(ctx, &wg, cancel, src)
}
- if ddlStmts == nil {
- if w := logger.Check(zap.DebugLevel, "no statement generated"); w != nil {
- w.Write(zap.String("job", "ddl"))
- }
- return nil
+ if l.mode.IsWrite() {
+ l.startMutation(ctx, cancel, &wg, src, "Mutation", true, true)
}
- for _, ddlStmt := range ddlStmts.List {
- if w := logger.Check(zap.DebugLevel, "ddl statement"); w != nil {
- prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL()
- if prettyCQLErr != nil {
- return PrettyCQLError{
- PrettyCQL: prettyCQLErr,
- Stmt: ddlStmt,
- }
- }
-
- w.Write(zap.String("pretty_cql", prettyCQL))
- }
-
- if err = s.Mutate(ctx, ddlStmt); err != nil {
- if errors.Is(err, context.Canceled) {
- return nil
- }
+ wg.Wait()
- prettyCQL, prettyCQLErr := ddlStmt.PrettyCQL()
- if prettyCQLErr != nil {
- return PrettyCQLError{
- PrettyCQL: prettyCQLErr,
- Stmt: ddlStmt,
- Err: err,
- }
- }
-
- globalStatus.AddWriteError(&joberror.JobError{
- Timestamp: time.Now(),
- StmtType: ddlStmts.QueryType.String(),
- Message: "DDL failed: " + err.Error(),
- Query: prettyCQL,
- })
-
- return err
- }
- globalStatus.WriteOps.Add(1)
- }
- ddlStmts.PostStmtHook()
- if verbose {
- jsonSchema, _ := json.MarshalIndent(schema, "", " ")
- fmt.Printf("New schema: %v\n", string(jsonSchema))
- }
return nil
}
-func mutation(
- ctx context.Context,
- schema *typedef.Schema,
- _ *typedef.SchemaConfig,
- table *typedef.Table,
- s store.Store,
- r *rand.Rand,
- p *typedef.PartitionRangeConfig,
- g *generators.Generator,
- globalStatus *status.GlobalStatus,
- deletes bool,
- logger *zap.Logger,
-) error {
- mutateStmt, err := GenMutateStmt(schema, table, g, r, p, deletes)
+func (l *Runner) startMutation(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, src rand.Source, name string, deletes, ddl bool) {
+ logger := l.logger.Named(name)
+
+ err := l.start(ctx, cancel, wg, rand.New(src), func(rnd *rand.Rand) Job {
+ return NewMutation(
+ logger,
+ l.schema,
+ l.store,
+ l.globalStatus,
+ rnd,
+ l.failFast,
+ deletes,
+ ddl,
+ )
+ })
if err != nil {
- logger.Error("Failed! Mutation statement generation failed", zap.Error(err))
- globalStatus.WriteErrors.Add(1)
- return err
- }
- if mutateStmt == nil {
- if w := logger.Check(zap.DebugLevel, "no statement generated"); w != nil {
- w.Write(zap.String("job", "mutation"))
- }
- return err
- }
-
- if w := logger.Check(zap.DebugLevel, "mutation statement"); w != nil {
- prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL()
- if prettyCQLErr != nil {
- return PrettyCQLError{
- PrettyCQL: prettyCQLErr,
- Stmt: mutateStmt,
- Err: err,
- }
- }
-
- w.Write(zap.String("pretty_cql", prettyCQL))
- }
- if err = s.Mutate(ctx, mutateStmt); err != nil {
- if errors.Is(err, context.Canceled) {
- return nil
- }
-
- prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL()
- if prettyCQLErr != nil {
- return PrettyCQLError{
- PrettyCQL: prettyCQLErr,
- Stmt: mutateStmt,
- Err: err,
- }
+ logger.Error("Mutation job failed", zap.Error(err))
+ if l.failFast {
+ cancel()
}
-
- globalStatus.AddWriteError(&joberror.JobError{
- Timestamp: time.Now(),
- StmtType: mutateStmt.QueryType.String(),
- Message: "Mutation failed: " + err.Error(),
- Query: prettyCQL,
- })
-
- return err
}
-
- globalStatus.WriteOps.Add(1)
- g.GiveOlds(mutateStmt.ValuesWithToken)
-
- return nil
}
-func validation(
- ctx context.Context,
- sc *typedef.SchemaConfig,
- table *typedef.Table,
- s store.Store,
- stmt *typedef.Stmt,
- logger *zap.Logger,
-) error {
- if w := logger.Check(zap.DebugLevel, "validation statement"); w != nil {
- prettyCQL, prettyCQLErr := stmt.PrettyCQL()
- if prettyCQLErr != nil {
- return PrettyCQLError{
- PrettyCQL: prettyCQLErr,
- Stmt: stmt,
- }
- }
-
- w.Write(zap.String("pretty_cql", prettyCQL))
- }
-
- maxAttempts := 1
- delay := 10 * time.Millisecond
- if stmt.QueryType.PossibleAsyncOperation() {
- maxAttempts = sc.AsyncObjectStabilizationAttempts
- if maxAttempts < 1 {
- maxAttempts = 1
+func (l *Runner) startValidation(ctx context.Context, wg *sync.WaitGroup, cancel context.CancelFunc, src rand.Source) {
+ err := l.start(ctx, cancel, wg, rand.New(src), func(rnd *rand.Rand) Job {
+ return NewValidation(
+ l.logger,
+ l.schema,
+ l.store,
+ rnd,
+ l.globalStatus,
+ l.failFast,
+ )
+ })
+ if err != nil {
+ l.logger.Error("Validation job failed", zap.Error(err))
+ if l.failFast {
+ cancel()
}
- delay = sc.AsyncObjectStabilizationDelay
}
+}
- var lastErr, err error
- attempt := 1
- for {
- lastErr = err
- err = s.Check(ctx, table, stmt, attempt == maxAttempts)
-
- if err == nil {
- if attempt > 1 {
- logger.Info(fmt.Sprintf("Validation successfully completed on %d attempt.", attempt))
- }
- return nil
- }
- if errors.Is(err, context.Canceled) {
- // When context is canceled it means that test was commanded to stop
- // to skip logging part it is returned here
- return err
- }
- if attempt == maxAttempts {
- break
- }
- if errors.Is(err, unWrapErr(lastErr)) {
- logger.Info(fmt.Sprintf("Retring failed validation. %d attempt from %d attempts. Error same as at attempt before. ", attempt, maxAttempts))
- } else {
- logger.Info(fmt.Sprintf("Retring failed validation. %d attempt from %d attempts. Error: %s", attempt, maxAttempts, err))
- }
-
- select {
- case <-time.After(delay):
- case <-ctx.Done():
- logger.Info(fmt.Sprintf("Retring failed validation stoped by done context. %d attempt from %d attempts. Error: %s", attempt, maxAttempts, err))
- return nil
+func (l *Runner) start(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, rnd *rand.Rand, job func(*rand.Rand) Job) error {
+ wg.Add(int(l.workers))
+
+ for _, table := range l.schema.Tables {
+ gen := l.generators.Get()
+ for range l.workers {
+ j := job(rand.New(rand.NewSource(rnd.Uint64())))
+ go func(j Job) {
+ defer wg.Done()
+ if err := j.Do(ctx, gen, table); err != nil {
+ if errors.Is(err, context.Canceled) {
+ return
+ }
+
+ l.logger.Error("job failed", zap.String("table", table.Name), zap.Error(err))
+
+ if l.failFast {
+ cancel()
+ }
+ }
+ }(j)
}
- attempt++
}
- if attempt > 1 {
- logger.Info(fmt.Sprintf("Retring failed validation stoped by reach of max attempts %d. Error: %s", maxAttempts, err))
- } else {
- logger.Info(fmt.Sprintf("Validation failed. Error: %s", err))
- }
-
- return err
-}
-
-func unWrapErr(err error) error {
- nextErr := err
- for nextErr != nil {
- err = nextErr
- nextErr = errors.Unwrap(err)
- }
- return err
+ return nil
}
diff --git a/pkg/jobs/mode.go b/pkg/jobs/mode.go
new file mode 100644
index 00000000..cc2cd109
--- /dev/null
+++ b/pkg/jobs/mode.go
@@ -0,0 +1,36 @@
+package jobs
+
+import "strings"
+
+type Mode []string
+
+const (
+ WriteMode = "write"
+ ReadMode = "read"
+ MixedMode = "mixed"
+)
+
+func (m Mode) IsWrite() bool {
+ return m[0] == WriteMode
+}
+
+func (m Mode) IsRead() bool {
+ if len(m) == 1 {
+ return m[0] == ReadMode
+ }
+
+ return m[0] == ReadMode || m[1] == ReadMode
+}
+
+func ModeFromString(m string) Mode {
+ switch strings.ToLower(m) {
+ case WriteMode:
+ return Mode{WriteMode}
+ case ReadMode:
+ return Mode{ReadMode}
+ case MixedMode:
+ return Mode{WriteMode, ReadMode}
+ default:
+ panic("unknown mode " + m)
+ }
+}
diff --git a/pkg/jobs/mutation.go b/pkg/jobs/mutation.go
new file mode 100644
index 00000000..6b2068d5
--- /dev/null
+++ b/pkg/jobs/mutation.go
@@ -0,0 +1,155 @@
+package jobs
+
+import (
+ "context"
+ "time"
+
+ "github.com/pkg/errors"
+ "go.uber.org/zap"
+ "golang.org/x/exp/rand"
+
+ "github.com/scylladb/gemini/pkg/generators"
+ "github.com/scylladb/gemini/pkg/generators/statements"
+ "github.com/scylladb/gemini/pkg/joberror"
+ "github.com/scylladb/gemini/pkg/status"
+ "github.com/scylladb/gemini/pkg/store"
+ "github.com/scylladb/gemini/pkg/typedef"
+)
+
+type (
+ Mutation struct {
+ logger *zap.Logger
+ mutation mutation
+ failFast bool
+ }
+
+ mutation struct {
+ logger *zap.Logger
+ schema *typedef.Schema
+ store store.Store
+ globalStatus *status.GlobalStatus
+ random *rand.Rand
+ deletes bool
+ ddl bool
+ }
+)
+
+func NewMutation(
+ logger *zap.Logger,
+ schema *typedef.Schema,
+ store store.Store,
+ globalStatus *status.GlobalStatus,
+ rnd *rand.Rand,
+ failFast bool,
+ deletes bool,
+ ddl bool,
+) *Mutation {
+ return &Mutation{
+ logger: logger,
+ mutation: mutation{
+ logger: logger.Named("mutation-with-deletes"),
+ schema: schema,
+ store: store,
+ globalStatus: globalStatus,
+ deletes: deletes,
+ random: rnd,
+ ddl: ddl,
+ },
+ failFast: failFast,
+ }
+}
+
+func (m *Mutation) Name() string {
+ return "Mutation"
+}
+
+func (m *Mutation) Do(ctx context.Context, generator generators.Interface, table *typedef.Table) error {
+ m.logger.Info("starting mutation loop")
+ defer m.logger.Info("ending mutation loop")
+
+ for {
+ select {
+ case <-ctx.Done():
+ m.logger.Debug("mutation job terminated")
+ return context.Canceled
+ default:
+ }
+
+ var err error
+
+ if m.mutation.ShouldDoDDL() {
+ err = m.mutation.DDL(ctx, table)
+ } else {
+ err = m.mutation.Statement(ctx, generator, table)
+ }
+
+ if err != nil {
+ return errors.WithStack(err)
+ }
+ }
+}
+
+func (m *mutation) Statement(ctx context.Context, generator generators.Interface, table *typedef.Table) error {
+ partitionRangeConfig := m.schema.Config.GetPartitionRangeConfig()
+ mutateStmt, err := statements.GenMutateStmt(m.schema, table, generator, m.random, &partitionRangeConfig, m.deletes)
+ if err != nil {
+ m.logger.Error("Failed! Mutation statement generation failed", zap.Error(err))
+ m.globalStatus.WriteErrors.Add(1)
+ return errors.WithStack(err)
+ }
+
+ if w := m.logger.Check(zap.DebugLevel, "mutation statement"); w != nil {
+ prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL()
+ if prettyCQLErr != nil {
+ return errors.WithStack(PrettyCQLError{
+ PrettyCQL: prettyCQLErr,
+ Stmt: mutateStmt,
+ Err: err,
+ })
+ }
+
+ w.Write(zap.String("pretty_cql", prettyCQL))
+ }
+
+ if err = m.store.Mutate(ctx, mutateStmt); err != nil {
+ if errors.Is(err, context.Canceled) {
+ return nil
+ }
+
+ prettyCQL, prettyCQLErr := mutateStmt.PrettyCQL()
+ if prettyCQLErr != nil {
+ return errors.WithStack(PrettyCQLError{
+ PrettyCQL: prettyCQLErr,
+ Stmt: mutateStmt,
+ Err: err,
+ })
+ }
+
+ m.globalStatus.AddWriteError(&joberror.JobError{
+ Timestamp: time.Now(),
+ StmtType: mutateStmt.QueryType.String(),
+ Message: "Mutation failed: " + err.Error(),
+ Query: prettyCQL,
+ })
+
+ return err
+ }
+
+ m.globalStatus.WriteOps.Add(1)
+ generator.GiveOld(mutateStmt.ValuesWithToken...)
+
+ return nil
+}
+
+func (m *mutation) HasErrors() bool {
+ return m.globalStatus.HasErrors()
+}
+
+func (m *mutation) ShouldDoDDL() bool {
+ if m.ddl && m.schema.Config.CQLFeature == typedef.CQL_FEATURE_ALL {
+ ind := m.random.Intn(100000)
+ return ind < 100
+ }
+
+ return false
+}
diff --git a/pkg/jobs/validation.go b/pkg/jobs/validation.go
new file mode 100644
index 00000000..a8e0fb20
--- /dev/null
+++ b/pkg/jobs/validation.go
@@ -0,0 +1,159 @@
+package jobs
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/pkg/errors"
+ "go.uber.org/zap"
+ "golang.org/x/exp/rand"
+
+ "github.com/scylladb/gemini/pkg/generators"
+ "github.com/scylladb/gemini/pkg/generators/statements"
+ "github.com/scylladb/gemini/pkg/joberror"
+ "github.com/scylladb/gemini/pkg/status"
+ "github.com/scylladb/gemini/pkg/store"
+ "github.com/scylladb/gemini/pkg/typedef"
+)
+
+type Validation struct {
+ logger *zap.Logger
+ schema *typedef.Schema
+ store store.Store
+ random *rand.Rand
+ globalStatus *status.GlobalStatus
+ failFast bool
+}
+
+func NewValidation(
+ logger *zap.Logger,
+ schema *typedef.Schema,
+ store store.Store,
+ random *rand.Rand,
+ globalStatus *status.GlobalStatus,
+ failFast bool,
+) *Validation {
+ return &Validation{
+ logger: logger.Named("validation"),
+ schema: schema,
+ store: store,
+ random: random,
+ globalStatus: globalStatus,
+ failFast: failFast,
+ }
+}
+
+func (v *Validation) Name() string {
+ return "Validation"
+}
+
+func (v *Validation) validate(ctx context.Context, generator generators.Interface, table *typedef.Table) error {
+
+ err := v.validation(ctx, table, generator)
+
+ switch {
+ case err == nil:
+ v.globalStatus.ReadOps.Add(1)
+ case errors.Is(err, context.Canceled):
+ return context.Canceled
+ default:
+ v.globalStatus.AddReadError(&joberror.JobError{
+ Timestamp: time.Now(),
+ //StmtType: stmt.QueryType.String(),
+ Message: "Validation failed: " + err.Error(),
+ })
+
+ if v.failFast && v.globalStatus.HasErrors() {
+ return err
+ }
+ }
+
+ return err
+}
+
+func (v *Validation) Do(ctx context.Context, generator generators.Interface, table *typedef.Table) error {
+ v.logger.Info("starting validation loop")
+ defer v.logger.Info("ending validation loop")
+
+ for {
+ select {
+ case <-ctx.Done():
+ v.logger.Info("Context Done...")
+ return nil
+ default:
+ }
+
+ if err := v.validate(ctx, generator, table); errors.Is(err, context.Canceled) {
+ return nil
+ }
+
+ if v.failFast && v.globalStatus.HasErrors() {
+ return nil
+ }
+ }
+}
+
+func (v *Validation) validation(
+ ctx context.Context,
+ table *typedef.Table,
+ generator generators.Interface,
+) error {
+ partitionRangeConfig := v.schema.Config.GetPartitionRangeConfig()
+ stmt, cleanup := statements.GenCheckStmt(v.schema, table, generator, v.random, &partitionRangeConfig)
+ defer cleanup()
+
+ if w := v.logger.Check(zap.DebugLevel, "validation statement"); w != nil {
+ prettyCQL, prettyCQLErr := stmt.PrettyCQL()
+ if prettyCQLErr != nil {
+ return PrettyCQLError{
+ PrettyCQL: prettyCQLErr,
+ }
+ }
+
+ w.Write(zap.String("pretty_cql", prettyCQL))
+ }
+
+ maxAttempts := v.schema.Config.AsyncObjectStabilizationAttempts
+ delay := time.Duration(0)
+ if maxAttempts < 1 {
+ maxAttempts = 1
+ }
+
+ var err error
+
+ for attempt := 1; ; attempt++ {
+ select {
+ case <-ctx.Done():
+ v.logger.Info("Context Done... validation exiting")
+ return context.Canceled
+ case <-time.After(delay):
+ delay = v.schema.Config.AsyncObjectStabilizationDelay
+ }
+
+ err = v.store.Check(ctx, table, stmt, attempt == maxAttempts)
+
+ if err == nil {
+ if attempt > 1 {
+ v.logger.Info(fmt.Sprintf("Validation successfully completed on %d attempt.", attempt))
+ }
+ return nil
+ }
+
+ if errors.Is(err, context.Canceled) {
+ return context.Canceled
+ }
+
+ if attempt == maxAttempts {
+ if attempt > 1 {
+ v.logger.Info(fmt.Sprintf("Retring failed validation stoped by reach of max attempts %d. Error: %s", maxAttempts, err))
+ } else {
+ v.logger.Info(fmt.Sprintf("Validation failed. Error: %s", err))
+ }
+
+ return err
+ }
+
+ v.logger.Info(fmt.Sprintf("Retring failed validation. %d attempt from %d attempts. Error: %s", attempt, maxAttempts, err))
+ }
+}
diff --git a/pkg/replication/replication.go b/pkg/replication/replication.go
index e0e26ef6..af059e29 100644
--- a/pkg/replication/replication.go
+++ b/pkg/replication/replication.go
@@ -21,9 +21,30 @@ import (
type Replication map[string]any
+var (
+ singleQuoteReplacer = strings.NewReplacer("'", "\"")
+ doubleQuoteReplacer = strings.NewReplacer("\"", "'")
+)
+
func (r *Replication) ToCQL() string {
b, _ := json.Marshal(r)
- return strings.ReplaceAll(string(b), "\"", "'")
+ return doubleQuoteReplacer.Replace(string(b))
+}
+
+func MustParseReplication(rs string) *Replication {
+ switch rs {
+ case "network":
+ return NewNetworkTopologyStrategy()
+ case "simple":
+ return NewSimpleStrategy()
+ default:
+ var strategy Replication
+
+ if err := json.Unmarshal([]byte(singleQuoteReplacer.Replace(rs)), &strategy); err != nil {
+ panic("unable to parse replication strategy: " + rs + " Error: " + err.Error())
+ }
+ return &strategy
+ }
}
func NewSimpleStrategy() *Replication {
diff --git a/pkg/replication/replication_test.go b/pkg/replication/replication_test.go
index 82ed471c..9b25e714 100644
--- a/pkg/replication/replication_test.go
+++ b/pkg/replication/replication_test.go
@@ -17,6 +17,8 @@ package replication_test
import (
"testing"
+ "github.com/google/go-cmp/cmp"
+
"github.com/scylladb/gemini/pkg/replication"
)
@@ -46,3 +48,36 @@ func TestToCQL(t *testing.T) {
})
}
}
+
+func TestGetReplicationStrategy(t *testing.T) {
+ tests := map[string]struct {
+ strategy string
+ expected string
+ }{
+ "simple strategy": {
+ strategy: "{\"class\": \"SimpleStrategy\", \"replication_factor\": \"1\"}",
+ expected: "{'class':'SimpleStrategy','replication_factor':'1'}",
+ },
+ "simple strategy single quotes": {
+ strategy: "{'class': 'SimpleStrategy', 'replication_factor': '1'}",
+ expected: "{'class':'SimpleStrategy','replication_factor':'1'}",
+ },
+ "network topology strategy": {
+ strategy: "{\"class\": \"NetworkTopologyStrategy\", \"dc1\": 3, \"dc2\": 3}",
+ expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}",
+ },
+ "network topology strategy single quotes": {
+ strategy: "{'class': 'NetworkTopologyStrategy', 'dc1': 3, 'dc2': 3}",
+ expected: "{'class':'NetworkTopologyStrategy','dc1':3,'dc2':3}",
+ },
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ got := replication.MustParseReplication(tc.strategy)
+ if diff := cmp.Diff(got.ToCQL(), tc.expected); diff != "" {
+ t.Errorf("expected=%s, got=%s,diff=%s", tc.strategy, got.ToCQL(), diff)
+ }
+ })
+ }
+}
diff --git a/pkg/status/status.go b/pkg/status/status.go
index cbaf1c9d..2568f067 100644
--- a/pkg/status/status.go
+++ b/pkg/status/status.go
@@ -19,6 +19,7 @@ import (
"fmt"
"io"
"sync/atomic"
+ "time"
"github.com/pkg/errors"
@@ -43,25 +44,21 @@ type GlobalStatus struct {
}
func (gs *GlobalStatus) AddWriteError(err *joberror.JobError) {
- // TODO: https://github.com/scylladb/gemini/issues/302 - Move out and add logging
- fmt.Printf("Error detected: %#v", err)
gs.Errors.AddError(err)
gs.WriteErrors.Add(1)
}
func (gs *GlobalStatus) AddReadError(err *joberror.JobError) {
- // TODO: https://github.com/scylladb/gemini/issues/302 - Move out and add logging
- fmt.Printf("Error detected: %#v", err)
gs.Errors.AddError(err)
gs.ReadErrors.Add(1)
}
-func (gs *GlobalStatus) PrintResultAsJSON(w io.Writer, schema *typedef.Schema, version string) error {
+func (gs *GlobalStatus) PrintResultAsJSON(w io.Writer, schema *typedef.Schema, version string, start time.Time) error {
result := map[string]any{
"result": gs,
"gemini_version": version,
"schemaHash": schema.GetHash(),
- "schema": schema,
+ "Time": time.Since(start).String(),
}
encoder := json.NewEncoder(w)
encoder.SetEscapeHTML(false)
@@ -81,12 +78,13 @@ func (gs *GlobalStatus) HasErrors() bool {
return gs.WriteErrors.Load() > 0 || gs.ReadErrors.Load() > 0
}
-func (gs *GlobalStatus) PrintResult(w io.Writer, schema *typedef.Schema, version string) {
- if err := gs.PrintResultAsJSON(w, schema, version); err != nil {
+func (gs *GlobalStatus) PrintResult(w io.Writer, schema *typedef.Schema, version string, start time.Time) {
+ if err := gs.PrintResultAsJSON(w, schema, version, start); err != nil {
// In case there has been it has been a long run we want to display it anyway...
fmt.Printf("Unable to print result as json, using plain text to stdout, error=%s\n", err)
fmt.Printf("Gemini version: %s\n", version)
fmt.Printf("Results:\n")
+ fmt.Printf("\ttime: %v\n", time.Since(start).String())
fmt.Printf("\twrite ops: %v\n", gs.WriteOps.Load())
fmt.Printf("\tread ops: %v\n", gs.ReadOps.Load())
fmt.Printf("\twrite errors: %v\n", gs.WriteErrors.Load())
@@ -94,8 +92,6 @@ func (gs *GlobalStatus) PrintResult(w io.Writer, schema *typedef.Schema, version
for i, err := range gs.Errors.Errors() {
fmt.Printf("Error %d: %s\n", i, err)
}
- jsonSchema, _ := json.MarshalIndent(schema, "", " ")
- fmt.Printf("Schema: %v\n", string(jsonSchema))
}
}
diff --git a/pkg/stop/flag.go b/pkg/stop/flag.go
deleted file mode 100644
index 54c6e1f2..00000000
--- a/pkg/stop/flag.go
+++ /dev/null
@@ -1,221 +0,0 @@
-// Copyright 2019 ScyllaDB
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stop
-
-import (
- "context"
- "fmt"
- "os"
- "os/signal"
- "sync"
- "sync/atomic"
- "syscall"
-
- "go.uber.org/zap"
-)
-
-const (
- SignalNoop uint32 = iota
- SignalSoftStop
- SignalHardStop
-)
-
-type SignalChannel chan uint32
-
-var closedChan = createClosedChan()
-
-func createClosedChan() SignalChannel {
- ch := make(SignalChannel)
- close(ch)
- return ch
-}
-
-type SyncList[T any] struct {
- children []T
- childrenLock sync.RWMutex
-}
-
-func (f *SyncList[T]) Append(el T) {
- f.childrenLock.Lock()
- defer f.childrenLock.Unlock()
- f.children = append(f.children, el)
-}
-
-func (f *SyncList[T]) Get() []T {
- f.childrenLock.RLock()
- defer f.childrenLock.RUnlock()
- return f.children
-}
-
-type logger interface {
- Debug(msg string, fields ...zap.Field)
-}
-
-type Flag struct {
- name string
- log logger
- ch atomic.Pointer[SignalChannel]
- parent *Flag
- children SyncList[*Flag]
- stopHandlers SyncList[func(signal uint32)]
- val atomic.Uint32
-}
-
-func (s *Flag) Name() string {
- return s.name
-}
-
-func (s *Flag) closeChannel() {
- ch := s.ch.Swap(&closedChan)
- if ch != &closedChan {
- close(*ch)
- }
-}
-
-func (s *Flag) sendSignal(signal uint32, sendToParent bool) bool {
- s.log.Debug(fmt.Sprintf("flag %s received signal %s", s.name, GetStateName(signal)))
- s.closeChannel()
- out := s.val.CompareAndSwap(SignalNoop, signal)
- if !out {
- return false
- }
-
- for _, handler := range s.stopHandlers.Get() {
- handler(signal)
- }
-
- for _, child := range s.children.Get() {
- child.sendSignal(signal, sendToParent)
- }
- if sendToParent && s.parent != nil {
- s.parent.sendSignal(signal, sendToParent)
- }
- return out
-}
-
-func (s *Flag) SetHard(sendToParent bool) bool {
- return s.sendSignal(SignalHardStop, sendToParent)
-}
-
-func (s *Flag) SetSoft(sendToParent bool) bool {
- return s.sendSignal(SignalSoftStop, sendToParent)
-}
-
-func (s *Flag) CreateChild(name string) *Flag {
- child := newFlag(name, s)
- s.children.Append(child)
- val := s.val.Load()
- switch val {
- case SignalSoftStop, SignalHardStop:
- child.sendSignal(val, false)
- }
- return child
-}
-
-func (s *Flag) SignalChannel() SignalChannel {
- return *s.ch.Load()
-}
-
-func (s *Flag) IsSoft() bool {
- return s.val.Load() == SignalSoftStop
-}
-
-func (s *Flag) IsHard() bool {
- return s.val.Load() == SignalHardStop
-}
-
-func (s *Flag) IsHardOrSoft() bool {
- return s.val.Load() != SignalNoop
-}
-
-func (s *Flag) AddHandler(handler func(signal uint32)) {
- s.stopHandlers.Append(handler)
- val := s.val.Load()
- switch val {
- case SignalSoftStop, SignalHardStop:
- handler(val)
- }
-}
-
-func (s *Flag) AddHandler2(handler func(), expectedSignal uint32) {
- s.AddHandler(func(signal uint32) {
- switch expectedSignal {
- case SignalNoop:
- handler()
- default:
- if signal == expectedSignal {
- handler()
- }
- }
- })
-}
-
-func (s *Flag) CancelContextOnSignal(ctx context.Context, expectedSignal uint32) context.Context {
- ctx, cancel := context.WithCancel(ctx)
- s.AddHandler2(cancel, expectedSignal)
- return ctx
-}
-
-func (s *Flag) SetLogger(log logger) {
- s.log = log
-}
-
-func NewFlag(name string) *Flag {
- return newFlag(name, nil)
-}
-
-func newFlag(name string, parent *Flag) *Flag {
- out := Flag{
- name: name,
- parent: parent,
- log: zap.NewNop(),
- }
- ch := make(SignalChannel)
- out.ch.Store(&ch)
- return &out
-}
-
-func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) {
- graceful := make(chan os.Signal, 1)
- signal.Notify(graceful, syscall.SIGTERM, syscall.SIGINT)
- go func() {
- sig := <-graceful
- switch sig {
- case syscall.SIGINT:
- for i := range flags {
- flags[i].SetSoft(true)
- }
- logger.Info("Get SIGINT signal, begin soft stop.")
- default:
- for i := range flags {
- flags[i].SetHard(true)
- }
- logger.Info("Get SIGTERM signal, begin hard stop.")
- }
- }()
-}
-
-func GetStateName(state uint32) string {
- switch state {
- case SignalSoftStop:
- return "soft"
- case SignalHardStop:
- return "hard"
- case SignalNoop:
- return "no-signal"
- default:
- panic(fmt.Sprintf("unexpected signal %d", state))
- }
-}
diff --git a/pkg/stop/flag_test.go b/pkg/stop/flag_test.go
deleted file mode 100644
index 81a4b344..00000000
--- a/pkg/stop/flag_test.go
+++ /dev/null
@@ -1,429 +0,0 @@
-// Copyright 2019 ScyllaDB
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stop_test
-
-import (
- "context"
- "errors"
- "fmt"
- "reflect"
- "runtime"
- "strings"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/scylladb/gemini/pkg/stop"
-)
-
-func TestHardStop(t *testing.T) {
- t.Parallel()
- testFlag, ctx, workersDone := initVars()
- workers := 30
-
- testSignals(t, workersDone, workers, testFlag.IsHard, testFlag.SetHard)
- if ctx.Err() == nil {
- t.Error("Error:SetHard function does not apply hardStopHandler")
- }
-}
-
-func TestSoftStop(t *testing.T) {
- t.Parallel()
- testFlag, ctx, workersDone := initVars()
- workers := 30
-
- testSignals(t, workersDone, workers, testFlag.IsSoft, testFlag.SetSoft)
- if ctx.Err() != nil {
- t.Error("Error:SetSoft function apply hardStopHandler")
- }
-}
-
-func TestSoftOrHardStop(t *testing.T) {
- t.Parallel()
- testFlag, ctx, workersDone := initVars()
- workers := 30
-
- testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetSoft)
- if ctx.Err() != nil {
- t.Error("Error:SetSoft function apply hardStopHandler")
- }
-
- workersDone.Store(uint32(0))
- testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard)
- if ctx.Err() != nil {
- t.Error("Error:SetHard function apply hardStopHandler after SetSoft")
- }
-
- testFlag, ctx, workersDone = initVars()
- workersDone.Store(uint32(0))
-
- testSignals(t, workersDone, workers, testFlag.IsHardOrSoft, testFlag.SetHard)
- if ctx.Err() == nil {
- t.Error("Error:SetHard function does not apply hardStopHandler")
- }
-}
-
-func initVars() (testFlag *stop.Flag, ctx context.Context, workersDone *atomic.Uint32) {
- testFlagOut := stop.NewFlag("main_test")
- ctx = testFlagOut.CancelContextOnSignal(context.Background(), stop.SignalHardStop)
- workersDone = &atomic.Uint32{}
- return testFlagOut, ctx, workersDone
-}
-
-func testSignals(
- t *testing.T,
- workersDone *atomic.Uint32,
- workers int,
- checkFunc func() bool,
- setFunc func(propagation bool) bool,
-) {
- t.Helper()
- for i := 0; i != workers; i++ {
- go func() {
- for {
- if checkFunc() {
- workersDone.Add(1)
- return
- }
- time.Sleep(10 * time.Millisecond)
- }
- }()
- }
- time.Sleep(200 * time.Millisecond)
- setFunc(false)
-
- for i := 0; i != 10; i++ {
- time.Sleep(100 * time.Millisecond)
- if workersDone.Load() == uint32(workers) {
- break
- }
- }
-
- setFuncName := runtime.FuncForPC(reflect.ValueOf(setFunc).Pointer()).Name()
- setFuncName, _ = strings.CutSuffix(setFuncName, "-fm")
- _, setFuncName, _ = strings.Cut(setFuncName, ".(")
- setFuncName = strings.ReplaceAll(setFuncName, ").", ".")
- checkFuncName := runtime.FuncForPC(reflect.ValueOf(checkFunc).Pointer()).Name()
- checkFuncName, _ = strings.CutSuffix(checkFuncName, "-fm")
- _, checkFuncName, _ = strings.Cut(checkFuncName, ".(")
- checkFuncName = strings.ReplaceAll(checkFuncName, ").", ".")
-
- if workersDone.Load() != uint32(workers) {
- t.Errorf("Error:%s or %s functions works not correctly %[2]s=%v", setFuncName, checkFuncName, checkFunc())
- }
-}
-
-func TestSendToParent(t *testing.T) {
- t.Parallel()
- tcases := []tCase{
- {
- testName: "parent-hard-true",
- parentSignal: stop.SignalHardStop,
- child1Signal: stop.SignalHardStop,
- child11Signal: stop.SignalHardStop,
- child12Signal: stop.SignalHardStop,
- child2Signal: stop.SignalHardStop,
- },
- {
- testName: "parent-hard-false",
- parentSignal: stop.SignalHardStop,
- child1Signal: stop.SignalHardStop,
- child11Signal: stop.SignalHardStop,
- child12Signal: stop.SignalHardStop,
- child2Signal: stop.SignalHardStop,
- },
- {
- testName: "parent-soft-true",
- parentSignal: stop.SignalSoftStop,
- child1Signal: stop.SignalSoftStop,
- child11Signal: stop.SignalSoftStop,
- child12Signal: stop.SignalSoftStop,
- child2Signal: stop.SignalSoftStop,
- },
- {
- testName: "parent-soft-false",
- parentSignal: stop.SignalSoftStop,
- child1Signal: stop.SignalSoftStop,
- child11Signal: stop.SignalSoftStop,
- child12Signal: stop.SignalSoftStop,
- child2Signal: stop.SignalSoftStop,
- },
- {
- testName: "child1-soft-true",
- parentSignal: stop.SignalSoftStop,
- child1Signal: stop.SignalSoftStop,
- child11Signal: stop.SignalSoftStop,
- child12Signal: stop.SignalSoftStop,
- child2Signal: stop.SignalSoftStop,
- },
- {
- testName: "child1-soft-false",
- parentSignal: stop.SignalNoop,
- child1Signal: stop.SignalSoftStop,
- child11Signal: stop.SignalSoftStop,
- child12Signal: stop.SignalSoftStop,
- child2Signal: stop.SignalNoop,
- },
- {
- testName: "child11-soft-true",
- parentSignal: stop.SignalSoftStop,
- child1Signal: stop.SignalSoftStop,
- child11Signal: stop.SignalSoftStop,
- child12Signal: stop.SignalSoftStop,
- child2Signal: stop.SignalSoftStop,
- },
- {
- testName: "child11-soft-false",
- parentSignal: stop.SignalNoop,
- child1Signal: stop.SignalNoop,
- child11Signal: stop.SignalSoftStop,
- child12Signal: stop.SignalNoop,
- child2Signal: stop.SignalNoop,
- },
- }
- for id := range tcases {
- tcase := tcases[id]
- t.Run(tcase.testName, func(t *testing.T) {
- t.Parallel()
- if err := tcase.runTest(); err != nil {
- t.Error(err)
- }
- })
- }
-}
-
-// nolint: govet
-type parentChildInfo struct {
- parent *stop.Flag
- parentSignal uint32
- child1 *stop.Flag
- child1Signal uint32
- child11 *stop.Flag
- child11Signal uint32
- child12 *stop.Flag
- child12Signal uint32
- child2 *stop.Flag
- child2Signal uint32
-}
-
-func (t *parentChildInfo) getFlag(flagName string) *stop.Flag {
- switch flagName {
- case "parent":
- return t.parent
- case "child1":
- return t.child1
- case "child2":
- return t.child2
- case "child11":
- return t.child11
- case "child12":
- return t.child12
- default:
- panic(fmt.Sprintf("no such flag %s", flagName))
- }
-}
-
-func (t *parentChildInfo) getFlagHandlerState(flagName string) uint32 {
- switch flagName {
- case "parent":
- return t.parentSignal
- case "child1":
- return t.child1Signal
- case "child2":
- return t.child2Signal
- case "child11":
- return t.child11Signal
- case "child12":
- return t.child12Signal
- default:
- panic(fmt.Sprintf("no such flag %s", flagName))
- }
-}
-
-func (t *parentChildInfo) checkFlagState(flag *stop.Flag, expectedState uint32) error {
- var err error
- flagName := flag.Name()
- state := t.getFlagHandlerState(flagName)
- if state != expectedState {
- err = errors.Join(err, fmt.Errorf("flag %s handler has state %s while it is expected to be %s", flagName, stop.GetStateName(state), stop.GetStateName(expectedState)))
- }
- flagState := getFlagState(flag)
- if stop.GetStateName(expectedState) != flagState {
- err = errors.Join(err, fmt.Errorf("flag %s has state %s while it is expected to be %s", flagName, flagState, stop.GetStateName(expectedState)))
- }
- return err
-}
-
-type tCase struct {
- testName string
- parentSignal uint32
- child1Signal uint32
- child11Signal uint32
- child12Signal uint32
- child2Signal uint32
-}
-
-func (t *tCase) runTest() error {
- chunk := strings.Split(t.testName, "-")
- if len(chunk) != 3 {
- panic(fmt.Sprintf("wrong test name %s", t.testName))
- }
- flagName := chunk[0]
- signalTypeName := chunk[1]
- sendToParentName := chunk[2]
-
- var sendToParent bool
- switch sendToParentName {
- case "true":
- sendToParent = true
- case "false":
- sendToParent = false
- default:
- panic(fmt.Sprintf("wrong test name %s", t.testName))
- }
- runt := newParentChildInfo()
- flag := runt.getFlag(flagName)
- switch signalTypeName {
- case "soft":
- flag.SetSoft(sendToParent)
- case "hard":
- flag.SetHard(sendToParent)
- default:
- panic(fmt.Sprintf("wrong test name %s", t.testName))
- }
- var err error
- err = errors.Join(err, runt.checkFlagState(runt.parent, t.parentSignal))
- err = errors.Join(err, runt.checkFlagState(runt.child1, t.child1Signal))
- err = errors.Join(err, runt.checkFlagState(runt.child2, t.child2Signal))
- err = errors.Join(err, runt.checkFlagState(runt.child11, t.child11Signal))
- err = errors.Join(err, runt.checkFlagState(runt.child12, t.child12Signal))
- return err
-}
-
-func newParentChildInfo() *parentChildInfo {
- parent := stop.NewFlag("parent")
- child1 := parent.CreateChild("child1")
- out := parentChildInfo{
- parent: parent,
- child1: child1,
- child11: child1.CreateChild("child11"),
- child12: child1.CreateChild("child12"),
- child2: parent.CreateChild("child2"),
- }
-
- out.parent.AddHandler(func(signal uint32) {
- out.parentSignal = signal
- })
- out.child1.AddHandler(func(signal uint32) {
- out.child1Signal = signal
- })
- out.child11.AddHandler(func(signal uint32) {
- out.child11Signal = signal
- })
- out.child12.AddHandler(func(signal uint32) {
- out.child12Signal = signal
- })
- out.child2.AddHandler(func(signal uint32) {
- out.child2Signal = signal
- })
- return &out
-}
-
-func getFlagState(flag *stop.Flag) string {
- switch {
- case flag.IsSoft():
- return "soft"
- case flag.IsHard():
- return "hard"
- default:
- return "no-signal"
- }
-}
-
-func TestSignalChannel(t *testing.T) {
- t.Parallel()
- t.Run("single-no-signal", func(t *testing.T) {
- t.Parallel()
- flag := stop.NewFlag("parent")
- select {
- case <-flag.SignalChannel():
- t.Error("should not get the signal")
- case <-time.Tick(200 * time.Millisecond):
- }
- })
-
- t.Run("single-beforehand", func(t *testing.T) {
- t.Parallel()
- flag := stop.NewFlag("parent")
- flag.SetSoft(true)
- <-flag.SignalChannel()
- })
-
- t.Run("single-normal", func(t *testing.T) {
- t.Parallel()
- flag := stop.NewFlag("parent")
- go func() {
- time.Sleep(200 * time.Millisecond)
- flag.SetSoft(true)
- }()
- <-flag.SignalChannel()
- })
-
- t.Run("parent-beforehand", func(t *testing.T) {
- t.Parallel()
- parent := stop.NewFlag("parent")
- child := parent.CreateChild("child")
- parent.SetSoft(true)
- <-child.SignalChannel()
- })
-
- t.Run("parent-beforehand", func(t *testing.T) {
- t.Parallel()
- parent := stop.NewFlag("parent")
- parent.SetSoft(true)
- child := parent.CreateChild("child")
- <-child.SignalChannel()
- })
-
- t.Run("parent-normal", func(t *testing.T) {
- t.Parallel()
- parent := stop.NewFlag("parent")
- child := parent.CreateChild("child")
- go func() {
- time.Sleep(200 * time.Millisecond)
- parent.SetSoft(true)
- }()
- <-child.SignalChannel()
- })
-
- t.Run("child-beforehand", func(t *testing.T) {
- t.Parallel()
- parent := stop.NewFlag("parent")
- child := parent.CreateChild("child")
- child.SetSoft(true)
- <-parent.SignalChannel()
- })
-
- t.Run("child-normal", func(t *testing.T) {
- t.Parallel()
- parent := stop.NewFlag("parent")
- child := parent.CreateChild("child")
- go func() {
- time.Sleep(200 * time.Millisecond)
- child.SetSoft(true)
- }()
- <-parent.SignalChannel()
- })
-}
diff --git a/pkg/store/helpers.go b/pkg/store/helpers.go
index 6758851e..d8cc39bb 100644
--- a/pkg/store/helpers.go
+++ b/pkg/store/helpers.go
@@ -79,7 +79,6 @@ func lt(mi, mj map[string]any) bool {
return true
default:
msg := fmt.Sprintf("unhandled type %T\n", mis)
- time.Sleep(time.Second)
panic(msg)
}
}
diff --git a/pkg/store/store.go b/pkg/store/store.go
index 51095b2d..b434d08b 100644
--- a/pkg/store/store.go
+++ b/pkg/store/store.go
@@ -233,9 +233,11 @@ func (ds delegatingStore) Check(ctx context.Context, table *typedef.Table, stmt
}),
cmp.Comparer(func(x, y *inf.Dec) bool {
return x.Cmp(y) == 0
- }), cmp.Comparer(func(x, y *big.Int) bool {
+ }),
+ cmp.Comparer(func(x, y *big.Int) bool {
return x.Cmp(y) == 0
- }))
+ }),
+ )
if diff != "" {
return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff)
}
diff --git a/pkg/testutils/mock_generator.go b/pkg/testutils/mock_generator.go
index eb6c0184..add33585 100644
--- a/pkg/testutils/mock_generator.go
+++ b/pkg/testutils/mock_generator.go
@@ -19,10 +19,13 @@ import (
"golang.org/x/exp/rand"
+ "github.com/scylladb/gemini/pkg/generators"
"github.com/scylladb/gemini/pkg/routingkey"
"github.com/scylladb/gemini/pkg/typedef"
)
+var _ generators.Interface = &MockGenerator{}
+
type MockGenerator struct {
table *typedef.Table
rand *rand.Rand
@@ -57,9 +60,7 @@ func (g *MockGenerator) GetOld() *typedef.ValueWithToken {
return &typedef.ValueWithToken{Token: token, Value: values}
}
-func (g *MockGenerator) GiveOld(_ *typedef.ValueWithToken) {}
-
-func (g *MockGenerator) GiveOlds(_ []*typedef.ValueWithToken) {}
+func (g *MockGenerator) GiveOld(_ ...*typedef.ValueWithToken) {}
func (g *MockGenerator) ReleaseToken(_ uint64) {
}
diff --git a/pkg/typedef/feature.go b/pkg/typedef/feature.go
new file mode 100644
index 00000000..5b0b215c
--- /dev/null
+++ b/pkg/typedef/feature.go
@@ -0,0 +1,28 @@
+// Copyright 2019 ScyllaDB
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package typedef
+
+import "strings"
+
+func ParseCQLFeature(feature string) CQLFeature {
+ switch strings.ToLower(feature) {
+ case "all":
+ return CQL_FEATURE_ALL
+ case "normal":
+ return CQL_FEATURE_NORMAL
+ default:
+ return CQL_FEATURE_BASIC
+ }
+}
diff --git a/pkg/typedef/simple_type.go b/pkg/typedef/simple_type.go
index 4a02a612..edeff8f8 100644
--- a/pkg/typedef/simple_type.go
+++ b/pkg/typedef/simple_type.go
@@ -295,7 +295,7 @@ func (st SimpleType) genValue(r *rand.Rand, p *PartitionRangeConfig) any {
case TYPE_FLOAT:
return r.Float32()
case TYPE_INET:
- return net.ParseIP(utils.RandIPV4Address(r, r.Intn(255), 2)).String()
+ return net.ParseIP(utils.RandIPV4Address(r)).String()
case TYPE_INT:
return r.Int31()
case TYPE_SMALLINT:
diff --git a/pkg/typedef/table.go b/pkg/typedef/table.go
index 1af82947..2479d58b 100644
--- a/pkg/typedef/table.go
+++ b/pkg/typedef/table.go
@@ -105,14 +105,13 @@ func (t *Table) ValidColumnsForDelete() Columns {
}
}
}
- if len(t.MaterializedViews) != 0 {
- for _, mv := range t.MaterializedViews {
- if mv.HaveNonPrimaryKey() {
- for j := range validCols {
- if validCols[j].Name == mv.NonPrimaryKey.Name {
- validCols = append(validCols[:j], validCols[j+1:]...)
- break
- }
+
+ for _, mv := range t.MaterializedViews {
+ if mv.HaveNonPrimaryKey() {
+ for j := range validCols {
+ if validCols[j].Name == mv.NonPrimaryKey.Name {
+ validCols = append(validCols[:j], validCols[j+1:]...)
+ break
}
}
}
diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go
index 0b775d48..841b30c8 100644
--- a/pkg/utils/utils.go
+++ b/pkg/utils/utils.go
@@ -21,6 +21,9 @@ import (
"strings"
"time"
+ "github.com/pkg/errors"
+ "golang.org/x/exp/constraints"
+
"github.com/gocql/gocql"
"golang.org/x/exp/rand"
)
@@ -54,22 +57,34 @@ func RandTime(rnd *rand.Rand) int64 {
return rnd.Int63n(86400000000000)
}
-func RandIPV4Address(rnd *rand.Rand, v, pos int) string {
+func ipV4Builder[T constraints.Integer](bytes [4]T) string {
+ var builder strings.Builder
+ builder.Grow(16) // Maximum length of an IPv4 address
+
+ for _, b := range bytes {
+ builder.WriteString(strconv.FormatUint(uint64(b), 10))
+ builder.WriteRune('.')
+ }
+
+ return builder.String()[:builder.Len()-1]
+}
+
+func RandIPV4Address(rnd *rand.Rand) string {
+ return ipV4Builder([4]int{rnd.Intn(256), rnd.Intn(256), rnd.Intn(256), rnd.Intn(256)})
+}
+
+func RandIPV4AddressPositional(rnd *rand.Rand, v, pos int) string {
if pos < 0 || pos > 4 {
panic(fmt.Sprintf("invalid position for the desired value of the IP part %d, 0-3 supported", pos))
}
if v < 0 || v > 255 {
panic(fmt.Sprintf("invalid value for the desired position %d of the IP, 0-255 suppoerted", v))
}
- var blocks []string
- for i := 0; i < 4; i++ {
- if i == pos {
- blocks = append(blocks, strconv.Itoa(v))
- } else {
- blocks = append(blocks, strconv.Itoa(rnd.Intn(255)))
- }
- }
- return strings.Join(blocks, ".")
+
+ data := [4]int{rnd.Intn(255), rnd.Intn(255), rnd.Intn(255), rnd.Intn(255)}
+ data[pos] = v
+
+ return ipV4Builder(data)
}
func RandInt2(rnd *rand.Rand, min, max int) int {
@@ -107,3 +122,12 @@ func UUIDFromTime(rnd *rand.Rand) string {
}
return gocql.UUIDFromTime(RandDate(rnd)).String()
}
+
+func UnwrapErr(err error) error {
+ nextErr := err
+ for nextErr != nil {
+ err = nextErr
+ nextErr = errors.Unwrap(err)
+ }
+ return err
+}
diff --git a/results/.gitkeep b/results/.gitkeep
new file mode 100644
index 00000000..e69de29b