From bc226ce5a161c8958c5be2a6dd66be766e529852 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Tue, 18 Jun 2024 11:35:51 -0700 Subject: [PATCH] add tls tests and re-org flags --- .github/workflows/tests.yml | 1 + .gitignore | 2 +- cmd/constants.go | 31 -- cmd/createIndex.go | 115 ++++---- cmd/dropIndex.go | 67 ++--- cmd/flags/client.go | 34 +++ cmd/flags/constants.go | 31 ++ cmd/flags/tls.go | 6 +- cmd/listIndex.go | 56 ++-- cmd/root.go | 6 +- cmd/utils_test.go | 49 ++++ docker/tls/config/aerospike-proximus.yml | 86 ++++++ docker/tls/config/aerospike.conf | 82 ++++++ .../config/tls/connector.aerospike.com.crt | 22 ++ .../tls/connector.aerospike.com.keystore.jks | Bin 0 -> 2285 bytes docker/tls/config/tls/keypass | 1 + docker/tls/config/tls/storepass | 1 + docker/tls/docker-compose.yml | 23 ++ e2e_test.go | 269 ++++++------------ go.mod | 8 +- go.sum | 6 + test_utils.go | 160 +++++++++++ 22 files changed, 716 insertions(+), 340 deletions(-) delete mode 100644 cmd/constants.go create mode 100644 cmd/flags/client.go create mode 100644 cmd/flags/constants.go create mode 100644 cmd/utils_test.go create mode 100644 docker/tls/config/aerospike-proximus.yml create mode 100644 docker/tls/config/aerospike.conf create mode 100644 docker/tls/config/tls/connector.aerospike.com.crt create mode 100644 docker/tls/config/tls/connector.aerospike.com.keystore.jks create mode 100644 docker/tls/config/tls/keypass create mode 100644 docker/tls/config/tls/storepass create mode 100644 docker/tls/docker-compose.yml create mode 100644 test_utils.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a82f7a..6530556 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,6 +24,7 @@ jobs: echo "$FEATURES_CONF" > docker/config/features.conf - name: Run tests run: | + echo '127.0.0.1 connector.aerospike.com' >> /etc/hosts make coverage - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/.gitignore b/.gitignore index d87c693..00ec609 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -/docker/config/features.conf +features.conf /bin/* embed_*.go /tmp diff --git a/cmd/constants.go b/cmd/constants.go deleted file mode 100644 index f4cce6c..0000000 --- a/cmd/constants.go +++ /dev/null @@ -1,31 +0,0 @@ -package cmd - -const ( - logLevelFlagName = "log-level" - flagNameSeeds = "seeds" - flagNameHost = "host" - flagNameListenerName = "listener-name" - flagNameNamespace = "namespace" - flagNameSets = "sets" - flagNameIndexName = "index-name" - flagNameVectorField = "vector-field" - flagNameDimension = "dimension" - flagNameDistanceMetric = "distance-metric" - flagNameIndexMeta = "index-meta" - flagNameTimeout = "timeout" - flagNameVerbose = "verbose" - flagNameStorageNamespace = "storage-namespace" - flagNameStorageSet = "storage-set" - flagNameMaxEdges = "hnsw-max-edges" - flagNameConstructionEf = "hnsw-ef-construction" - flagNameEf = "hnsw-ef" - flagNameBatchMaxRecords = "hnsw-batch-max-records" - flagNameBatchInterval = "hnsw-batch-interval" - flagNameBatchEnabled = "hnsw-batch-enabled" - flagNameTLSProtocols = "tls-protocols" - flagNameTLSCaFile = "tls-cafile" - flagNameTLSCaPath = "tls-capath" - flagNameTLSCertFile = "tls-certfile" - flagNameTLSKeyFile = "tls-keyfile" - flagNameTLSKeyFilePass = "tls-keyfile-password" -) diff --git a/cmd/createIndex.go b/cmd/createIndex.go index 81b4fb7..cb4a381 100644 --- a/cmd/createIndex.go +++ b/cmd/createIndex.go @@ -21,9 +21,7 @@ import ( //nolint:govet // Padding not a concern for a CLI var createIndexFlags = &struct { - host *flags.HostPortFlag - seeds *flags.SeedsSliceFlag - listenerName flags.StringOptionalFlag + flags.ClientFlags namespace string sets []string indexName string @@ -40,10 +38,8 @@ var createIndexFlags = &struct { hnswBatchInterval flags.Uint32OptionalFlag hnswBatchEnabled flags.BoolOptionalFlag timeout time.Duration - tls *flags.TLSFlags }{ - host: flags.NewDefaultHostPortFlag(), - seeds: &flags.SeedsSliceFlag{}, + ClientFlags: *flags.NewClientFlags(), storageNamespace: flags.StringOptionalFlag{}, storageSet: flags.StringOptionalFlag{}, hnswMaxEdges: flags.Uint32OptionalFlag{}, @@ -52,41 +48,37 @@ var createIndexFlags = &struct { hnswBatchMaxRecords: flags.Uint32OptionalFlag{}, hnswBatchInterval: flags.Uint32OptionalFlag{}, hnswBatchEnabled: flags.BoolOptionalFlag{}, - tls: &flags.TLSFlags{}, } func newCreateIndexFlagSet() *pflag.FlagSet { - flagSet := &pflag.FlagSet{} - flagSet.VarP(createIndexFlags.host, flagNameHost, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flagNameSeeds))) //nolint:lll // For readability - flagSet.Var(createIndexFlags.seeds, flagNameSeeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flagNameHost))) //nolint:lll // For readability - flagSet.VarP(&createIndexFlags.listenerName, flagNameListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability - flagSet.StringVarP(&createIndexFlags.namespace, flagNameNamespace, "n", "", commonFlags.DefaultWrapHelpString("The namespace for the index.")) //nolint:lll // For readability - flagSet.StringArrayVarP(&createIndexFlags.sets, flagNameSets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability - flagSet.StringVarP(&createIndexFlags.indexName, flagNameIndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability - flagSet.StringVarP(&createIndexFlags.vectorField, flagNameVectorField, "f", "", commonFlags.DefaultWrapHelpString("The name of the vector field.")) //nolint:lll // For readability - flagSet.Uint32VarP(&createIndexFlags.dimensions, flagNameDimension, "d", 0, commonFlags.DefaultWrapHelpString("The dimension of the vector field.")) //nolint:lll // For readability - flagSet.VarP(&createIndexFlags.distanceMetric, flagNameDistanceMetric, "m", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The distance metric for the index. Valid values: %s", strings.Join(flags.DistanceMetricEnum(), ", ")))) //nolint:lll // For readability - flagSet.StringToStringVar(&createIndexFlags.indexMeta, flagNameIndexMeta, nil, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability - flagSet.DurationVar(&createIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability - flagSet.Var(&createIndexFlags.storageNamespace, flagNameStorageNamespace, commonFlags.DefaultWrapHelpString("Optional storage namespace where the index is stored. Defaults to the index namespace.")) //nolint:lll // For readability //nolint:lll // For readability - flagSet.Var(&createIndexFlags.storageSet, flagNameStorageSet, commonFlags.DefaultWrapHelpString("Optional storage set where the index is stored. Defaults to the index name.")) //nolint:lll // For readability //nolint:lll // For readability - flagSet.Var(&createIndexFlags.hnswMaxEdges, flagNameMaxEdges, commonFlags.DefaultWrapHelpString("Maximum number bi-directional links per HNSW vertex. Greater values of 'm' in general provide better recall for data with high dimensionality, while lower values work well for data with lower dimensionality. The storage space required for the index increases proportionally with 'm'. The default value is 16.")) //nolint:lll // For readability - flagSet.Var(&createIndexFlags.hnswConstructionEf, flagNameConstructionEf, commonFlags.DefaultWrapHelpString("The number of candidate nearest neighbors shortlisted during index creation. Larger values provide better recall at the cost of longer index update times. The default is 100.")) //nolint:lll // For readability - flagSet.Var(&createIndexFlags.hnswEf, flagNameEf, commonFlags.DefaultWrapHelpString("The default number of candidate nearest neighbors shortlisted during search. Larger values provide better recall at the cost of longer search times. The default is 100.")) //nolint:lll // For readability - flagSet.Var(&createIndexFlags.hnswBatchMaxRecords, flagNameBatchMaxRecords, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000.")) //nolint:lll // For readability - flagSet.Var(&createIndexFlags.hnswBatchInterval, flagNameBatchInterval, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000.")) //nolint:lll // For readability - flagSet.Var(&createIndexFlags.hnswBatchEnabled, flagNameBatchEnabled, commonFlags.DefaultWrapHelpString("Enables batching for index updates. Default is true meaning batching is enabled.")) //nolint:lll // For readability - flagSet.AddFlagSet(createIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString)) + flagSet := &pflag.FlagSet{} //nolint:lll // For readability + flagSet.StringVarP(&createIndexFlags.namespace, flags.Namespace, "n", "", commonFlags.DefaultWrapHelpString("The namespace for the index.")) //nolint:lll // For readability + flagSet.StringArrayVarP(&createIndexFlags.sets, flags.Sets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability + flagSet.StringVarP(&createIndexFlags.indexName, flags.IndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability + flagSet.StringVarP(&createIndexFlags.vectorField, flags.VectorField, "f", "", commonFlags.DefaultWrapHelpString("The name of the vector field.")) //nolint:lll // For readability + flagSet.Uint32VarP(&createIndexFlags.dimensions, flags.Dimension, "d", 0, commonFlags.DefaultWrapHelpString("The dimension of the vector field.")) //nolint:lll // For readability + flagSet.VarP(&createIndexFlags.distanceMetric, flags.DistanceMetric, "m", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The distance metric for the index. Valid values: %s", strings.Join(flags.DistanceMetricEnum(), ", ")))) //nolint:lll // For readability + flagSet.StringToStringVar(&createIndexFlags.indexMeta, flags.IndexMeta, nil, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability + flagSet.DurationVar(&createIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability + flagSet.Var(&createIndexFlags.storageNamespace, flags.StorageNamespace, commonFlags.DefaultWrapHelpString("Optional storage namespace where the index is stored. Defaults to the index namespace.")) //nolint:lll // For readability //nolint:lll // For readability + flagSet.Var(&createIndexFlags.storageSet, flags.StorageSet, commonFlags.DefaultWrapHelpString("Optional storage set where the index is stored. Defaults to the index name.")) //nolint:lll // For readability //nolint:lll // For readability + flagSet.Var(&createIndexFlags.hnswMaxEdges, flags.MaxEdges, commonFlags.DefaultWrapHelpString("Maximum number bi-directional links per HNSW vertex. Greater values of 'm' in general provide better recall for data with high dimensionality, while lower values work well for data with lower dimensionality. The storage space required for the index increases proportionally with 'm'. The default value is 16.")) //nolint:lll // For readability + flagSet.Var(&createIndexFlags.hnswConstructionEf, flags.ConstructionEf, commonFlags.DefaultWrapHelpString("The number of candidate nearest neighbors shortlisted during index creation. Larger values provide better recall at the cost of longer index update times. The default is 100.")) //nolint:lll // For readability + flagSet.Var(&createIndexFlags.hnswEf, flags.Ef, commonFlags.DefaultWrapHelpString("The default number of candidate nearest neighbors shortlisted during search. Larger values provide better recall at the cost of longer search times. The default is 100.")) //nolint:lll // For readability + flagSet.Var(&createIndexFlags.hnswBatchMaxRecords, flags.BatchMaxRecords, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000.")) //nolint:lll // For readability + flagSet.Var(&createIndexFlags.hnswBatchInterval, flags.BatchInterval, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000.")) //nolint:lll // For readability + flagSet.Var(&createIndexFlags.hnswBatchEnabled, flags.BatchEnabled, commonFlags.DefaultWrapHelpString("Enables batching for index updates. Default is true meaning batching is enabled.")) //nolint:lll // For readability + flagSet.AddFlagSet(createIndexFlags.NewClientFlagSet()) return flagSet } var createIndexRequiredFlags = []string{ - flagNameNamespace, - flagNameIndexName, - flagNameVectorField, - flagNameDimension, - flagNameDistanceMetric, + flags.Namespace, + flags.IndexName, + flags.VectorField, + flags.Dimension, + flags.DistanceMetric, } // createIndexCmd represents the createIndex command @@ -106,42 +98,53 @@ func newCreateIndexCmd() *cobra.Command { --storage-namespace test --hnsw-batch-enabled false `, PreRunE: func(_ *cobra.Command, _ []string) error { - if viper.IsSet(flagNameSeeds) && viper.IsSet(flagNameHost) { - return fmt.Errorf("only --%s or --%s allowed", flagNameSeeds, flagNameHost) + if viper.IsSet(flags.Seeds) && viper.IsSet(flags.Host) { + return fmt.Errorf("only --%s or --%s allowed", flags.Seeds, flags.Host) } return nil }, RunE: func(_ *cobra.Command, _ []string) error { - hosts, isLoadBalancer := parseBothHostSeedsFlag(createIndexFlags.seeds, createIndexFlags.host) + hosts, isLoadBalancer := parseBothHostSeedsFlag(createIndexFlags.Seeds, createIndexFlags.Host) logger.Debug("parsed flags", - slog.String(flagNameHost, createIndexFlags.host.String()), - slog.String(flagNameSeeds, createIndexFlags.seeds.String()), - slog.String(flagNameListenerName, createIndexFlags.listenerName.String()), - slog.String(flagNameNamespace, createIndexFlags.namespace), - slog.Any(flagNameSets, createIndexFlags.sets), - slog.String(flagNameIndexName, createIndexFlags.indexName), - slog.String(flagNameVectorField, createIndexFlags.vectorField), - slog.Uint64(flagNameDimension, uint64(createIndexFlags.dimensions)), - slog.Any(flagNameIndexMeta, createIndexFlags.indexMeta), - slog.String(flagNameDistanceMetric, createIndexFlags.distanceMetric.String()), - slog.Duration(flagNameTimeout, createIndexFlags.timeout), - slog.Any(flagNameStorageNamespace, createIndexFlags.storageNamespace.String()), - slog.Any(flagNameStorageSet, createIndexFlags.storageSet.String()), - slog.Any(flagNameMaxEdges, createIndexFlags.hnswMaxEdges.String()), - slog.Any(flagNameEf, createIndexFlags.hnswEf), - slog.Any(flagNameConstructionEf, createIndexFlags.hnswConstructionEf.String()), - slog.Any(flagNameBatchMaxRecords, createIndexFlags.hnswBatchMaxRecords.String()), - slog.Any(flagNameBatchInterval, createIndexFlags.hnswBatchInterval.String()), - slog.Any(flagNameBatchEnabled, createIndexFlags.hnswBatchEnabled.String()), + slog.String(flags.Host, createIndexFlags.Host.String()), + slog.String(flags.Seeds, createIndexFlags.Seeds.String()), + slog.String(flags.ListenerName, createIndexFlags.ListenerName.String()), + slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil), + slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil), + slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil), + slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil), + slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil), + slog.String(flags.Namespace, createIndexFlags.namespace), + slog.Any(flags.Sets, createIndexFlags.sets), + slog.String(flags.IndexName, createIndexFlags.indexName), + slog.String(flags.VectorField, createIndexFlags.vectorField), + slog.Uint64(flags.Dimension, uint64(createIndexFlags.dimensions)), + slog.Any(flags.IndexMeta, createIndexFlags.indexMeta), + slog.String(flags.DistanceMetric, createIndexFlags.distanceMetric.String()), + slog.Duration(flags.Timeout, createIndexFlags.timeout), + slog.Any(flags.StorageNamespace, createIndexFlags.storageNamespace.String()), + slog.Any(flags.StorageSet, createIndexFlags.storageSet.String()), + slog.Any(flags.MaxEdges, createIndexFlags.hnswMaxEdges.String()), + slog.Any(flags.Ef, createIndexFlags.hnswEf), + slog.Any(flags.ConstructionEf, createIndexFlags.hnswConstructionEf.String()), + slog.Any(flags.BatchMaxRecords, createIndexFlags.hnswBatchMaxRecords.String()), + slog.Any(flags.BatchInterval, createIndexFlags.hnswBatchInterval.String()), + slog.Any(flags.BatchEnabled, createIndexFlags.hnswBatchEnabled.String()), ) ctx, cancel := context.WithTimeout(context.Background(), createIndexFlags.timeout) defer cancel() + tlsConfig, err := createIndexFlags.NewTLSConfig() + if err != nil { + logger.Error("failed to create TLS config", slog.Any("error", err)) + return err + } + adminClient, err := avs.NewAdminClient( - ctx, hosts, createIndexFlags.listenerName.Val, isLoadBalancer, nil, logger, + ctx, hosts, createIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger, ) if err != nil { logger.Error("failed to create AVS client", slog.Any("error", err)) diff --git a/cmd/dropIndex.go b/cmd/dropIndex.go index d8deb5f..6ddfcd1 100644 --- a/cmd/dropIndex.go +++ b/cmd/dropIndex.go @@ -19,37 +19,29 @@ import ( //nolint:govet // Padding not a concern for a CLI var dropIndexFlags = &struct { - host *flags.HostPortFlag - seeds *flags.SeedsSliceFlag - listenerName flags.StringOptionalFlag - namespace string - sets []string - indexName string - timeout time.Duration - tls *flags.TLSFlags + flags.ClientFlags + namespace string + sets []string + indexName string + timeout time.Duration }{ - host: flags.NewDefaultHostPortFlag(), - seeds: &flags.SeedsSliceFlag{}, - tls: &flags.TLSFlags{}, + ClientFlags: *flags.NewClientFlags(), } func newDropIndexFlagSet() *pflag.FlagSet { flagSet := &pflag.FlagSet{} - flagSet.VarP(dropIndexFlags.host, flagNameHost, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flagNameSeeds))) //nolint:lll // For readability - flagSet.Var(dropIndexFlags.seeds, flagNameSeeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flagNameHost))) //nolint:lll // For readability - flagSet.VarP(&dropIndexFlags.listenerName, flagNameListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability - flagSet.StringVarP(&dropIndexFlags.namespace, flagNameNamespace, "n", "", commonFlags.DefaultWrapHelpString("The namespace for the index.")) //nolint:lll // For readability - flagSet.StringArrayVarP(&dropIndexFlags.sets, flagNameSets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability - flagSet.StringVarP(&dropIndexFlags.indexName, flagNameIndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability - flagSet.DurationVar(&dropIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability - flagSet.AddFlagSet(dropIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString)) + flagSet.StringVarP(&dropIndexFlags.namespace, flags.Namespace, "n", "", commonFlags.DefaultWrapHelpString("The namespace for the index.")) //nolint:lll // For readability + flagSet.StringArrayVarP(&dropIndexFlags.sets, flags.Sets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability + flagSet.StringVarP(&dropIndexFlags.indexName, flags.IndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability + flagSet.DurationVar(&dropIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability + flagSet.AddFlagSet(dropIndexFlags.NewClientFlagSet()) return flagSet } var dropIndexRequiredFlags = []string{ - flagNameNamespace, - flagNameIndexName, + flags.Namespace, + flags.IndexName, } // dropIndexCmd represents the dropIndex command @@ -65,30 +57,41 @@ func newDropIndexCommand() *cobra.Command { asvec drop index -i myindex -n test `, PreRunE: func(_ *cobra.Command, _ []string) error { - if viper.IsSet(flagNameSeeds) && viper.IsSet(flagNameHost) { - return fmt.Errorf("only --%s or --%s allowed", flagNameSeeds, flagNameHost) + if viper.IsSet(flags.Seeds) && viper.IsSet(flags.Host) { + return fmt.Errorf("only --%s or --%s allowed", flags.Seeds, flags.Host) } return nil }, RunE: func(_ *cobra.Command, _ []string) error { logger.Debug("parsed flags", - slog.String(flagNameHost, dropIndexFlags.host.String()), - slog.String(flagNameSeeds, dropIndexFlags.seeds.String()), - slog.String(flagNameListenerName, dropIndexFlags.listenerName.String()), - slog.String(flagNameNamespace, dropIndexFlags.namespace), - slog.Any(flagNameSets, dropIndexFlags.sets), - slog.String(flagNameIndexName, dropIndexFlags.indexName), - slog.Duration(flagNameTimeout, dropIndexFlags.timeout), + slog.String(flags.Host, dropIndexFlags.Host.String()), + slog.String(flags.Seeds, dropIndexFlags.Seeds.String()), + slog.String(flags.ListenerName, dropIndexFlags.ListenerName.String()), + slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil), + slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil), + slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil), + slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil), + slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil), + slog.String(flags.Namespace, dropIndexFlags.namespace), + slog.Any(flags.Sets, dropIndexFlags.sets), + slog.String(flags.IndexName, dropIndexFlags.indexName), + slog.Duration(flags.Timeout, dropIndexFlags.timeout), ) - hosts, isLoadBalancer := parseBothHostSeedsFlag(dropIndexFlags.seeds, dropIndexFlags.host) + hosts, isLoadBalancer := parseBothHostSeedsFlag(dropIndexFlags.Seeds, dropIndexFlags.Host) ctx, cancel := context.WithTimeout(context.Background(), dropIndexFlags.timeout) defer cancel() + tlsConfig, err := dropIndexFlags.NewTLSConfig() + if err != nil { + logger.Error("failed to create TLS config", slog.Any("error", err)) + return err + } + adminClient, err := avs.NewAdminClient( - ctx, hosts, nil, isLoadBalancer, nil, logger, + ctx, hosts, dropIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger, ) if err != nil { logger.Error("failed to create AVS client", slog.Any("error", err)) diff --git a/cmd/flags/client.go b/cmd/flags/client.go new file mode 100644 index 0000000..95dc670 --- /dev/null +++ b/cmd/flags/client.go @@ -0,0 +1,34 @@ +package flags + +import ( + "fmt" + + commonFlags "github.com/aerospike/tools-common-go/flags" + "github.com/spf13/pflag" +) + +type ClientFlags struct { + Host *HostPortFlag + Seeds *SeedsSliceFlag + ListenerName StringOptionalFlag + TLSFlags +} + +func NewClientFlags() *ClientFlags { + return &ClientFlags{ + Host: NewDefaultHostPortFlag(), + Seeds: &SeedsSliceFlag{}, + TLSFlags: *NewTLSFlags(), + } +} + +func (cf *ClientFlags) NewClientFlagSet() *pflag.FlagSet { + flagSet := &pflag.FlagSet{} + flagSet.VarP(cf.Host, Host, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", Seeds))) //nolint:lll // For readability + flagSet.Var(cf.Seeds, Seeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", Host))) //nolint:lll // For readability + flagSet.VarP(&cf.ListenerName, ListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) + + flagSet.AddFlagSet(cf.NewTLSFlagSet(commonFlags.DefaultWrapHelpString)) + + return flagSet +} diff --git a/cmd/flags/constants.go b/cmd/flags/constants.go new file mode 100644 index 0000000..1e06a87 --- /dev/null +++ b/cmd/flags/constants.go @@ -0,0 +1,31 @@ +package flags + +const ( + LogLevel = "log-level" + Seeds = "seeds" + Host = "host" + ListenerName = "listener-name" + Namespace = "namespace" + Sets = "sets" + IndexName = "index-name" + VectorField = "vector-field" + Dimension = "dimension" + DistanceMetric = "distance-metric" + IndexMeta = "index-meta" + Timeout = "timeout" + Verbose = "verbose" + StorageNamespace = "storage-namespace" + StorageSet = "storage-set" + MaxEdges = "hnsw-max-edges" + ConstructionEf = "hnsw-ef-construction" + Ef = "hnsw-ef" + BatchMaxRecords = "hnsw-batch-max-records" + BatchInterval = "hnsw-batch-interval" + BatchEnabled = "hnsw-batch-enabled" + TLSProtocols = "tls-protocols" + TLSCaFile = "tls-cafile" + TLSCaPath = "tls-capath" + TLSCertFile = "tls-certfile" + TLSKeyFile = "tls-keyfile" + TLSKeyFilePass = "tls-keyfile-password" +) diff --git a/cmd/flags/tls.go b/cmd/flags/tls.go index 1dd6f91..30927c4 100644 --- a/cmd/flags/tls.go +++ b/cmd/flags/tls.go @@ -25,7 +25,7 @@ func NewTLSFlags() *TLSFlags { // NewAerospikeFlagSet returns a new pflag.FlagSet with Aerospike flags defined. // Values set in the returned FlagSet will be stored in the AerospikeFlags argument. -func (tf *TLSFlags) NewFlagSet(fmtUsage commonFlags.UsageFormatter) *pflag.FlagSet { +func (tf *TLSFlags) NewTLSFlagSet(fmtUsage commonFlags.UsageFormatter) *pflag.FlagSet { f := &pflag.FlagSet{} f.Var(&tf.TLSRootCAFile, "tls-cafile", fmtUsage("The CA used when connecting to AVS.")) @@ -55,7 +55,7 @@ func (tf *TLSFlags) NewTLSConfig() (*tls.Config, error) { tf.TLSCertFile, tf.TLSKeyFile, tf.TLSKeyFilePass, - tf.TLSProtocols.Min, - tf.TLSProtocols.Max, + 0, + 0, ).NewGoTLSConfig() } diff --git a/cmd/listIndex.go b/cmd/listIndex.go index 0ab7e6a..d294caf 100644 --- a/cmd/listIndex.go +++ b/cmd/listIndex.go @@ -20,26 +20,21 @@ import ( ) var listIndexFlags = &struct { - host *flags.HostPortFlag - seeds *flags.SeedsSliceFlag - listenerName flags.StringOptionalFlag - verbose bool - timeout time.Duration - tls *flags.TLSFlags + flags.ClientFlags + verbose bool + timeout time.Duration }{ - host: flags.NewDefaultHostPortFlag(), - seeds: &flags.SeedsSliceFlag{}, - tls: &flags.TLSFlags{}, + ClientFlags: *flags.NewClientFlags(), } func newListIndexFlagSet() *pflag.FlagSet { flagSet := &pflag.FlagSet{} - flagSet.VarP(listIndexFlags.host, flagNameHost, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flagNameSeeds))) //nolint:lll // For readability - flagSet.Var(listIndexFlags.seeds, flagNameSeeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flagNameHost))) //nolint:lll // For readability - flagSet.VarP(&listIndexFlags.listenerName, flagNameListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability - flagSet.BoolVarP(&listIndexFlags.verbose, flagNameVerbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability - flagSet.DurationVar(&listIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability - flagSet.AddFlagSet(listIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString)) + flagSet.VarP(listIndexFlags.Host, flags.Host, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flags.Seeds))) //nolint:lll // For readability + flagSet.Var(listIndexFlags.Seeds, flags.Seeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flags.Host))) //nolint:lll // For readability + flagSet.VarP(&listIndexFlags.ListenerName, flags.ListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability + flagSet.BoolVarP(&listIndexFlags.verbose, flags.Verbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability + flagSet.DurationVar(&listIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability + flagSet.AddFlagSet(listIndexFlags.NewClientFlagSet()) return flagSet } @@ -57,30 +52,41 @@ func newListIndexCmd() *cobra.Command { For example: export ASVEC_HOST=:5000 asvec list index - `, flagNameVerbose), + `, flags.Verbose), PreRunE: func(_ *cobra.Command, _ []string) error { - if viper.IsSet(flagNameSeeds) && viper.IsSet(flagNameHost) { - return fmt.Errorf("only --%s or --%s allowed", flagNameSeeds, flagNameHost) + if viper.IsSet(flags.Seeds) && viper.IsSet(flags.Host) { + return fmt.Errorf("only --%s or --%s allowed", flags.Seeds, flags.Host) } return nil }, RunE: func(_ *cobra.Command, _ []string) error { logger.Debug("parsed flags", - slog.String(flagNameHost, listIndexFlags.host.String()), - slog.String(flagNameSeeds, listIndexFlags.seeds.String()), - slog.String(flagNameListenerName, listIndexFlags.listenerName.String()), - slog.Bool(flagNameVerbose, listIndexFlags.verbose), - slog.Duration(flagNameTimeout, listIndexFlags.timeout), + slog.String(flags.Host, listIndexFlags.Host.String()), + slog.String(flags.Seeds, listIndexFlags.Seeds.String()), + slog.String(flags.ListenerName, listIndexFlags.ListenerName.String()), + slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil), + slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil), + slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil), + slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil), + slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil), + slog.Bool(flags.Verbose, listIndexFlags.verbose), + slog.Duration(flags.Timeout, listIndexFlags.timeout), ) - hosts, isLoadBalancer := parseBothHostSeedsFlag(listIndexFlags.seeds, listIndexFlags.host) + hosts, isLoadBalancer := parseBothHostSeedsFlag(listIndexFlags.Seeds, listIndexFlags.Host) ctx, cancel := context.WithTimeout(context.Background(), listIndexFlags.timeout) defer cancel() + tlsConfig, err := listIndexFlags.NewTLSConfig() + if err != nil { + logger.Error("failed to create TLS config", slog.Any("error", err)) + return err + } + adminClient, err := avs.NewAdminClient( - ctx, hosts, listIndexFlags.listenerName.Val, isLoadBalancer, nil, logger, + ctx, hosts, listIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger, ) if err != nil { logger.Error("failed to create AVS client", slog.Any("error", err)) diff --git a/cmd/root.go b/cmd/root.go index 017228e..46f099b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -86,15 +86,15 @@ func Execute() { } func init() { - rootCmd.PersistentFlags().Var(&rootFlags.logLevel, logLevelFlagName, common.DefaultWrapHelpString(fmt.Sprintf("Log level for additional details and debugging. Valid values: %s", strings.Join(flags.LogLevelEnum(), ", ")))) + rootCmd.PersistentFlags().Var(&rootFlags.logLevel, flags.LogLevel, common.DefaultWrapHelpString(fmt.Sprintf("Log level for additional details and debugging. Valid values: %s", strings.Join(flags.LogLevelEnum(), ", ")))) common.SetupRoot(rootCmd, "aerospike-vector-search", "0.0.0") // TODO: Handle version viper.SetEnvPrefix("ASVEC") - if err := viper.BindEnv(flagNameHost); err != nil { + if err := viper.BindEnv(flags.Host); err != nil { logger.Error("failed to bind environment variable", slog.Any("error", err)) } - if err := viper.BindEnv(flagNameSeeds); err != nil { + if err := viper.BindEnv(flags.Seeds); err != nil { logger.Error("failed to bind environment variable", slog.Any("error", err)) } } diff --git a/cmd/utils_test.go b/cmd/utils_test.go new file mode 100644 index 0000000..53c8a84 --- /dev/null +++ b/cmd/utils_test.go @@ -0,0 +1,49 @@ +//go:build unit + +package cmd + +import ( + "asvec/cmd/flags" + "testing" + + avs "github.com/aerospike/aerospike-proximus-client-go" + "github.com/stretchr/testify/assert" +) + +func TestParseBothHostSeedsFlag(t *testing.T) { + testCases := []struct { + seeds *flags.SeedsSliceFlag + host *flags.HostPortFlag + expectedSlice avs.HostPortSlice + expectedIsLoadBalancer bool + }{ + { + &flags.SeedsSliceFlag{ + Seeds: avs.HostPortSlice{ + avs.NewHostPort("1.1.1.1", 5000, false), + }, + }, + flags.NewDefaultHostPortFlag(), + avs.HostPortSlice{ + avs.NewHostPort("1.1.1.1", 5000, false), + }, + false, + }, + { + &flags.SeedsSliceFlag{ + Seeds: avs.HostPortSlice{}, + }, + flags.NewDefaultHostPortFlag(), + avs.HostPortSlice{ + &flags.NewDefaultHostPortFlag().HostPort, + }, + true, + }, + } + + for _, tc := range testCases { + actualSlice, actualBool := parseBothHostSeedsFlag(tc.seeds, tc.host) + assert.Equal(t, tc.expectedSlice, actualSlice) + assert.Equal(t, tc.expectedIsLoadBalancer, actualBool) + } +} diff --git a/docker/tls/config/aerospike-proximus.yml b/docker/tls/config/aerospike-proximus.yml new file mode 100644 index 0000000..740d40f --- /dev/null +++ b/docker/tls/config/aerospike-proximus.yml @@ -0,0 +1,86 @@ +# Change the configuration for your use case. +cluster: + # Custom node-id. It will be auto-generated if not specified. + # node-id: a1 + + # Unique identifier for this cluster. + cluster-name: prism-image-search + +tls: + service-tls: + trust-store: + store-file: /etc/aerospike-proximus/tls/ca.aerospike.com.truststore.jks + store-password-file: /etc/aerospike-proximus/tls/storepass + key-store: + store-file: /etc/aerospike-proximus/tls/localhost.keystore.jks + store-password-file: /etc/aerospike-proximus/tls/storepass + key-password-file: /etc/aerospike-proximus/tls/keypass + +# The Proximus service listening ports, TLS and network interface. +service: + ports: + 10000: + # If TLS needs to be enabled, tls configuration id. + tls-id: service-tls + advertised-listeners: + default: + address: 127.0.0.1 + port: 10000 + +# Management API listening ports, TLS and network interface. +manage: + ports: + 5040: + tls-id: service-tls + +# Intra cluster interconnect listening ports, TLS and network interface. +interconnect: + ports: + 5001: {} + +#heartbeat: +# seeds: +# - address: localhost +# port: 6001 + +# Target Aerospike cluster +aerospike: + seeds: + - aerospike: + port: 3000 + +# File based credentials store only if security should be enabled. +#security: +# credentials-store: +# type: file +# credentials-file: samples/credentials.yml +# auth-token: +# private-key: samples/auth/private_key.pem +# public-key: samples/auth/public_key.pem + +# Vault based credentials store only if security should be enabled. +#security: +# credentials-store: +# type: vault +# url: https://vault:8200 +# secrets-path: /secret/aerospike/aerodb1 +# tls: +# key-store: +# store-type: PEM +# store-file: key.pem +# store-password-file: keypass.txt # Password protecting key.pem. +# certificate-chain-files: certchain.pem +# trust-store: +# store-type: PEM +# certificate-files: cacert.pem +# auth-token: +# private-key: samples/auth/private_key.pem +# public-key: samples/auth/public_key.pem + +# The logging properties. +logging: + #format: json + #file: /var/log/aerospike-proximus/aerospike-proximus.log + enable-console-logging: true + levels: + metrics-ticker: off diff --git a/docker/tls/config/aerospike.conf b/docker/tls/config/aerospike.conf new file mode 100644 index 0000000..a23c052 --- /dev/null +++ b/docker/tls/config/aerospike.conf @@ -0,0 +1,82 @@ +# Aerospike database configuration file for use with systemd. + +service { + cluster-name prism-demo + proto-fd-max 15000 +} + + +logging { + file /var/log/aerospike/aerospike.log { + context any info + } + + # Send log messages to stdout + console { + context any info + context query critical + } +} + +network { + service { + address any + port 3000 + } + + heartbeat { + mode multicast + multicast-group 239.1.99.222 + port 9918 + + # To use unicast-mesh heartbeats, remove the 3 lines above, and see + # aerospike_mesh.conf for alternative. + + interval 150 + timeout 10 + } + + fabric { + port 3001 + } + + info { + port 3003 + } +} + +namespace test { + replication-factor 1 + nsup-period 60 + + storage-engine memory { + data-size 1G + } +} + +namespace bar { + replication-factor 1 + nsup-period 60 + + storage-engine memory { + data-size 1G + } +} + +namespace proximus-meta { + replication-factor 1 + nsup-period 100 + + storage-engine memory { + data-size 1G + } + + # To use file storage backing, comment out the line above and use the + # following lines instead. +# storage-engine device { +# file /opt/aerospike/data/bar.dat +# filesize 16G +# data-in-memory true # Store data in memory in addition to file. +# } +} + diff --git a/docker/tls/config/tls/connector.aerospike.com.crt b/docker/tls/config/tls/connector.aerospike.com.crt new file mode 100644 index 0000000..3e53361 --- /dev/null +++ b/docker/tls/config/tls/connector.aerospike.com.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDpjCCAo6gAwIBAgIJAJuWztiFvTcCMA0GCSqGSIb3DQEBCwUAMIGLMQswCQYD +VQQGEwJJTjELMAkGA1UECAwCS0ExCzAJBgNVBAcMAkJOMRIwEAYDVQQKDAlhZXJv +c3Bpa2UxEjAQBgNVBAsMCWVjb3N5c3RlbTEZMBcGA1UEAwwQY2EuYWVyb3NwaWtl +LmNvbTEfMB0GCSqGSIb3DQEJARYQY2FAYWVyb3NwaWtlLmNvbTAeFw0xOTA3MDgx +MDMyMjFaFw0yOTA3MDUxMDMyMjFaMFYxCzAJBgNVBAYTAklOMQswCQYDVQQIDAJL +QTEYMBYGA1UECgwPQWVyb3NwaWtlLCBJbmMuMSAwHgYDVQQDDBdjb25uZWN0b3Iu +YWVyb3NwaWtlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAO70 +L+G2TP6s1jPPMvIWB2LSdiVKkfKAuihTf5lJEuF1jyCt6c8q7j4gqAEA5PoAaTpM +ukKt3EvmCyRZTjBr1WqSOuwU4mw+rsQ+gWdFiaEOLXS6kHcpunb5/wLeYAC6rGLk +Di9FBSwX3sNke5M1v/P+Vcd3ozA3bsv5JTFCbZiwarG7i8ZatRXFDHYwS0t36pLA +JcyVUyehT7SSctkQQym0goMjLpnEKgimC0RKgw9yhmPmK5oasIzythr6MSv2SkiO +BJfJDBJZcqvNFn66ghuofJ7vyxxW59DSwrRAlXPfAaxQ4mcgECuoe1xMUtrMya64 +4Hz7ZijnbMrFKTqgy3sCAwEAAaNBMD8wPQYDVR0RBDYwNIIXY29ubmVjdG9yLmFl +cm9zcGlrZS5jb22CGSouY29ubmVjdG9yLmFlcm9zcGlrZS5jb20wDQYJKoZIhvcN +AQELBQADggEBAFTbTRmftwDZlC/60I9w6u5/vRRJRCQvp5mBhrbsfA1OcY7HOBOT +G3hFtDHl14s5cpZAgecdgACtvGvnex6q/gun2JzdAmJjslpRAQgZA0L182X+Sc5z +5QIhMhS6iMCae9KV4GaEq0VMfZ/KlKPfESj1DZE3VQNWUPTwBbamnzpEqSzoFdrp +duFah12yIwnUoTi+fNfhhm76Y5GkG5Cx4HJvF6qiJ1uOIieLYZ07e1nFxIVqnEz/ +zZCkDkKqXWb+ii8hlBl8kiu4vVwKEjrlruSaKb+3UTJxB16zDESjn5i58A9TgtsZ +VvECl552NbEBT+VOgRPsNJRBpPbPfYqUVfk= +-----END CERTIFICATE----- diff --git a/docker/tls/config/tls/connector.aerospike.com.keystore.jks b/docker/tls/config/tls/connector.aerospike.com.keystore.jks new file mode 100644 index 0000000000000000000000000000000000000000..29490f53bdc89036e1e5bb37638e5e1c43e28f7e GIT binary patch literal 2285 zcmZ{lc{J4f8^`B6gQ1yWFi1(3v5xr82u0aSg-KZp=~t$)??jl$Oo(KO#H4I<#e|Rs zlOjtlLP)ajS<(Nc+K{dLc;bMEht=bY#LocB4O*LiO8x#2I~+AF91?`zJr2%+i-jU51Qe?r#}k8X>9HovSD09VZ3uKML#UW)wM{t82Ju;{U|7!OAA~l+(GS=rBb(p+a(e zcjIx!sN$>gs(~)mNWN5Dld@RU?dP8IqZSrL-n-6ijzLO8iBo~q_;^d|>vE^~i1ZH4 z4%oVk+tn61_oB##nL+DDrF%95I!iQkXSmWs*2&f9&Fv{ieMSts<<4J>cC`OusKMeE zFwUR}1s+|DPV1brB>_(qtp?|^7rS0k7gff!bQyCUw-wl3RkJETZogo<`$`LGh5o*Bp-vV{eO6`tq*0Yve^gl@=3yaym1#s&y9$)q zWZ(0jJ>75IP0R4NIv@-+dQ* zS*Cpf?6Up91aHuSr}Vr?!OITKYHk;qQ(3J~H4`G>y;pZs=Nda_Vw)m8t?;gQa{MCQ zlr+c4ek;S93VlJ51_2W?Ef@vhduHx?)CK?Qp%?6HrnHe0VbfY|a@YyA7b6KDUF!7g zlJDdNH)ty);Amme`;_1CBHAVlwqyOt`nmW5JD&Dh*~0Xu()N4g(Ho-iQTK)uHVvz5 zm4qT65v;o0S*wP~Ce;f^{dp%T4{R$5kGuacIB0o>GJi=IHRAA=KP4N%AKm6~5-ugY ze4zY=31KGEFX9n(0QF-4szUuZhafGdiv8l(LdDbfaE1vuNUtRPWr$=KBpL9x3_ zhbDARlhf+>Gb0UiBl!dR2vZ>g`f5c@zYRXfMYPmVz3rVB^)I{hQsF+I2~g{@PScv< z6^)VD&n&^GPM_-6q#{C#*lrRb4e*ts>3N

6n7JQhQ&j)8?-K<%@7sFx1q&~KB!swwn>+`idzbyAc0fp#|S)ovMo=* zG=CJ1hwX^k)q#R2J^|I3V=!RgUH@^Nl@a44~9gO&a3}p zRVVp+<9CC*{yAL$5Z^Kw{$c>7B~bYNpboeX4{8z!cn1j-;fK`vuLSLWBK-^Oc2N9p zJ|Ux@d}=bLJ|uO#3@H6Kkp#)t$A?S`@(uVE45R{b{~OHK-vOfnLXfTbHYgPUAe>Jc z6D%UP_>E>4VMQG7+!L%|ma=k%t!y2cX(}>tF-fLmzDt#(FH;CWroTemb%|`_l40|A z2&}yY=-KCSQ+E+F>7`%Vq#xyK5|=BuCy1RKqQVZ|*o2OrgRqO8rv)`kc-16E8!v>X zYSpcBZCgU}K^>p=4F$ZhcSe~<`IGA}93G1`BZEP6^N@v`^$IU$d+r=wDyoR>%OvX6U$Nb`@kGP2=mUB)4s6WyRJY=YCUJerI}pB{kK^oP-jwwgC)Q( z2{T?>C38(X181PJ1PuFHJw3eV*0{^n2PVYJ*==e0W5UX7s1zMrnBB=w%e<_DY~8~} zYI9;YcETTjg_of*!*1Impr`c{idl;1>m6CutVd%0>%U}09tvOWvC+T$BVE_OC literal 0 HcmV?d00001 diff --git a/docker/tls/config/tls/keypass b/docker/tls/config/tls/keypass new file mode 100644 index 0000000..b1f833e --- /dev/null +++ b/docker/tls/config/tls/keypass @@ -0,0 +1 @@ +citrusstore diff --git a/docker/tls/config/tls/storepass b/docker/tls/config/tls/storepass new file mode 100644 index 0000000..b1f833e --- /dev/null +++ b/docker/tls/config/tls/storepass @@ -0,0 +1 @@ +citrusstore diff --git a/docker/tls/docker-compose.yml b/docker/tls/docker-compose.yml new file mode 100644 index 0000000..30a3af1 --- /dev/null +++ b/docker/tls/docker-compose.yml @@ -0,0 +1,23 @@ +services: + aerospike: + image: aerospike/aerospike-server-enterprise:7.0.0.2 + ports: + - "3000:3000" + networks: + - avs-demo + volumes: + - ./config:/opt/aerospike/etc/aerospike + command: + - "--config-file" + - "/opt/aerospike/etc/aerospike/aerospike.conf" + avs: + image: aerospike.jfrog.io/docker/aerospike/aerospike-proximus-private:0.5.0-SNAPSHOT + ports: + - "10000:10000" + networks: + - avs-demo + volumes: + - ./config:/etc/aerospike-proximus + +networks: + avs-demo: {} diff --git a/e2e_test.go b/e2e_test.go index 9dadc2f..7a8f725 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -1,10 +1,15 @@ //go:build integration -package main_test +package main import ( + "asvec/cmd/flags" "context" + "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" + "log" "log/slog" "os" "os/exec" @@ -28,31 +33,68 @@ var ( barNamespace = "bar" ) +func GetCACert(cert string) (*x509.CertPool, error) { + // read in file + certBytes, err := ioutil.ReadFile(cert) + if err != nil { + log.Fatalf("unable to read cert file %v", err) + return nil, err + } + + certificates := x509.NewCertPool() + certificates.AppendCertsFromPEM(certBytes) + + return certificates, nil +} + type CmdTestSuite struct { suite.Suite - app string - coverFile string - coverFileCounter int - avsIP string - avsPort int - avsHostPort *avs.HostPort - avsClient *avs.AdminClient + app string + composeFile string + suiteName string + suiteArgs string + coverFile string + avsIP string + avsPort int + avsHostPort *avs.HostPort + avsTLSConfig *tls.Config + avsClient *avs.AdminClient } -func TestDistanceMetricFlagSuite(t *testing.T) { - suite.Run(t, new(CmdTestSuite)) +func TestCmdSuite(t *testing.T) { + logger = logger.WithGroup("test-logger") + rootCA, err := GetCACert("docker/tls/config/tls/ca.aerospike.com.crt") + if err != nil { + t.Fatalf("unable to read root ca %v", err) + t.FailNow() + logger.Error("Failed to read cert") + } + + logger.Info("%v", slog.Any("cert", rootCA)) + + suite.Run(t, &CmdTestSuite{ + composeFile: "docker/docker-compose.yml", + suiteArgs: "--log-level debug", + avsIP: "localhost", + }) + suite.Run(t, &CmdTestSuite{ + composeFile: "docker/tls/docker-compose.yml", + suiteArgs: fmt.Sprintf("--%s %s --log-level debug", flags.TLSCaFile, "docker/tls/config/tls/connector.aerospike.com.crt"), + avsTLSConfig: &tls.Config{ + Certificates: nil, + RootCAs: rootCA, + }, + avsIP: "connector.aerospike.com", + }) } func (suite *CmdTestSuite) SetupSuite() { suite.app = path.Join(wd, "app.test") suite.coverFile = path.Join(wd, "../coverage/cmd-coverage.cov") - suite.coverFileCounter = 0 - suite.avsIP = "127.0.0.1" suite.avsPort = 10000 suite.avsHostPort = avs.NewHostPort(suite.avsIP, suite.avsPort, false) - // var err error - err := docker_compose_up() + err := docker_compose_up(suite.composeFile) if err != nil { suite.FailNowf("unable to start docker compose up", "%v", err) } @@ -75,7 +117,7 @@ func (suite *CmdTestSuite) SetupSuite() { for { suite.avsClient, err = avs.NewAdminClient( - ctx, avs.HostPortSlice{suite.avsHostPort}, nil, true, nil, logger, + ctx, avs.HostPortSlice{suite.avsHostPort}, nil, true, suite.avsTLSConfig, logger, ) if err != nil { @@ -100,17 +142,18 @@ func (suite *CmdTestSuite) TearDownSuite() { suite.Assert().NoError(err) suite.avsClient.Close() - err = docker_compose_down() + err = docker_compose_down(suite.composeFile) if err != nil { fmt.Println("unable to stop docker compose down") } } func (suite *CmdTestSuite) runCmd(asvecCmd ...string) ([]string, error) { + asvecCmd = append(strings.Split(suite.suiteArgs, " "), asvecCmd...) + logger.Info("running command", slog.String("cmd", strings.Join(asvecCmd, " "))) cmd := exec.Command(suite.app, asvecCmd...) cmd.Env = []string{"GOCOVERDIR=" + os.Getenv("COVERAGE_DIR")} stdout, err := cmd.Output() - // fmt.Printf("stdout: %v", string(stdout)) if err != nil { if ee, ok := err.(*exec.ExitError); ok { @@ -124,161 +167,6 @@ func (suite *CmdTestSuite) runCmd(asvecCmd ...string) ([]string, error) { return lines, nil } -func getStrPtr(str string) *string { - ptr := str - return &ptr -} - -func getUint32Ptr(i int) *uint32 { - ptr := uint32(i) - return &ptr -} - -func getBoolPtr(b bool) *bool { - ptr := b - return &ptr -} - -type IndexDefinitionBuilder struct { - indexName string - namespace string - set *string - dimension int - vectorDistanceMetric protos.VectorDistanceMetric - vectorField string - storageNamespace *string - storageSet *string - hnsfM *uint32 - hnsfEfC *uint32 - hnsfEf *uint32 - hnsfBatchingMaxRecord *uint32 - hnsfBatchingInterval *uint32 - hnsfBatchingDisabled *bool -} - -func NewIndexDefinitionBuilder( - indexName, - namespace string, - dimension int, - distanceMetric protos.VectorDistanceMetric, - vectorField string, -) *IndexDefinitionBuilder { - return &IndexDefinitionBuilder{ - indexName: indexName, - namespace: namespace, - dimension: dimension, - vectorDistanceMetric: distanceMetric, - vectorField: vectorField, - } -} - -func (idb *IndexDefinitionBuilder) WithSet(set string) *IndexDefinitionBuilder { - idb.set = &set - return idb -} - -func (idb *IndexDefinitionBuilder) WithStorageNamespace(storageNamespace string) *IndexDefinitionBuilder { - idb.storageNamespace = &storageNamespace - return idb -} - -func (idb *IndexDefinitionBuilder) WithStorageSet(storageSet string) *IndexDefinitionBuilder { - idb.storageSet = &storageSet - return idb -} - -func (idb *IndexDefinitionBuilder) WithHnswM(m uint32) *IndexDefinitionBuilder { - idb.hnsfM = &m - return idb -} - -func (idb *IndexDefinitionBuilder) WithHnswEf(ef uint32) *IndexDefinitionBuilder { - idb.hnsfEf = &ef - return idb -} - -func (idb *IndexDefinitionBuilder) WithHnswEfConstruction(efConstruction uint32) *IndexDefinitionBuilder { - idb.hnsfEfC = &efConstruction - return idb -} - -func (idb *IndexDefinitionBuilder) WithHnswBatchingMaxRecord(maxRecord uint32) *IndexDefinitionBuilder { - idb.hnsfBatchingMaxRecord = &maxRecord - return idb -} - -func (idb *IndexDefinitionBuilder) WithHnswBatchingInterval(interval uint32) *IndexDefinitionBuilder { - idb.hnsfBatchingInterval = &interval - return idb -} - -func (idb *IndexDefinitionBuilder) WithHnswBatchingDisabled(disabled bool) *IndexDefinitionBuilder { - idb.hnsfBatchingDisabled = &disabled - return idb -} - -func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition { - indexDef := &protos.IndexDefinition{ - Id: &protos.IndexId{ - Name: idb.indexName, - Namespace: idb.namespace, - }, - Dimensions: uint32(idb.dimension), - VectorDistanceMetric: idb.vectorDistanceMetric, - Field: idb.vectorField, - Type: protos.IndexType_HNSW, - Storage: &protos.IndexStorage{ - Namespace: &idb.namespace, - Set: &idb.indexName, - }, - Params: &protos.IndexDefinition_HnswParams{ - HnswParams: &protos.HnswParams{ - M: getUint32Ptr(16), - EfConstruction: getUint32Ptr(100), - Ef: getUint32Ptr(100), - BatchingParams: &protos.HnswBatchingParams{ - MaxRecords: getUint32Ptr(100000), - Interval: getUint32Ptr(30000), - Disabled: getBoolPtr(false), - }, - }, - }, - } - - if idb.set != nil { - indexDef.SetFilter = idb.set - } - - if idb.storageNamespace != nil { - indexDef.Storage.Namespace = idb.storageNamespace - } - - if idb.storageSet != nil { - indexDef.Storage.Set = idb.storageSet - } - - if idb.hnsfM != nil { - indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.M = idb.hnsfM - } - if idb.hnsfEf != nil { - indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.Ef = idb.hnsfEf - } - if idb.hnsfEfC != nil { - indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.EfConstruction = idb.hnsfEfC - } - if idb.hnsfBatchingMaxRecord != nil { - indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.BatchingParams.MaxRecords = idb.hnsfBatchingMaxRecord - } - if idb.hnsfBatchingInterval != nil { - indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.BatchingParams.Interval = idb.hnsfBatchingInterval - } - if idb.hnsfBatchingDisabled != nil { - indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.BatchingParams.Disabled = idb.hnsfBatchingDisabled - } - - return indexDef -} - func (suite *CmdTestSuite) TestSuccessfulCreateIndexCmd() { testCases := []struct { name string @@ -291,14 +179,14 @@ func (suite *CmdTestSuite) TestSuccessfulCreateIndexCmd() { "test with storage config", "index1", "test", - fmt.Sprintf("create index --seeds %s -n test -i index1 -d 256 -m SQUARED_EUCLIDEAN --vector-field vector1 --storage-namespace bar --storage-set testbar --timeout 10s", suite.avsHostPort.String()), + fmt.Sprintf("create index --host %s -n test -i index1 -d 256 -m SQUARED_EUCLIDEAN --vector-field vector1 --storage-namespace bar --storage-set testbar --timeout 10s", suite.avsHostPort.String()), NewIndexDefinitionBuilder("index1", "test", 256, protos.VectorDistanceMetric_SQUARED_EUCLIDEAN, "vector1"). WithStorageNamespace("bar"). WithStorageSet("testbar"). Build(), }, { - "test with hnsw params", + "test with hnsw params and seeds", "index2", "test", fmt.Sprintf("create index --timeout 10s --seeds %s -n test -i index2 -d 256 -m HAMMING --vector-field vector2 --hnsw-max-edges 10 --hnsw-ef 11 --hnsw-ef-construction 12", suite.avsHostPort.String()), @@ -312,7 +200,7 @@ func (suite *CmdTestSuite) TestSuccessfulCreateIndexCmd() { "test with hnsw batch params", "index3", "test", - fmt.Sprintf("create index --timeout 10s --seeds %s -n test -i index3 -d 256 -m COSINE --vector-field vector3 --hnsw-batch-enabled false --hnsw-batch-interval 50 --hnsw-batch-max-records 100", suite.avsHostPort.String()), + fmt.Sprintf("create index --timeout 10s --host %s -n test -i index3 -d 256 -m COSINE --vector-field vector3 --hnsw-batch-enabled false --hnsw-batch-interval 50 --hnsw-batch-max-records 100", suite.avsHostPort.String()), NewIndexDefinitionBuilder("index3", "test", 256, protos.VectorDistanceMetric_COSINE, "vector3"). WithHnswBatchingMaxRecord(100). WithHnswBatchingInterval(50). @@ -324,7 +212,11 @@ func (suite *CmdTestSuite) TestSuccessfulCreateIndexCmd() { for _, tc := range testCases { suite.Run(tc.name, func() { lines, err := suite.runCmd(strings.Split(tc.cmd, " ")...) - suite.Assert().NoError(err, "error: %s, stdout/err: %s", err, lines) + + if err != nil { + suite.Assert().NoError(err, "error: %s, stdout/err: %s", err, lines) + suite.FailNow("unable to create index") + } actual, err := suite.avsClient.IndexGet(context.Background(), tc.indexNamespace, tc.indexName) @@ -340,10 +232,10 @@ func (suite *CmdTestSuite) TestSuccessfulCreateIndexCmd() { } func (suite *CmdTestSuite) TestCreateIndexFailsAlreadyExistsCmd() { - lines, err := suite.runCmd(strings.Split(fmt.Sprintf("create index --seeds %s -n test -i exists -d 256 -m SQUARED_EUCLIDEAN --vector-field vector1 --storage-namespace bar --storage-set testbar --timeout 10s", suite.avsHostPort.String()), " ")...) + lines, err := suite.runCmd(strings.Split(fmt.Sprintf("create index --host %s -n test -i exists -d 256 -m SQUARED_EUCLIDEAN --vector-field vector1 --storage-namespace bar --storage-set testbar --timeout 10s", suite.avsHostPort.String()), " ")...) suite.Assert().NoError(err, "index should have NOT existed on first call. error: %s, stdout/err: %s", err, lines) - lines, err = suite.runCmd(strings.Split(fmt.Sprintf("create index --seeds %s -n test -i exists -d 256 -m SQUARED_EUCLIDEAN --vector-field vector1 --storage-namespace bar --storage-set testbar --timeout 10s", suite.avsHostPort.String()), " ")...) + lines, err = suite.runCmd(strings.Split(fmt.Sprintf("create index --host %s -n test -i exists -d 256 -m SQUARED_EUCLIDEAN --vector-field vector1 --storage-namespace bar --storage-set testbar --timeout 10s", suite.avsHostPort.String()), " ")...) suite.Assert().Error(err, "index should HAVE existed on first call. error: %s, stdout/err: %s", err, lines) suite.Assert().Contains(lines[0], "AlreadyExists") @@ -358,7 +250,7 @@ func (suite *CmdTestSuite) TestSuccessfulDropIndexCmd() { cmd string }{ { - "test with just namespace", + "test with just namespace and seeds", "indexdrop1", "test", nil, @@ -371,7 +263,7 @@ func (suite *CmdTestSuite) TestSuccessfulDropIndexCmd() { []string{ "testset", }, - fmt.Sprintf("drop index --seeds %s -n test -s testset -i indexdrop2 --timeout 10s", suite.avsHostPort.String()), + fmt.Sprintf("drop index --host %s -n test -s testset -i indexdrop2 --timeout 10s", suite.avsHostPort.String()), }, } @@ -382,12 +274,18 @@ func (suite *CmdTestSuite) TestSuccessfulDropIndexCmd() { suite.FailNowf("unable to create index", "%v", err) } + time.Sleep(time.Second * 3) + lines, err := suite.runCmd(strings.Split(tc.cmd, " ")...) suite.Assert().NoError(err, "error: %s, stdout/err: %s", err, lines) + if err != nil { + suite.FailNow("unable to drop index") + } + _, err = suite.avsClient.IndexGet(context.Background(), tc.indexNamespace, tc.indexName) - time.Sleep(time.Second) + time.Sleep(time.Second * 3) if err == nil { suite.FailNow("err is nil, that means the index still exists") @@ -628,13 +526,12 @@ func (suite *CmdTestSuite) TestFailInvalidArg() { } } -func docker_compose_up() error { +func docker_compose_up(composeFile string) error { fmt.Println("Starting docker containers") ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - // docker/docker-compose.yml - cmd := exec.CommandContext(ctx, "docker", "compose", fmt.Sprintf("-fdocker/docker-compose.yml"), "up", "-d") + cmd := exec.CommandContext(ctx, "docker", "compose", fmt.Sprintf("-f%s", composeFile), "up", "-d") output, err := cmd.CombinedOutput() fmt.Printf("docker compose up output: %s\n", string(output)) @@ -649,11 +546,11 @@ func docker_compose_up() error { return nil } -func docker_compose_down() error { +func docker_compose_down(composeFile string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - cmd := exec.CommandContext(ctx, "docker", "compose", fmt.Sprintf("-fdocker/docker-compose.yml"), "down") + cmd := exec.CommandContext(ctx, "docker", "compose", fmt.Sprintf("-f%s", composeFile), "down") _, err := cmd.Output() if err != nil { diff --git a/go.mod b/go.mod index 76d251b..a9f8493 100644 --- a/go.mod +++ b/go.mod @@ -2,11 +2,13 @@ module asvec go 1.21.7 -replace github.com/aerospike/aerospike-proximus-client-go => /Users/jesseschmidt/Developer/aerospike-proximus-client-go +// replace github.com/aerospike/aerospike-proximus-client-go => /Users/jesseschmidt/Developer/aerospike-proximus-client-go + +// replace github.com/aerospike/tools-common-go => /Users/jesseschmidt/Developer/tools-common-go require ( - github.com/aerospike/aerospike-proximus-client-go v0.0.0-20240603230632-86a0ebaa8aa9 - github.com/aerospike/tools-common-go v0.0.0-20240425222921-596724ec5926 + github.com/aerospike/aerospike-proximus-client-go v0.0.0-20240618165139-d1f0bb1968a5 + github.com/aerospike/tools-common-go v0.0.0-20240618165632-595098741f89 github.com/jedib0t/go-pretty/v6 v6.5.9 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 diff --git a/go.sum b/go.sum index 16c4edc..a0cec2d 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,14 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/aerospike/aerospike-client-go/v7 v7.4.0 h1:g8/7v8RHhQhTArhW3C7Au7o+u8j8x5eySZL6MXfpHKU= github.com/aerospike/aerospike-client-go/v7 v7.4.0/go.mod h1:pPKnWiS8VDJcH4IeB1b8SA2TWnkjcVLHwAAJ+BHfGK8= +github.com/aerospike/aerospike-proximus-client-go v0.0.0-20240603230632-86a0ebaa8aa9 h1:qVpPCrbp0pNNmP1CPqln6HkzhVmFmOOVZYLq4IDlidI= +github.com/aerospike/aerospike-proximus-client-go v0.0.0-20240603230632-86a0ebaa8aa9/go.mod h1:N0kxd4FoYDbLOEwm8vWH6wKUkoR5v0Wp/v0+tUqoUMg= +github.com/aerospike/aerospike-proximus-client-go v0.0.0-20240618165139-d1f0bb1968a5 h1:OfqJsUs8T8DaYqDLyZwXDY55FvWtyVybbm2mc3Pi+1s= +github.com/aerospike/aerospike-proximus-client-go v0.0.0-20240618165139-d1f0bb1968a5/go.mod h1:N0kxd4FoYDbLOEwm8vWH6wKUkoR5v0Wp/v0+tUqoUMg= github.com/aerospike/tools-common-go v0.0.0-20240425222921-596724ec5926 h1:CqkNasGC/7x5JvYjCSuAVX/rG+nUgRQtXfxIURXo5OE= github.com/aerospike/tools-common-go v0.0.0-20240425222921-596724ec5926/go.mod h1:Ig1lRynXx0tXNOY3MdtanTsKz1ifG/2AyDFMXn3RMTc= +github.com/aerospike/tools-common-go v0.0.0-20240618165632-595098741f89 h1:5rYc5QsaQeAnSzUm30gOUANEIEsMS8knbnjouenRV7E= +github.com/aerospike/tools-common-go v0.0.0-20240618165632-595098741f89/go.mod h1:Ig1lRynXx0tXNOY3MdtanTsKz1ifG/2AyDFMXn3RMTc= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/test_utils.go b/test_utils.go new file mode 100644 index 0000000..9bcb1cc --- /dev/null +++ b/test_utils.go @@ -0,0 +1,160 @@ +//go:build integration + +package main + +import "github.com/aerospike/aerospike-proximus-client-go/protos" + +func getStrPtr(str string) *string { + ptr := str + return &ptr +} + +func getUint32Ptr(i int) *uint32 { + ptr := uint32(i) + return &ptr +} + +func getBoolPtr(b bool) *bool { + ptr := b + return &ptr +} + +type IndexDefinitionBuilder struct { + indexName string + namespace string + set *string + dimension int + vectorDistanceMetric protos.VectorDistanceMetric + vectorField string + storageNamespace *string + storageSet *string + hnsfM *uint32 + hnsfEfC *uint32 + hnsfEf *uint32 + hnsfBatchingMaxRecord *uint32 + hnsfBatchingInterval *uint32 + hnsfBatchingDisabled *bool +} + +func NewIndexDefinitionBuilder( + indexName, + namespace string, + dimension int, + distanceMetric protos.VectorDistanceMetric, + vectorField string, +) *IndexDefinitionBuilder { + return &IndexDefinitionBuilder{ + indexName: indexName, + namespace: namespace, + dimension: dimension, + vectorDistanceMetric: distanceMetric, + vectorField: vectorField, + } +} + +func (idb *IndexDefinitionBuilder) WithSet(set string) *IndexDefinitionBuilder { + idb.set = &set + return idb +} + +func (idb *IndexDefinitionBuilder) WithStorageNamespace(storageNamespace string) *IndexDefinitionBuilder { + idb.storageNamespace = &storageNamespace + return idb +} + +func (idb *IndexDefinitionBuilder) WithStorageSet(storageSet string) *IndexDefinitionBuilder { + idb.storageSet = &storageSet + return idb +} + +func (idb *IndexDefinitionBuilder) WithHnswM(m uint32) *IndexDefinitionBuilder { + idb.hnsfM = &m + return idb +} + +func (idb *IndexDefinitionBuilder) WithHnswEf(ef uint32) *IndexDefinitionBuilder { + idb.hnsfEf = &ef + return idb +} + +func (idb *IndexDefinitionBuilder) WithHnswEfConstruction(efConstruction uint32) *IndexDefinitionBuilder { + idb.hnsfEfC = &efConstruction + return idb +} + +func (idb *IndexDefinitionBuilder) WithHnswBatchingMaxRecord(maxRecord uint32) *IndexDefinitionBuilder { + idb.hnsfBatchingMaxRecord = &maxRecord + return idb +} + +func (idb *IndexDefinitionBuilder) WithHnswBatchingInterval(interval uint32) *IndexDefinitionBuilder { + idb.hnsfBatchingInterval = &interval + return idb +} + +func (idb *IndexDefinitionBuilder) WithHnswBatchingDisabled(disabled bool) *IndexDefinitionBuilder { + idb.hnsfBatchingDisabled = &disabled + return idb +} + +func (idb *IndexDefinitionBuilder) Build() *protos.IndexDefinition { + indexDef := &protos.IndexDefinition{ + Id: &protos.IndexId{ + Name: idb.indexName, + Namespace: idb.namespace, + }, + Dimensions: uint32(idb.dimension), + VectorDistanceMetric: idb.vectorDistanceMetric, + Field: idb.vectorField, + Type: protos.IndexType_HNSW, + Storage: &protos.IndexStorage{ + Namespace: &idb.namespace, + Set: &idb.indexName, + }, + Params: &protos.IndexDefinition_HnswParams{ + HnswParams: &protos.HnswParams{ + M: getUint32Ptr(16), + EfConstruction: getUint32Ptr(100), + Ef: getUint32Ptr(100), + BatchingParams: &protos.HnswBatchingParams{ + MaxRecords: getUint32Ptr(100000), + Interval: getUint32Ptr(30000), + Disabled: getBoolPtr(false), + }, + }, + }, + } + + if idb.set != nil { + indexDef.SetFilter = idb.set + } + + if idb.storageNamespace != nil { + indexDef.Storage.Namespace = idb.storageNamespace + } + + if idb.storageSet != nil { + indexDef.Storage.Set = idb.storageSet + } + + if idb.hnsfM != nil { + indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.M = idb.hnsfM + } + if idb.hnsfEf != nil { + indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.Ef = idb.hnsfEf + } + if idb.hnsfEfC != nil { + indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.EfConstruction = idb.hnsfEfC + } + if idb.hnsfBatchingMaxRecord != nil { + indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.BatchingParams.MaxRecords = idb.hnsfBatchingMaxRecord + } + if idb.hnsfBatchingInterval != nil { + indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.BatchingParams.Interval = idb.hnsfBatchingInterval + } + if idb.hnsfBatchingDisabled != nil { + indexDef.Params.(*protos.IndexDefinition_HnswParams).HnswParams.BatchingParams.Disabled = idb.hnsfBatchingDisabled + } + + return indexDef +}