From 3880ef425efb2f528aff233a2f510f628f284ff4 Mon Sep 17 00:00:00 2001 From: Jesse Schmidt Date: Thu, 13 Jun 2024 15:39:11 -0700 Subject: [PATCH] add tls flags --- cmd/constants.go | 6 +++++ cmd/createIndex.go | 3 +++ cmd/dropIndex.go | 3 +++ cmd/flags/tls.go | 61 ++++++++++++++++++++++++++++++++++++++++++++++ cmd/listIndex.go | 3 +++ 5 files changed, 76 insertions(+) create mode 100644 cmd/flags/tls.go diff --git a/cmd/constants.go b/cmd/constants.go index 0b6ee36..f4cce6c 100644 --- a/cmd/constants.go +++ b/cmd/constants.go @@ -22,4 +22,10 @@ const ( flagNameBatchMaxRecords = "hnsw-batch-max-records" flagNameBatchInterval = "hnsw-batch-interval" flagNameBatchEnabled = "hnsw-batch-enabled" + flagNameTLSProtocols = "tls-protocols" + flagNameTLSCaFile = "tls-cafile" + flagNameTLSCaPath = "tls-capath" + flagNameTLSCertFile = "tls-certfile" + flagNameTLSKeyFile = "tls-keyfile" + flagNameTLSKeyFilePass = "tls-keyfile-password" ) diff --git a/cmd/createIndex.go b/cmd/createIndex.go index 6ce6b41..81b4fb7 100644 --- a/cmd/createIndex.go +++ b/cmd/createIndex.go @@ -40,6 +40,7 @@ var createIndexFlags = &struct { hnswBatchInterval flags.Uint32OptionalFlag hnswBatchEnabled flags.BoolOptionalFlag timeout time.Duration + tls *flags.TLSFlags }{ host: flags.NewDefaultHostPortFlag(), seeds: &flags.SeedsSliceFlag{}, @@ -51,6 +52,7 @@ var createIndexFlags = &struct { hnswBatchMaxRecords: flags.Uint32OptionalFlag{}, hnswBatchInterval: flags.Uint32OptionalFlag{}, hnswBatchEnabled: flags.BoolOptionalFlag{}, + tls: &flags.TLSFlags{}, } func newCreateIndexFlagSet() *pflag.FlagSet { @@ -74,6 +76,7 @@ func newCreateIndexFlagSet() *pflag.FlagSet { flagSet.Var(&createIndexFlags.hnswBatchMaxRecords, flagNameBatchMaxRecords, commonFlags.DefaultWrapHelpString("Maximum number of records to fit in a batch. The default value is 10000.")) //nolint:lll // For readability flagSet.Var(&createIndexFlags.hnswBatchInterval, flagNameBatchInterval, commonFlags.DefaultWrapHelpString("The maximum amount of time in milliseconds to wait before finalizing a batch. The default value is 10000.")) //nolint:lll // For readability flagSet.Var(&createIndexFlags.hnswBatchEnabled, flagNameBatchEnabled, commonFlags.DefaultWrapHelpString("Enables batching for index updates. Default is true meaning batching is enabled.")) //nolint:lll // For readability + flagSet.AddFlagSet(createIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString)) return flagSet } diff --git a/cmd/dropIndex.go b/cmd/dropIndex.go index 28fc395..d8deb5f 100644 --- a/cmd/dropIndex.go +++ b/cmd/dropIndex.go @@ -26,9 +26,11 @@ var dropIndexFlags = &struct { sets []string indexName string timeout time.Duration + tls *flags.TLSFlags }{ host: flags.NewDefaultHostPortFlag(), seeds: &flags.SeedsSliceFlag{}, + tls: &flags.TLSFlags{}, } func newDropIndexFlagSet() *pflag.FlagSet { @@ -40,6 +42,7 @@ func newDropIndexFlagSet() *pflag.FlagSet { flagSet.StringArrayVarP(&dropIndexFlags.sets, flagNameSets, "s", nil, commonFlags.DefaultWrapHelpString("The sets for the index.")) //nolint:lll // For readability flagSet.StringVarP(&dropIndexFlags.indexName, flagNameIndexName, "i", "", commonFlags.DefaultWrapHelpString("The name of the index.")) //nolint:lll // For readability flagSet.DurationVar(&dropIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability + flagSet.AddFlagSet(dropIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString)) return flagSet } diff --git a/cmd/flags/tls.go b/cmd/flags/tls.go new file mode 100644 index 0000000..1dd6f91 --- /dev/null +++ b/cmd/flags/tls.go @@ -0,0 +1,61 @@ +package flags + +import ( + "crypto/tls" + + commonClient "github.com/aerospike/tools-common-go/client" + commonFlags "github.com/aerospike/tools-common-go/flags" + "github.com/spf13/pflag" +) + +type TLSFlags struct { + TLSProtocols commonFlags.TLSProtocolsFlag + TLSRootCAFile commonFlags.CertFlag + TLSRootCAPath commonFlags.CertPathFlag + TLSCertFile commonFlags.CertFlag + TLSKeyFile commonFlags.CertFlag + TLSKeyFilePass commonFlags.PasswordFlag +} + +func NewTLSFlags() *TLSFlags { + return &TLSFlags{ + TLSProtocols: commonFlags.NewDefaultTLSProtocolsFlag(), + } +} + +// NewAerospikeFlagSet returns a new pflag.FlagSet with Aerospike flags defined. +// Values set in the returned FlagSet will be stored in the AerospikeFlags argument. +func (tf *TLSFlags) NewFlagSet(fmtUsage commonFlags.UsageFormatter) *pflag.FlagSet { + f := &pflag.FlagSet{} + + f.Var(&tf.TLSRootCAFile, "tls-cafile", fmtUsage("The CA used when connecting to AVS.")) + f.Var(&tf.TLSRootCAPath, "tls-capath", fmtUsage("A path containing CAs for connecting to AVS.")) + f.Var(&tf.TLSCertFile, "tls-certfile", fmtUsage("The certificate file for mutual TLS authentication with AVS.")) + f.Var(&tf.TLSKeyFile, "tls-keyfile", fmtUsage("The key file used for mutual TLS authentication with AVS.")) + f.Var(&tf.TLSKeyFilePass, "tls-keyfile-password", fmtUsage("The password used to decrypt the key-file if encrypted.")) + f.Var(&tf.TLSProtocols, "tls-protocols", fmtUsage( + "Set the TLS protocol selection criteria. This format is the same as"+ + " Apache's SSLProtocol documented at https://httpd.apache.org/docs/current/mod/mod_ssl.html#ssl protocol.", + )) + + return f +} + +func (tf *TLSFlags) NewTLSConfig() (*tls.Config, error) { + rootCA := [][]byte{} + + if len(tf.TLSRootCAFile) != 0 { + rootCA = append(rootCA, tf.TLSRootCAFile) + } + + rootCA = append(rootCA, tf.TLSRootCAPath...) + + return commonClient.NewTLSConfig( + rootCA, + tf.TLSCertFile, + tf.TLSKeyFile, + tf.TLSKeyFilePass, + tf.TLSProtocols.Min, + tf.TLSProtocols.Max, + ).NewGoTLSConfig() +} diff --git a/cmd/listIndex.go b/cmd/listIndex.go index d9f4008..0ab7e6a 100644 --- a/cmd/listIndex.go +++ b/cmd/listIndex.go @@ -25,9 +25,11 @@ var listIndexFlags = &struct { listenerName flags.StringOptionalFlag verbose bool timeout time.Duration + tls *flags.TLSFlags }{ host: flags.NewDefaultHostPortFlag(), seeds: &flags.SeedsSliceFlag{}, + tls: &flags.TLSFlags{}, } func newListIndexFlagSet() *pflag.FlagSet { @@ -37,6 +39,7 @@ func newListIndexFlagSet() *pflag.FlagSet { flagSet.VarP(&listIndexFlags.listenerName, flagNameListenerName, "l", commonFlags.DefaultWrapHelpString("The listener to ask the AVS server for as configured in the AVS server. Likely required for cloud deployments.")) //nolint:lll // For readability flagSet.BoolVarP(&listIndexFlags.verbose, flagNameVerbose, "v", false, commonFlags.DefaultWrapHelpString("Print detailed index information.")) //nolint:lll // For readability flagSet.DurationVar(&listIndexFlags.timeout, flagNameTimeout, time.Second*5, commonFlags.DefaultWrapHelpString("The distance metric for the index.")) //nolint:lll // For readability + flagSet.AddFlagSet(listIndexFlags.tls.NewFlagSet(commonFlags.DefaultWrapHelpString)) return flagSet }