From 5c7adace018d6aa1871c2702f2f86d3e7bb1b40f Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sun, 21 Jul 2024 20:51:36 -0400 Subject: [PATCH 1/2] Add BlobCompressor to compress blob and text fields on fly --- internal/testutils/cluster.go | 133 ++++++++++++++++++ internal/testutils/flags.go | 114 ++++++++++++++++ lz4/blob_compression_integration_test.go | 136 +++++++++++++++++++ lz4/blob_compressor.go | 166 +++++++++++++++++++++++ lz4/blob_compressor_test.go | 1 + lz4/compressed_types.go | 113 +++++++++++++++ lz4/default_blob_compressor.go | 19 +++ lz4/rate_evaluator.go | 159 ++++++++++++++++++++++ lz4/string_support.go | 42 ++++++ lz4/string_support_test.go | 53 ++++++++ 10 files changed, 936 insertions(+) create mode 100644 internal/testutils/cluster.go create mode 100644 internal/testutils/flags.go create mode 100644 lz4/blob_compression_integration_test.go create mode 100644 lz4/blob_compressor.go create mode 100644 lz4/blob_compressor_test.go create mode 100644 lz4/compressed_types.go create mode 100644 lz4/default_blob_compressor.go create mode 100644 lz4/rate_evaluator.go create mode 100644 lz4/string_support.go create mode 100644 lz4/string_support_test.go diff --git a/internal/testutils/cluster.go b/internal/testutils/cluster.go new file mode 100644 index 000000000..431614be4 --- /dev/null +++ b/internal/testutils/cluster.go @@ -0,0 +1,133 @@ +package testutils + +import ( + "context" + "fmt" + "log" + "strings" + "sync" + "testing" + "time" + + "github.com/gocql/gocql" +) + +var initOnce sync.Once + +func CreateSession(tb testing.TB, opts ...func(config *gocql.ClusterConfig)) *gocql.Session { + cluster := CreateCluster(opts...) + return createSessionFromCluster(cluster, tb) +} + +func CreateCluster(opts ...func(*gocql.ClusterConfig)) *gocql.ClusterConfig { + clusterHosts := getClusterHosts() + cluster := gocql.NewCluster(clusterHosts...) + cluster.ProtoVersion = *flagProto + cluster.CQLVersion = *flagCQL + cluster.Timeout = *flagTimeout + cluster.Consistency = gocql.Quorum + cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow + if *flagRetry > 0 { + cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *flagRetry} + } + + switch *flagCompressTest { + case "snappy": + cluster.Compressor = &gocql.SnappyCompressor{} + case "": + default: + panic("invalid compressor: " + *flagCompressTest) + } + + cluster = addSslOptions(cluster) + + for _, opt := range opts { + opt(cluster) + } + + return cluster +} + +func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocql.Session { + // Drop and re-create the keyspace once. Different tests should use their own + // individual tables, but can assume that the table does not exist before. + initOnce.Do(func() { + createKeyspace(tb, cluster, "gocql_test") + }) + + cluster.Keyspace = "gocql_test" + session, err := cluster.CreateSession() + if err != nil { + tb.Fatal("createSession:", err) + } + + if err := session.AwaitSchemaAgreement(context.Background()); err != nil { + tb.Fatal(err) + } + + return session +} + +func getClusterHosts() []string { + return strings.Split(*flagCluster, ",") +} + +func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string) { + // TODO: tb.Helper() + c := *cluster + c.Keyspace = "system" + c.Timeout = 30 * time.Second + session, err := c.CreateSession() + if err != nil { + panic(err) + } + defer session.Close() + + err = CreateTable(session, `DROP KEYSPACE IF EXISTS `+keyspace) + if err != nil { + panic(fmt.Sprintf("unable to drop keyspace: %v", err)) + } + + err = CreateTable(session, fmt.Sprintf(`CREATE KEYSPACE %s + WITH replication = { + 'class' : 'SimpleStrategy', + 'replication_factor' : %d + }`, keyspace, *flagRF)) + + if err != nil { + panic(fmt.Sprintf("unable to create keyspace: %v", err)) + } +} + +func CreateTable(s *gocql.Session, table string) error { + // lets just be really sure + if err := s.AwaitSchemaAgreement(context.Background()); err != nil { + log.Printf("error waiting for schema agreement pre create table=%q err=%v\n", table, err) + return err + } + + if err := s.Query(table).RetryPolicy(&gocql.SimpleRetryPolicy{}).Exec(); err != nil { + log.Printf("error creating table table=%q err=%v\n", table, err) + return err + } + + if err := s.AwaitSchemaAgreement(context.Background()); err != nil { + log.Printf("error waiting for schema agreement post create table=%q err=%v\n", table, err) + return err + } + + return nil +} + +func addSslOptions(cluster *gocql.ClusterConfig) *gocql.ClusterConfig { + if *flagRunSslTest { + cluster.Port = 9142 + cluster.SslOpts = &gocql.SslOptions{ + CertPath: "testdata/pki/gocql.crt", + KeyPath: "testdata/pki/gocql.key", + CaPath: "testdata/pki/ca.crt", + EnableHostVerification: false, + } + } + return cluster +} diff --git a/internal/testutils/flags.go b/internal/testutils/flags.go new file mode 100644 index 000000000..ce19c3080 --- /dev/null +++ b/internal/testutils/flags.go @@ -0,0 +1,114 @@ +package testutils + +import ( + "flag" + "fmt" + "log" + "strconv" + "strings" + "time" + + "github.com/gocql/gocql" +) + +var ( + flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") + flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") + flagProto = flag.Int("proto", 0, "protcol version") + flagCQL = flag.String("cql", "3.0.0", "CQL version") + flagRF = flag.Int("rf", 1, "replication factor for test keyspace") + clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") + flagRetry = flag.Int("retries", 5, "number of times to retry queries") + flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") + flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") + flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") + flagCompressTest = flag.String("compressor", "", "compressor to use") + flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") + + flagCassVersion cassVersion +) + +type cassVersion struct { + Major, Minor, Patch int +} + +func (c *cassVersion) Set(v string) error { + if v == "" { + return nil + } + + return c.UnmarshalCQL(nil, []byte(v)) +} + +func (c *cassVersion) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { + return c.unmarshal(data) +} + +func (c *cassVersion) unmarshal(data []byte) error { + version := strings.TrimSuffix(string(data), "-SNAPSHOT") + version = strings.TrimPrefix(version, "v") + v := strings.Split(version, ".") + + if len(v) < 2 { + return fmt.Errorf("invalid version string: %s", data) + } + + var err error + c.Major, err = strconv.Atoi(v[0]) + if err != nil { + return fmt.Errorf("invalid major version %v: %v", v[0], err) + } + + c.Minor, err = strconv.Atoi(v[1]) + if err != nil { + return fmt.Errorf("invalid minor version %v: %v", v[1], err) + } + + if len(v) > 2 { + c.Patch, err = strconv.Atoi(v[2]) + if err != nil { + return fmt.Errorf("invalid patch version %v: %v", v[2], err) + } + } + + return nil +} + +func (c cassVersion) Before(major, minor, patch int) bool { + // We're comparing us (cassVersion) with the provided version (major, minor, patch) + // We return true if our version is lower (comes before) than the provided one. + if c.Major < major { + return true + } else if c.Major == major { + if c.Minor < minor { + return true + } else if c.Minor == minor && c.Patch < patch { + return true + } + + } + return false +} + +func (c cassVersion) AtLeast(major, minor, patch int) bool { + return !c.Before(major, minor, patch) +} + +func (c cassVersion) String() string { + return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch) +} + +func (c cassVersion) nodeUpDelay() time.Duration { + if c.Major >= 2 && c.Minor >= 2 { + // CASSANDRA-8236 + return 0 + } + + return 10 * time.Second +} + +func init() { + flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") + + log.SetFlags(log.Lshortfile | log.LstdFlags) +} diff --git a/lz4/blob_compression_integration_test.go b/lz4/blob_compression_integration_test.go new file mode 100644 index 000000000..bd5dd71db --- /dev/null +++ b/lz4/blob_compression_integration_test.go @@ -0,0 +1,136 @@ +package lz4 + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/testutils" +) + +func TestBlobCompressor(t *testing.T) { + session := testutils.CreateSession(t) + defer session.Close() + + originalBlob := strings.Repeat("1234567890", 20) + + lz4Compressor, err := NewBlobCompressor([]byte("lz4:"), CompressorSizeLimit(1)) + if err != nil { + t.Fatal("create lz4 compressor") + } + + rtBlob := StatsBasedThreadSafeRateEvaluator{} + rtAscii := StatsBasedThreadSafeRateEvaluator{} + expectedUUID := gocql.TimeUUID() + expectedBlob := lz4Compressor.Blob([]byte(originalBlob)).SetRatioStats(&rtBlob) + expectedText := lz4Compressor.String(originalBlob).SetRatioStats(&rtAscii) + expectedASCII := lz4Compressor.String(originalBlob).SetRatioStats(&rtAscii) + expectedVarchar := lz4Compressor.String(originalBlob).SetRatioStats(&rtAscii) + + t.Run("prepare", func(t *testing.T) { + // TypeVarchar, TypeAscii, TypeBlob, TypeText + err := testutils.CreateTable(session, `CREATE TABLE gocql_test.test_blob_compressor ( + testuuid timeuuid PRIMARY KEY, + testblob blob, + testtext text, + testascii ascii, + testvarchar varchar, + )`) + if err != nil { + t.Fatal("create table:", err) + } + + err = session.Query( + `INSERT INTO gocql_test.test_blob_compressor (testuuid, testblob, testtext, testascii, testvarchar) VALUES (?,?,?,?,?)`, + expectedUUID, expectedBlob, expectedText, expectedASCII, expectedVarchar, + ).Exec() + if err != nil { + t.Fatal("insert:", err) + } + }) + if t.Failed() { + t.FailNow() + } + + t.Run("CheckIfCompressed", func(t *testing.T) { + testMap := make(map[string]interface{}) + if session.Query(`SELECT * FROM test_blob_compressor`).MapScan(testMap) != nil { + t.Fatal("MapScan failed to work with one row") + } + if !lz4Compressor.IsDataCompressed(testMap["testblob"]) { + t.Errorf("expected blob to be compressed, but it is not: %v", testMap["testblob"]) + } + if !lz4Compressor.IsDataCompressed(testMap["testtext"]) { + t.Errorf("expected text to be compressed, but it is not: %v", testMap["testtext"]) + } + if !lz4Compressor.IsDataCompressed(testMap["testascii"]) { + t.Errorf("expected text to be compressed, but it is not: %v", testMap["testascii"]) + } + if !lz4Compressor.IsDataCompressed(testMap["testvarchar"]) { + t.Errorf("expected text to be compressed, but it is not: %v", testMap["testvarchar"]) + } + }) + if t.Failed() { + t.FailNow() + } + + t.Run("MapScan", func(t *testing.T) { + uuid := &gocql.UUID{} + blob := lz4Compressor.Blob(nil).SetRatioStats(&rtBlob) + text := lz4Compressor.String("").SetRatioStats(&rtAscii) + ascii := lz4Compressor.String("").SetRatioStats(&rtAscii) + varchar := lz4Compressor.String("").SetRatioStats(&rtAscii) + iter := session.Query(`SELECT testuuid, testblob, testtext, testascii, testvarchar FROM test_blob_compressor`).Iter() + if !iter.MapScan(map[string]interface{}{ + "testuuid": uuid, + "testblob": blob, + "testtext": text, + "testascii": ascii, + "testvarchar": varchar, + }) { + t.Fatalf("MapScan failed to work with one row: %v", iter.Close()) + } + if diff := cmp.Diff([]byte(originalBlob), blob.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + if diff := cmp.Diff(originalBlob, text.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + if diff := cmp.Diff(originalBlob, ascii.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + if diff := cmp.Diff(originalBlob, varchar.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + }) + + t.Run("Scan", func(t *testing.T) { + uuid := &gocql.UUID{} + blob := lz4Compressor.Blob(nil).SetRatioStats(&rtBlob) + text := lz4Compressor.String("").SetRatioStats(&rtAscii) + ascii := lz4Compressor.String("").SetRatioStats(&rtAscii) + varchar := lz4Compressor.String("").SetRatioStats(&rtAscii) + iter := session.Query(`SELECT testuuid, testblob, testtext, testascii, testvarchar FROM test_blob_compressor`).Iter() + if !iter.Scan(uuid, blob, text, ascii, varchar) { + t.Fatalf("MapScan failed to work with one row: %v", iter.Close()) + } + if diff := cmp.Diff(expectedUUID.String(), uuid.String()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + if diff := cmp.Diff([]byte(originalBlob), blob.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + if diff := cmp.Diff(originalBlob, text.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + if diff := cmp.Diff(originalBlob, ascii.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + if diff := cmp.Diff(originalBlob, varchar.Value()); diff != "" { + t.Fatal("mismatch in returned map", diff) + } + + }) +} diff --git a/lz4/blob_compressor.go b/lz4/blob_compressor.go new file mode 100644 index 000000000..9681492f3 --- /dev/null +++ b/lz4/blob_compressor.go @@ -0,0 +1,166 @@ +package lz4 + +import ( + "bytes" + "encoding/binary" + "fmt" + "github.com/pierrec/lz4/v4" +) + +type BlobCompressor struct { + prefix []byte + prefixPlusLen int + lowerSizeLimit int + defaultBlobRateEvaluator RateEvaluator + defaultASCIIRateEvaluator RateEvaluator +} + +type Option func(*BlobCompressor) error + +func CompressorSizeLimit(limit int) Option { + return func(c *BlobCompressor) error { + c.lowerSizeLimit = limit + return nil + } +} + +func CompressedPrefix(prefix []byte) Option { + return func(c *BlobCompressor) error { + if len(prefix) == 0 { + return fmt.Errorf("prefix should not be empty") + } + c.prefix = prefix + c.prefixPlusLen = len(prefix) + return nil + } +} + +func DefaultASCIIRateEvaluator(re RateEvaluator) Option { + return func(c *BlobCompressor) error { + c.defaultASCIIRateEvaluator = re + return nil + } +} + +func DefaultBlobRateEvaluator(re RateEvaluator) Option { + return func(c *BlobCompressor) error { + c.defaultBlobRateEvaluator = re + return nil + } +} + +func NewBlobCompressor(prefix []byte, options ...Option) (*BlobCompressor, error) { + if len(prefix) == 0 { + prefix = []byte("lz4:") + } + res := &BlobCompressor{ + prefix: prefix, + prefixPlusLen: len(prefix) + 4, + lowerSizeLimit: 1024, + } + + err := res.ApplyOptions(options...) + if err != nil { + return nil, err + } + return res, nil +} + +func NewBlobCompressorMust(prefix []byte, options ...Option) *BlobCompressor { + bc, err := NewBlobCompressor(prefix, options...) + if err != nil { + panic(err) + } + return bc +} + +func (c *BlobCompressor) ApplyOptions(opts ...Option) error { + for _, opt := range opts { + if err := opt(c); err != nil { + return err + } + } + return nil +} + +func (c *BlobCompressor) CompressBinary(data []byte) ([]byte, error) { + if len(data) < c.lowerSizeLimit { + return data, nil + } + + buf := make([]byte, c.prefixPlusLen+lz4.CompressBlockBound(len(data))) + copy(buf, c.prefix) + + var compressor lz4.Compressor + + n, err := compressor.CompressBlock(data, buf[c.prefixPlusLen:]) + // According to lz4.CompressBlock doc, it doesn't fail as long as the dst + // buffer length is at least lz4.CompressBlockBound(len(data))) bytes, but + // we check for error anyway just to be thorough. + if err != nil { + return nil, err + } + binary.BigEndian.PutUint32(buf[len(c.prefix):], uint32(len(data))) + + return buf[:c.prefixPlusLen+n], nil +} + +func (c *BlobCompressor) DecompressBinary(data []byte) ([]byte, error) { + if !bytes.HasPrefix(data, c.prefix) { + return data, nil + } + uncompressedLength := binary.BigEndian.Uint32(data[len(c.prefix):]) + if uncompressedLength == 0 { + return nil, nil + } + buf := make([]byte, uncompressedLength) + n, err := lz4.UncompressBlock(data[c.prefixPlusLen:], buf) + return buf[:n], err +} + +// CompressASCII compresses the given string data into a ascii-compatible byte slice. +func (c *BlobCompressor) CompressASCII(data string) ([]byte, error) { + if len(data) < c.lowerSizeLimit { + return []byte(data), nil + } + b, err := c.CompressBinary([]byte(data)) + return convertBinToASCII(b), err +} + +// DecompressASCII decompresses the given ascii-compatible byte slice to a string. +func (c *BlobCompressor) DecompressASCII(data []byte) (string, error) { + b := convertASCIIToBin(data[:1+len(c.prefix)*8/7]) + if !bytes.HasPrefix(b, c.prefix) { + return string(data), nil + } + b, err := c.DecompressBinary(convertASCIIToBin(data)) + return string(b), err +} + +func (c *BlobCompressor) IsDataCompressed(i interface{}) bool { + switch data := i.(type) { + case string: + b := convertASCIIToBin([]byte(data[:1+len(c.prefix)*8/7])) + return bytes.HasPrefix(b, c.prefix) + case []byte: + return bytes.HasPrefix(data, c.prefix) + default: + return false + } +} + +func (c *BlobCompressor) Blob(val []byte) *CompressedBlob { + return &CompressedBlob{ + c: c, + value: val, + rationStats: c.defaultBlobRateEvaluator, + } +} + +func (c *BlobCompressor) String(val string) *CompressedString { + return &CompressedString{ + c: c, + value: val, + rationStats: c.defaultBlobRateEvaluator, + } +} diff --git a/lz4/blob_compressor_test.go b/lz4/blob_compressor_test.go new file mode 100644 index 000000000..763123404 --- /dev/null +++ b/lz4/blob_compressor_test.go @@ -0,0 +1 @@ +package lz4 diff --git a/lz4/compressed_types.go b/lz4/compressed_types.go new file mode 100644 index 000000000..cdbef289a --- /dev/null +++ b/lz4/compressed_types.go @@ -0,0 +1,113 @@ +package lz4 + +import ( + "github.com/gocql/gocql" +) + +type Decision int + +const ( + CompressionDecisionNone = Decision(iota) + CompressionDecisionCompress + CompressionDecisionDontCompress +) + +type RateEvaluator interface { + WorthOfCompressingRate(size int, compressed int) bool + WorthOfCompressing(size int) Decision +} + +// CompressedBlob is a compressed string type that can be used with gocql. +// Suitable for to accept data only from `blob` column type +// On other colum types, such as `ascii`, `text` and `varchar` generated data is going to hit server side utf-8/ascii validation +// and will be rejected. +// To accommodate server side validation, use CompressedString instead. +type CompressedBlob struct { + value []byte + c *BlobCompressor + rationStats RateEvaluator +} + +func (c *CompressedBlob) MarshalCQL(info gocql.TypeInfo) ([]byte, error) { + switch c.rationStats.WorthOfCompressing(len(c.value)) { + case CompressionDecisionCompress: + return c.c.CompressBinary(c.value) + case CompressionDecisionDontCompress: + return c.value, nil + default: + compressed, err := c.c.CompressBinary(c.value) + if err != nil { + return nil, err + } + if c.rationStats.WorthOfCompressingRate(len(c.value), len(compressed)) { + return compressed, nil + } + return c.value, nil + } +} + +func (c *CompressedBlob) UnmarshalCQL(info gocql.TypeInfo, data []byte) (err error) { + c.value, err = c.c.DecompressBinary(data) + return err +} + +func (c *CompressedBlob) Value() []byte { + return c.value +} + +func (c *CompressedBlob) SetValue(val []byte) *CompressedBlob { + c.value = val + return c +} + +func (c *CompressedBlob) SetRatioStats(val RateEvaluator) *CompressedBlob { + c.rationStats = val + return c +} + +// CompressedString is a compressed string type that can be used with gocql. +// Suitable for to accept data from `varchar`, `text` and `ascii` column types. +type CompressedString struct { + value string + c *BlobCompressor + rationStats RateEvaluator +} + +func (c *CompressedString) MarshalCQL(info gocql.TypeInfo) ([]byte, error) { + if c.rationStats == nil { + return c.c.CompressASCII(c.value) + } + switch c.rationStats.WorthOfCompressing(len(c.value)) { + case CompressionDecisionCompress: + return c.c.CompressASCII(c.value) + case CompressionDecisionDontCompress: + return []byte(c.value), nil + default: + compressed, err := c.c.CompressASCII(c.value) + if err != nil { + return nil, err + } + if c.rationStats.WorthOfCompressingRate(len(c.value), len(compressed)) { + return compressed, nil + } + return []byte(c.value), nil + } +} + +func (c *CompressedString) UnmarshalCQL(info gocql.TypeInfo, data []byte) (err error) { + c.value, err = c.c.DecompressASCII(data) + return err +} + +func (c *CompressedString) Value() string { + return c.value +} + +func (c *CompressedString) SetValue(val string) { + c.value = val +} + +func (c *CompressedString) SetRatioStats(val RateEvaluator) *CompressedString { + c.rationStats = val + return c +} diff --git a/lz4/default_blob_compressor.go b/lz4/default_blob_compressor.go new file mode 100644 index 000000000..b47bbb13a --- /dev/null +++ b/lz4/default_blob_compressor.go @@ -0,0 +1,19 @@ +package lz4 + +var DefaultCompressor = NewBlobCompressorMust([]byte("lz4:")) + +func SetDefaultCompressorLimit(limit int) { + DefaultCompressor.lowerSizeLimit = limit +} + +func SetDefaultCompressorPrefix(prefix []byte) { + DefaultCompressor.prefix = prefix +} + +func Blob(val []byte) *CompressedBlob { + return DefaultCompressor.Blob(val) +} + +func String(val string) *CompressedString { + return DefaultCompressor.String(val) +} diff --git a/lz4/rate_evaluator.go b/lz4/rate_evaluator.go new file mode 100644 index 000000000..2ae3aa4bb --- /dev/null +++ b/lz4/rate_evaluator.go @@ -0,0 +1,159 @@ +package lz4 + +import ( + "sync/atomic" + "unsafe" +) + +const ( + defaultRateCnt = 10 + defaultBucketCnt = 24 + defaultMaxBucket = defaultBucketCnt - 1 +) + +func NewStatsBasedRateEvaluator(limit float32) *StatsBasedRateEvaluator { + return &StatsBasedRateEvaluator{ + compressRationLimit: limit, + } +} + +type StatsBasedRateEvaluator struct { + stats [defaultBucketCnt][defaultRateCnt]float32 + sizeIndex [defaultBucketCnt]int32 + compressRationLimit float32 +} + +var _ RateEvaluator = (*StatsBasedRateEvaluator)(nil) + +func (r *StatsBasedRateEvaluator) PushRate(size int, ratio float32) { + bct := getSizeBracket(size) + r.sizeIndex[bct] += 1 + + r.stats[bct][r.sizeIndex[bct]%defaultRateCnt] = ratio +} + +func (r *StatsBasedRateEvaluator) estimateCompressionRate(size int) float32 { + bct := getSizeBracket(size) + idx := r.sizeIndex[bct] + total := float32(0) + for i := 0; i < defaultRateCnt; i++ { + val := r.stats[bct][idx] + if val == 0 { + return -1 + } + total += val + } + return total / float32(defaultRateCnt) +} + +func (r *StatsBasedRateEvaluator) WorthOfCompressingRate(original, compressed int) bool { + cr := float32(original / compressed) + r.PushRate(original, cr) + return cr > r.compressRationLimit +} + +func (r *StatsBasedRateEvaluator) WorthOfCompressing(size int) Decision { + cr := r.estimateCompressionRate(size) + switch { + case cr == -1: + return CompressionDecisionNone + case cr <= r.compressRationLimit: + return CompressionDecisionDontCompress + default: + return CompressionDecisionCompress + } +} + +type StatsBasedThreadSafeRateEvaluator struct { + stats [32][defaultRateCnt]int32 + sizeIndex [32]int32 + compressRationLimit float32 +} + +var _ RateEvaluator = (*StatsBasedThreadSafeRateEvaluator)(nil) + +func NewStatsBasedThreadSafeRateEvaluator(limit float32) *StatsBasedThreadSafeRateEvaluator { + return &StatsBasedThreadSafeRateEvaluator{ + compressRationLimit: limit, + } +} + +func getSizeBracket(size int) int { + b := 0 + for size > 0 { + size >>= 2 + b++ + } + b -= 4 + if b < 0 { + b = 0 + } + // 0 - size < 256 + // 1 - size < 2048 + // 2 - size < 8192 + if b > defaultMaxBucket { + return defaultMaxBucket + } + return b +} + +func (r *StatsBasedThreadSafeRateEvaluator) PushRate(size int, ratio float32) { + bct := getSizeBracket(size) + idx := atomic.AddInt32(&r.sizeIndex[bct], 1) + atomic.StoreInt32(&r.stats[bct][idx%defaultRateCnt], int32(unsafe.ArbitraryType(ratio))) +} + +func (r *StatsBasedThreadSafeRateEvaluator) estimateCompressionRate(size int) float32 { + bct := getSizeBracket(size) + idx := atomic.AddInt32(&r.sizeIndex[bct], 1) + total := float32(0) + for i := 0; i < defaultRateCnt; i++ { + val := atomic.LoadInt32(&r.stats[bct][idx]) + if val == 0 { + return -1 + } + total += float32(unsafe.ArbitraryType(val)) + } + return total / float32(defaultRateCnt) +} + +func (r *StatsBasedThreadSafeRateEvaluator) WorthOfCompressingRate(original, compressed int) bool { + cr := float32(original / compressed) + r.PushRate(original, cr) + return cr > r.compressRationLimit +} + +func (r *StatsBasedThreadSafeRateEvaluator) WorthOfCompressing(size int) Decision { + cr := r.estimateCompressionRate(size) + switch { + case cr == -1: + return CompressionDecisionNone + case cr <= r.compressRationLimit: + return CompressionDecisionDontCompress + default: + return CompressionDecisionCompress + } +} + +type SimpleRateEvaluator struct { + compressRationLimit float32 +} + +var _ RateEvaluator = (*SimpleRateEvaluator)(nil) + +func NewSimpleRateEvaluator(limit float32) *SimpleRateEvaluator { + return &SimpleRateEvaluator{ + compressRationLimit: limit, + } +} + +func (r *SimpleRateEvaluator) WorthOfCompressingRate(original, compressed int) bool { + return float32(original/compressed) > r.compressRationLimit +} + +func (r *SimpleRateEvaluator) WorthOfCompressing(size int) Decision { + if float32(size) > r.compressRationLimit { + return CompressionDecisionCompress + } + return CompressionDecisionDontCompress +} diff --git a/lz4/string_support.go b/lz4/string_support.go new file mode 100644 index 000000000..13b3304f2 --- /dev/null +++ b/lz4/string_support.go @@ -0,0 +1,42 @@ +package lz4 + +const only7Bits = uint16(0x7f) + +// convertBinToASCII converts a slice of bytes to a slice of ascii-compatible bytes with only 7 bits populated. +func convertBinToASCII(p []byte) []byte { + b := make([]byte, 0, 1+len(p)*8/7) + buff := uint16(0) + k := byte(0) + for _, v := range p { + if k == 7 { + b = append(b, byte(buff)) + k = 0 + buff = 0 + } + buff |= uint16(v) << k + b = append(b, byte(buff&only7Bits)) + buff >>= 7 + k++ + } + if k > 0 { + b = append(b, byte(buff)) + } + return b +} + +// convertBinToASCII converts a slice of ascii-compatible bytes with only 7 bits populated to regular bytes. +func convertASCIIToBin(p []byte) []byte { + b := make([]byte, 0, len(p)*7/8) + buff := uint16(0) + k := byte(0) + for _, v := range p { + buff |= uint16(v) << k + k += 7 + if k >= 8 { + k -= 8 + b = append(b, byte(buff)) + buff >>= 8 + } + } + return b +} diff --git a/lz4/string_support_test.go b/lz4/string_support_test.go new file mode 100644 index 000000000..919cab767 --- /dev/null +++ b/lz4/string_support_test.go @@ -0,0 +1,53 @@ +package lz4 + +import ( + "bytes" + "strconv" + "testing" +) + +func TestBinaryToASII(t *testing.T) { + tcases := []struct { + in []byte + expected []byte + }{ + { + in: []byte{0xaa}, + expected: []byte{0x2a, 0x1}, + }, + { + in: []byte{0xaa, 0xaa, 0xaa}, + expected: []byte{0x2a, 0x55, 0x2a, 0x5}, + }, + { + in: []byte{0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa}, + expected: []byte{0x2a, 0x55, 0x2a, 0x55, 0x2a, 0x55, 0x2a, 0x55}, + }, + { + in: []byte{0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa}, + expected: []byte{0x2a, 0x55, 0x2a, 0x55, 0x2a, 0x55, 0x2a, 0x55, 0x2a, 0x55, 0x2a, 0x5}, + }, + } + + t.Run("convertBinToASCII", func(t *testing.T) { + for id, tcase := range tcases { + t.Run(strconv.Itoa(id), func(t *testing.T) { + got := convertBinToASCII(tcase.in) + if !bytes.Equal(tcase.expected, got) { + t.Errorf("expected %v, got %v", tcase.expected, got) + } + }) + } + }) + + t.Run("convertASCIIToBin", func(t *testing.T) { + for id, tcase := range tcases { + t.Run(strconv.Itoa(id), func(t *testing.T) { + got := convertASCIIToBin(tcase.expected) + if !bytes.Equal(tcase.in, got) { + t.Errorf("expected %v, got %v", tcase.expected, got) + } + }) + } + }) +} From 847b9e4178d5517d483f6e63a3b052da6da42b9a Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 23 Jul 2024 10:42:29 -0400 Subject: [PATCH 2/2] Separate test command line to a package --- batch_test.go | 4 +- cassandra_only_test.go | 16 +++--- cassandra_test.go | 12 ++-- cloud_cluster_test.go | 6 +- common_test.go | 58 +++++++------------- export_test.go | 1 - integration_test.go | 14 +++-- internal/{testutils => testcmdline}/flags.go | 42 ++++++-------- internal/testutils/cluster.go | 21 +++---- 9 files changed, 78 insertions(+), 96 deletions(-) rename internal/{testutils => testcmdline}/flags.go (53%) diff --git a/batch_test.go b/batch_test.go index 490ae796d..f5c40dde8 100644 --- a/batch_test.go +++ b/batch_test.go @@ -6,10 +6,12 @@ package gocql import ( "testing" "time" + + "github.com/gocql/gocql/internal/testcmdline" ) func TestBatch_Errors(t *testing.T) { - if *flagProto == 1 { + if *testcmdline.Proto == 1 { } session := createSession(t) diff --git a/cassandra_only_test.go b/cassandra_only_test.go index fd02d01b0..383326d6f 100644 --- a/cassandra_only_test.go +++ b/cassandra_only_test.go @@ -13,6 +13,8 @@ import ( "sync" "testing" "time" + + "github.com/gocql/gocql/internal/testcmdline" ) func TestDiscoverViaProxy(t *testing.T) { @@ -204,8 +206,8 @@ func TestGetKeyspaceMetadata(t *testing.T) { if err != nil { t.Fatalf("Error converting string to int with err: %v", err) } - if rfInt != *flagRF { - t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt) + if rfInt != *testcmdline.RF { + t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt) } } @@ -431,7 +433,7 @@ func TestViewMetadata(t *testing.T) { } textType := TypeText - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { textType = TypeVarchar } @@ -453,7 +455,7 @@ func TestViewMetadata(t *testing.T) { } func TestMaterializedViewMetadata(t *testing.T) { - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { return } session := createSession(t) @@ -552,7 +554,7 @@ func TestAggregateMetadata(t *testing.T) { } // In this case cassandra is returning a blob - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { expectedAggregrate.InitCond = string([]byte{0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0}) } @@ -736,7 +738,7 @@ func TestKeyspaceMetadata(t *testing.T) { t.Fatal("failed to find the types in metadata") } textType := TypeText - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { textType = TypeVarchar } expectedType := UserTypeMetadata{ @@ -753,7 +755,7 @@ func TestKeyspaceMetadata(t *testing.T) { if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) { t.Fatalf("type is %+v, but expected %+v", keyspaceMetadata.UserTypes["basicview"], expectedType) } - if flagCassVersion.Major >= 3 { + if testcmdline.CassVersion.Major >= 3 { materializedView, found := keyspaceMetadata.MaterializedViews["view_view"] if !found { t.Fatal("failed to find materialized view view_view in metadata") diff --git a/cassandra_test.go b/cassandra_test.go index f7539ded4..fa599ff95 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -20,6 +20,8 @@ import ( "unicode" inf "gopkg.in/inf.v0" + + "github.com/gocql/gocql/internal/testcmdline" ) func TestEmptyHosts(t *testing.T) { @@ -2126,8 +2128,8 @@ func TestGetKeyspaceMetadata(t *testing.T) { if err != nil { t.Fatalf("Error converting string to int with err: %v", err) } - if rfInt != *flagRF { - t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt) + if rfInt != *testcmdline.RF { + t.Errorf("Expected replication factor to be %d but was %d", *testcmdline.RF, rfInt) } } @@ -2494,8 +2496,8 @@ func TestUnmarshallNestedTypes(t *testing.T) { } func TestSchemaReset(t *testing.T) { - if flagCassVersion.Major == 0 || flagCassVersion.Before(2, 1, 3) { - t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", flagCassVersion) + if testcmdline.CassVersion.Major == 0 || testcmdline.CassVersion.Before(2, 1, 3) { + t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", testcmdline.CassVersion) } cluster := createCluster() @@ -2560,7 +2562,7 @@ func TestCreateSession_DontSwallowError(t *testing.T) { t.Fatal("expected to get an error for unsupported protocol") } - if flagCassVersion.Major < 3 { + if testcmdline.CassVersion.Major < 3 { // TODO: we should get a distinct error type here which include the underlying // cassandra error about the protocol version, for now check this here. if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") { diff --git a/cloud_cluster_test.go b/cloud_cluster_test.go index 4133ac56e..ec67af949 100644 --- a/cloud_cluster_test.go +++ b/cloud_cluster_test.go @@ -16,13 +16,15 @@ import ( "testing" "time" + "sigs.k8s.io/yaml" + "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/testcmdline" "github.com/gocql/gocql/scyllacloud" - "sigs.k8s.io/yaml" ) func TestCloudConnection(t *testing.T) { - if !*gocql.FlagRunSslTest { + if !*testcmdline.RunSslTest { t.Skip("Skipping because SSL is not enabled on cluster") } diff --git a/common_test.go b/common_test.go index abbe91cce..462fb2b6d 100644 --- a/common_test.go +++ b/common_test.go @@ -1,7 +1,6 @@ package gocql import ( - "flag" "fmt" "log" "net" @@ -10,41 +9,24 @@ import ( "sync" "testing" "time" -) -var ( - flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") - flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") - flagProto = flag.Int("proto", 0, "protcol version") - flagCQL = flag.String("cql", "3.0.0", "CQL version") - flagRF = flag.Int("rf", 1, "replication factor for test keyspace") - clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") - flagRetry = flag.Int("retries", 5, "number of times to retry queries") - flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") - flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") - flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") - flagCompressTest = flag.String("compressor", "", "compressor to use") - flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") - - flagCassVersion cassVersion + "github.com/gocql/gocql/internal/testcmdline" ) func init() { - flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") - log.SetFlags(log.Lshortfile | log.LstdFlags) } func getClusterHosts() []string { - return strings.Split(*flagCluster, ",") + return strings.Split(*testcmdline.Cluster, ",") } func getMultiNodeClusterHosts() []string { - return strings.Split(*flagMultiNodeCluster, ",") + return strings.Split(*testcmdline.MultiNodeCluster, ",") } func addSslOptions(cluster *ClusterConfig) *ClusterConfig { - if *flagRunSslTest { + if *testcmdline.RunSslTest { cluster.Port = 9142 cluster.SslOpts = &SslOptions{ CertPath: "testdata/pki/gocql.crt", @@ -81,21 +63,21 @@ func createTable(s *Session, table string) error { func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { clusterHosts := getClusterHosts() cluster := NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -110,21 +92,21 @@ func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { func createMultiNodeCluster(opts ...func(*ClusterConfig)) *ClusterConfig { clusterHosts := getMultiNodeClusterHosts() cluster := NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -156,7 +138,7 @@ func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) { WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d - }`, keyspace, *flagRF)) + }`, keyspace, *testcmdline.RF)) if err != nil { panic(fmt.Sprintf("unable to create keyspace: %v", err)) @@ -232,7 +214,7 @@ func createViews(t *testing.T, session *Session) { } func createMaterializedViews(t *testing.T, session *Session) { - if flagCassVersion.Before(3, 0, 0) { + if testcmdline.CassVersion.Before(3, 0, 0) { return } if err := session.Query(`CREATE TABLE IF NOT EXISTS gocql_test.view_table ( diff --git a/export_test.go b/export_test.go index 830436303..3295697db 100644 --- a/export_test.go +++ b/export_test.go @@ -3,7 +3,6 @@ package gocql -var FlagRunSslTest = flagRunSslTest var CreateCluster = createCluster var TestLogger = &testLogger{} var WaitUntilPoolsStopFilling = waitUntilPoolsStopFilling diff --git a/integration_test.go b/integration_test.go index f548a829f..11f22c445 100644 --- a/integration_test.go +++ b/integration_test.go @@ -9,16 +9,18 @@ import ( "reflect" "testing" "time" + + "github.com/gocql/gocql/internal/testcmdline" ) // TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections func TestAuthentication(t *testing.T) { - if *flagProto < 2 { + if *testcmdline.Proto < 2 { t.Skip("Authentication is not supported with protocol < 2") } - if !*flagRunAuthTest { + if !*testcmdline.RunAuthTest { t.Skip("Authentication is not configured in the target cluster") } @@ -60,21 +62,21 @@ func TestRingDiscovery(t *testing.T) { session := createSessionFromCluster(cluster, t) defer session.Close() - if *clusterSize > 1 { + if *testcmdline.ClusterSize > 1 { // wait for autodiscovery to update the pool with the list of known hosts - time.Sleep(*flagAutoWait) + time.Sleep(*testcmdline.AutoWait) } session.pool.mu.RLock() defer session.pool.mu.RUnlock() size := len(session.pool.hostConnPools) - if *clusterSize != size { + if *testcmdline.ClusterSize != size { for p, pool := range session.pool.hostConnPools { t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String()) } - t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size) + t.Errorf("Expected a cluster size of %d, but actual size was %d", *testcmdline.ClusterSize, size) } } diff --git a/internal/testutils/flags.go b/internal/testcmdline/flags.go similarity index 53% rename from internal/testutils/flags.go rename to internal/testcmdline/flags.go index ce19c3080..938c346f5 100644 --- a/internal/testutils/flags.go +++ b/internal/testcmdline/flags.go @@ -1,31 +1,27 @@ -package testutils +package testcmdline import ( "flag" "fmt" - "log" "strconv" "strings" "time" - - "github.com/gocql/gocql" ) var ( - flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") - flagMultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") - flagProto = flag.Int("proto", 0, "protcol version") - flagCQL = flag.String("cql", "3.0.0", "CQL version") - flagRF = flag.Int("rf", 1, "replication factor for test keyspace") - clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") - flagRetry = flag.Int("retries", 5, "number of times to retry queries") - flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") - flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") - flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") - flagCompressTest = flag.String("compressor", "", "compressor to use") - flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") - - flagCassVersion cassVersion + Cluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") + MultiNodeCluster = flag.String("multiCluster", "127.0.0.2", "a comma-separated list of host:port tuples") + Proto = flag.Int("proto", 0, "protcol version") + CQL = flag.String("cql", "3.0.0", "CQL version") + RF = flag.Int("rf", 1, "replication factor for test keyspace") + ClusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") + Retry = flag.Int("retries", 5, "number of times to retry queries") + AutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") + RunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") + RunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") + CompressTest = flag.String("compressor", "", "compressor to use") + Timeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") + CassVersion cassVersion ) type cassVersion struct { @@ -37,11 +33,7 @@ func (c *cassVersion) Set(v string) error { return nil } - return c.UnmarshalCQL(nil, []byte(v)) -} - -func (c *cassVersion) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { - return c.unmarshal(data) + return c.unmarshal([]byte(v)) } func (c *cassVersion) unmarshal(data []byte) error { @@ -108,7 +100,5 @@ func (c cassVersion) nodeUpDelay() time.Duration { } func init() { - flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") - - log.SetFlags(log.Lshortfile | log.LstdFlags) + flag.Var(&CassVersion, "gocql.cversion", "the cassandra version being tested against") } diff --git a/internal/testutils/cluster.go b/internal/testutils/cluster.go index 431614be4..fb81715f4 100644 --- a/internal/testutils/cluster.go +++ b/internal/testutils/cluster.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gocql/gocql" + "github.com/gocql/gocql/internal/testcmdline" ) var initOnce sync.Once @@ -22,21 +23,21 @@ func CreateSession(tb testing.TB, opts ...func(config *gocql.ClusterConfig)) *go func CreateCluster(opts ...func(*gocql.ClusterConfig)) *gocql.ClusterConfig { clusterHosts := getClusterHosts() cluster := gocql.NewCluster(clusterHosts...) - cluster.ProtoVersion = *flagProto - cluster.CQLVersion = *flagCQL - cluster.Timeout = *flagTimeout + cluster.ProtoVersion = *testcmdline.Proto + cluster.CQLVersion = *testcmdline.CQL + cluster.Timeout = *testcmdline.Timeout cluster.Consistency = gocql.Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow - if *flagRetry > 0 { - cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *flagRetry} + if *testcmdline.Retry > 0 { + cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: *testcmdline.Retry} } - switch *flagCompressTest { + switch *testcmdline.CompressTest { case "snappy": cluster.Compressor = &gocql.SnappyCompressor{} case "": default: - panic("invalid compressor: " + *flagCompressTest) + panic("invalid compressor: " + *testcmdline.CompressTest) } cluster = addSslOptions(cluster) @@ -69,7 +70,7 @@ func createSessionFromCluster(cluster *gocql.ClusterConfig, tb testing.TB) *gocq } func getClusterHosts() []string { - return strings.Split(*flagCluster, ",") + return strings.Split(*testcmdline.Cluster, ",") } func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string) { @@ -92,7 +93,7 @@ func createKeyspace(tb testing.TB, cluster *gocql.ClusterConfig, keyspace string WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d - }`, keyspace, *flagRF)) + }`, keyspace, *testcmdline.RF)) if err != nil { panic(fmt.Sprintf("unable to create keyspace: %v", err)) @@ -120,7 +121,7 @@ func CreateTable(s *gocql.Session, table string) error { } func addSslOptions(cluster *gocql.ClusterConfig) *gocql.ClusterConfig { - if *flagRunSslTest { + if *testcmdline.RunSslTest { cluster.Port = 9142 cluster.SslOpts = &gocql.SslOptions{ CertPath: "testdata/pki/gocql.crt",