Skip to content

Commit

Permalink
add tls tests and re-org flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Schmidt committed Jun 18, 2024
1 parent 3880ef4 commit bc226ce
Show file tree
Hide file tree
Showing 22 changed files with 716 additions and 340 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
echo "$FEATURES_CONF" > docker/config/features.conf
- name: Run tests
run: |
echo '127.0.0.1 connector.aerospike.com' >> /etc/hosts
make coverage
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/docker/config/features.conf
features.conf
/bin/*
embed_*.go
/tmp
Expand Down
31 changes: 0 additions & 31 deletions cmd/constants.go

This file was deleted.

115 changes: 59 additions & 56 deletions cmd/createIndex.go

Large diffs are not rendered by default.

67 changes: 35 additions & 32 deletions cmd/dropIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,29 @@ import (

//nolint:govet // Padding not a concern for a CLI
var dropIndexFlags = &struct {
host *flags.HostPortFlag
seeds *flags.SeedsSliceFlag
listenerName flags.StringOptionalFlag
namespace string
sets []string
indexName string
timeout time.Duration
tls *flags.TLSFlags
flags.ClientFlags
namespace string
sets []string
indexName string
timeout time.Duration
}{
host: flags.NewDefaultHostPortFlag(),
seeds: &flags.SeedsSliceFlag{},
tls: &flags.TLSFlags{},
ClientFlags: *flags.NewClientFlags(),
}

func newDropIndexFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.VarP(dropIndexFlags.host, flagNameHost, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flagNameSeeds))) //nolint:lll // For readability
flagSet.Var(dropIndexFlags.seeds, flagNameSeeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flagNameHost))) //nolint:lll // For readability
flagSet.VarP(&dropIndexFlags.listenerName, flagNameListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability
flagSet.StringVarP(&dropIndexFlags.namespace, flagNameNamespace, "n", "", commonFlags.DefaultWrapHelpString("The namespace for the index.")) //nolint:lll // For readability
flagSet.StringArrayVarP(&dropIndexFlags.sets, flagNameSets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability
flagSet.StringVarP(&dropIndexFlags.indexName, flagNameIndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability
flagSet.DurationVar(&dropIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(dropIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString))
flagSet.StringVarP(&dropIndexFlags.namespace, flags.Namespace, "n", "", commonFlags.DefaultWrapHelpString("The namespace for the index.")) //nolint:lll // For readability
flagSet.StringArrayVarP(&dropIndexFlags.sets, flags.Sets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability
flagSet.StringVarP(&dropIndexFlags.indexName, flags.IndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability
flagSet.DurationVar(&dropIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(dropIndexFlags.NewClientFlagSet())

return flagSet
}

var dropIndexRequiredFlags = []string{
flagNameNamespace,
flagNameIndexName,
flags.Namespace,
flags.IndexName,
}

// dropIndexCmd represents the dropIndex command
Expand All @@ -65,30 +57,41 @@ func newDropIndexCommand() *cobra.Command {
asvec drop index -i myindex -n test
`,
PreRunE: func(_ *cobra.Command, _ []string) error {
if viper.IsSet(flagNameSeeds) && viper.IsSet(flagNameHost) {
return fmt.Errorf("only --%s or --%s allowed", flagNameSeeds, flagNameHost)
if viper.IsSet(flags.Seeds) && viper.IsSet(flags.Host) {
return fmt.Errorf("only --%s or --%s allowed", flags.Seeds, flags.Host)
}

return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
logger.Debug("parsed flags",
slog.String(flagNameHost, dropIndexFlags.host.String()),
slog.String(flagNameSeeds, dropIndexFlags.seeds.String()),
slog.String(flagNameListenerName, dropIndexFlags.listenerName.String()),
slog.String(flagNameNamespace, dropIndexFlags.namespace),
slog.Any(flagNameSets, dropIndexFlags.sets),
slog.String(flagNameIndexName, dropIndexFlags.indexName),
slog.Duration(flagNameTimeout, dropIndexFlags.timeout),
slog.String(flags.Host, dropIndexFlags.Host.String()),
slog.String(flags.Seeds, dropIndexFlags.Seeds.String()),
slog.String(flags.ListenerName, dropIndexFlags.ListenerName.String()),
slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil),
slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil),
slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil),
slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil),
slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil),
slog.String(flags.Namespace, dropIndexFlags.namespace),
slog.Any(flags.Sets, dropIndexFlags.sets),
slog.String(flags.IndexName, dropIndexFlags.indexName),
slog.Duration(flags.Timeout, dropIndexFlags.timeout),
)

hosts, isLoadBalancer := parseBothHostSeedsFlag(dropIndexFlags.seeds, dropIndexFlags.host)
hosts, isLoadBalancer := parseBothHostSeedsFlag(dropIndexFlags.Seeds, dropIndexFlags.Host)

ctx, cancel := context.WithTimeout(context.Background(), dropIndexFlags.timeout)
defer cancel()

tlsConfig, err := dropIndexFlags.NewTLSConfig()
if err != nil {
logger.Error("failed to create TLS config", slog.Any("error", err))
return err
}

adminClient, err := avs.NewAdminClient(
ctx, hosts, nil, isLoadBalancer, nil, logger,
ctx, hosts, dropIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger,
)
if err != nil {
logger.Error("failed to create AVS client", slog.Any("error", err))
Expand Down
34 changes: 34 additions & 0 deletions cmd/flags/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package flags

import (
"fmt"

commonFlags "github.com/aerospike/tools-common-go/flags"
"github.com/spf13/pflag"
)

type ClientFlags struct {
Host *HostPortFlag
Seeds *SeedsSliceFlag
ListenerName StringOptionalFlag
TLSFlags
}

func NewClientFlags() *ClientFlags {
return &ClientFlags{
Host: NewDefaultHostPortFlag(),
Seeds: &SeedsSliceFlag{},
TLSFlags: *NewTLSFlags(),
}
}

func (cf *ClientFlags) NewClientFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.VarP(cf.Host, Host, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", Seeds))) //nolint:lll // For readability
flagSet.Var(cf.Seeds, Seeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", Host))) //nolint:lll // For readability
flagSet.VarP(&cf.ListenerName, ListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments."))

flagSet.AddFlagSet(cf.NewTLSFlagSet(commonFlags.DefaultWrapHelpString))

return flagSet
}
31 changes: 31 additions & 0 deletions cmd/flags/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package flags

const (
LogLevel = "log-level"
Seeds = "seeds"
Host = "host"
ListenerName = "listener-name"
Namespace = "namespace"
Sets = "sets"
IndexName = "index-name"
VectorField = "vector-field"
Dimension = "dimension"
DistanceMetric = "distance-metric"
IndexMeta = "index-meta"
Timeout = "timeout"
Verbose = "verbose"
StorageNamespace = "storage-namespace"
StorageSet = "storage-set"
MaxEdges = "hnsw-max-edges"
ConstructionEf = "hnsw-ef-construction"
Ef = "hnsw-ef"
BatchMaxRecords = "hnsw-batch-max-records"
BatchInterval = "hnsw-batch-interval"
BatchEnabled = "hnsw-batch-enabled"
TLSProtocols = "tls-protocols"
TLSCaFile = "tls-cafile"
TLSCaPath = "tls-capath"
TLSCertFile = "tls-certfile"
TLSKeyFile = "tls-keyfile"
TLSKeyFilePass = "tls-keyfile-password"
)
6 changes: 3 additions & 3 deletions cmd/flags/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func NewTLSFlags() *TLSFlags {

// NewAerospikeFlagSet returns a new pflag.FlagSet with Aerospike flags defined.
// Values set in the returned FlagSet will be stored in the AerospikeFlags argument.
func (tf *TLSFlags) NewFlagSet(fmtUsage commonFlags.UsageFormatter) *pflag.FlagSet {
func (tf *TLSFlags) NewTLSFlagSet(fmtUsage commonFlags.UsageFormatter) *pflag.FlagSet {
f := &pflag.FlagSet{}

f.Var(&tf.TLSRootCAFile, "tls-cafile", fmtUsage("The CA used when connecting to AVS."))
Expand Down Expand Up @@ -55,7 +55,7 @@ func (tf *TLSFlags) NewTLSConfig() (*tls.Config, error) {
tf.TLSCertFile,
tf.TLSKeyFile,
tf.TLSKeyFilePass,
tf.TLSProtocols.Min,
tf.TLSProtocols.Max,
0,
0,
).NewGoTLSConfig()
}
56 changes: 31 additions & 25 deletions cmd/listIndex.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,21 @@ import (
)

var listIndexFlags = &struct {
host *flags.HostPortFlag
seeds *flags.SeedsSliceFlag
listenerName flags.StringOptionalFlag
verbose bool
timeout time.Duration
tls *flags.TLSFlags
flags.ClientFlags
verbose bool
timeout time.Duration
}{
host: flags.NewDefaultHostPortFlag(),
seeds: &flags.SeedsSliceFlag{},
tls: &flags.TLSFlags{},
ClientFlags: *flags.NewClientFlags(),
}

func newListIndexFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.VarP(listIndexFlags.host, flagNameHost, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flagNameSeeds))) //nolint:lll // For readability
flagSet.Var(listIndexFlags.seeds, flagNameSeeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flagNameHost))) //nolint:lll // For readability
flagSet.VarP(&listIndexFlags.listenerName, flagNameListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability
flagSet.BoolVarP(&listIndexFlags.verbose, flagNameVerbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability
flagSet.DurationVar(&listIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(listIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString))
flagSet.VarP(listIndexFlags.Host, flags.Host, "h", commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS host to connect to. If cluster discovery is needed use --%s", flags.Seeds))) //nolint:lll // For readability
flagSet.Var(listIndexFlags.Seeds, flags.Seeds, commonFlags.DefaultWrapHelpString(fmt.Sprintf("The AVS seeds to use for cluster discovery. If no cluster discovery is needed (i.e. load-balancer) then use --%s", flags.Host))) //nolint:lll // For readability
flagSet.VarP(&listIndexFlags.ListenerName, flags.ListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability
flagSet.BoolVarP(&listIndexFlags.verbose, flags.Verbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability
flagSet.DurationVar(&listIndexFlags.timeout, flags.Timeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability
flagSet.AddFlagSet(listIndexFlags.NewClientFlagSet())

return flagSet
}
Expand All @@ -57,30 +52,41 @@ func newListIndexCmd() *cobra.Command {
For example:
export ASVEC_HOST=<avs-ip>:5000
asvec list index
`, flagNameVerbose),
`, flags.Verbose),
PreRunE: func(_ *cobra.Command, _ []string) error {
if viper.IsSet(flagNameSeeds) && viper.IsSet(flagNameHost) {
return fmt.Errorf("only --%s or --%s allowed", flagNameSeeds, flagNameHost)
if viper.IsSet(flags.Seeds) && viper.IsSet(flags.Host) {
return fmt.Errorf("only --%s or --%s allowed", flags.Seeds, flags.Host)
}

return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
logger.Debug("parsed flags",
slog.String(flagNameHost, listIndexFlags.host.String()),
slog.String(flagNameSeeds, listIndexFlags.seeds.String()),
slog.String(flagNameListenerName, listIndexFlags.listenerName.String()),
slog.Bool(flagNameVerbose, listIndexFlags.verbose),
slog.Duration(flagNameTimeout, listIndexFlags.timeout),
slog.String(flags.Host, listIndexFlags.Host.String()),
slog.String(flags.Seeds, listIndexFlags.Seeds.String()),
slog.String(flags.ListenerName, listIndexFlags.ListenerName.String()),
slog.Bool(flags.TLSCaFile, createIndexFlags.TLSRootCAFile != nil),
slog.Bool(flags.TLSCaPath, createIndexFlags.TLSRootCAPath != nil),
slog.Bool(flags.TLSCertFile, createIndexFlags.TLSCertFile != nil),
slog.Bool(flags.TLSKeyFile, createIndexFlags.TLSKeyFile != nil),
slog.Bool(flags.TLSKeyFilePass, createIndexFlags.TLSKeyFilePass != nil),
slog.Bool(flags.Verbose, listIndexFlags.verbose),
slog.Duration(flags.Timeout, listIndexFlags.timeout),
)

hosts, isLoadBalancer := parseBothHostSeedsFlag(listIndexFlags.seeds, listIndexFlags.host)
hosts, isLoadBalancer := parseBothHostSeedsFlag(listIndexFlags.Seeds, listIndexFlags.Host)

ctx, cancel := context.WithTimeout(context.Background(), listIndexFlags.timeout)
defer cancel()

tlsConfig, err := listIndexFlags.NewTLSConfig()
if err != nil {
logger.Error("failed to create TLS config", slog.Any("error", err))
return err
}

adminClient, err := avs.NewAdminClient(
ctx, hosts, listIndexFlags.listenerName.Val, isLoadBalancer, nil, logger,
ctx, hosts, listIndexFlags.ListenerName.Val, isLoadBalancer, tlsConfig, logger,
)
if err != nil {
logger.Error("failed to create AVS client", slog.Any("error", err))
Expand Down
6 changes: 3 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ func Execute() {
}

func init() {
rootCmd.PersistentFlags().Var(&rootFlags.logLevel, logLevelFlagName, common.DefaultWrapHelpString(fmt.Sprintf("Log level for additional details and debugging. Valid values: %s", strings.Join(flags.LogLevelEnum(), ", "))))
rootCmd.PersistentFlags().Var(&rootFlags.logLevel, flags.LogLevel, common.DefaultWrapHelpString(fmt.Sprintf("Log level for additional details and debugging. Valid values: %s", strings.Join(flags.LogLevelEnum(), ", "))))
common.SetupRoot(rootCmd, "aerospike-vector-search", "0.0.0") // TODO: Handle version
viper.SetEnvPrefix("ASVEC")

if err := viper.BindEnv(flagNameHost); err != nil {
if err := viper.BindEnv(flags.Host); err != nil {
logger.Error("failed to bind environment variable", slog.Any("error", err))
}

if err := viper.BindEnv(flagNameSeeds); err != nil {
if err := viper.BindEnv(flags.Seeds); err != nil {
logger.Error("failed to bind environment variable", slog.Any("error", err))
}
}
49 changes: 49 additions & 0 deletions cmd/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//go:build unit

package cmd

import (
"asvec/cmd/flags"
"testing"

avs "github.com/aerospike/aerospike-proximus-client-go"
"github.com/stretchr/testify/assert"
)

func TestParseBothHostSeedsFlag(t *testing.T) {
testCases := []struct {
seeds *flags.SeedsSliceFlag
host *flags.HostPortFlag
expectedSlice avs.HostPortSlice
expectedIsLoadBalancer bool
}{
{
&flags.SeedsSliceFlag{
Seeds: avs.HostPortSlice{
avs.NewHostPort("1.1.1.1", 5000, false),
},
},
flags.NewDefaultHostPortFlag(),
avs.HostPortSlice{
avs.NewHostPort("1.1.1.1", 5000, false),
},
false,
},
{
&flags.SeedsSliceFlag{
Seeds: avs.HostPortSlice{},
},
flags.NewDefaultHostPortFlag(),
avs.HostPortSlice{
&flags.NewDefaultHostPortFlag().HostPort,
},
true,
},
}

for _, tc := range testCases {
actualSlice, actualBool := parseBothHostSeedsFlag(tc.seeds, tc.host)
assert.Equal(t, tc.expectedSlice, actualSlice)
assert.Equal(t, tc.expectedIsLoadBalancer, actualBool)
}
}
Loading

0 comments on commit bc226ce

Please sign in to comment.