diff --git a/admin.go b/admin.go index 2329b26..b1a1e98 100644 --- a/admin.go +++ b/admin.go @@ -5,9 +5,12 @@ import ( "context" "crypto/tls" "log/slog" + "strconv" + "strings" "time" "github.com/aerospike/avs-client-go/protos" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" @@ -44,7 +47,7 @@ func NewAdminClient( logger *slog.Logger, ) (*AdminClient, error) { logger = logger.WithGroup("avs.admin") - logger.Debug("creating new client") + logger.Info("creating new client") channelProvider, err := newChannelProvider( ctx, @@ -160,7 +163,7 @@ func (c *AdminClient) IndexCreateFromIndexDef( logger := c.logger.With(slog.Any("definition", indexDef)) logger.InfoContext(ctx, "creating index from definition") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to create index from definition" logger.Error(msg, slog.Any("error", err)) @@ -197,7 +200,7 @@ func (c *AdminClient) IndexUpdate( logger.InfoContext(ctx, "updating index") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to update index" logger.Error(msg, slog.Any("error", err)) @@ -234,7 +237,7 @@ func (c *AdminClient) IndexDrop(ctx context.Context, namespace, name string) err logger := c.logger.With(slog.String("namespace", namespace), slog.String("name", name)) logger.InfoContext(ctx, "dropping index") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to drop index" logger.Error(msg, slog.Any("error", err)) @@ -269,7 +272,7 @@ func (c *AdminClient) IndexDrop(ctx context.Context, namespace, name string) err func (c *AdminClient) IndexList(ctx context.Context) (*protos.IndexDefinitionList, error) { c.logger.InfoContext(ctx, "listing indexes") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to get indexes" @@ -298,7 +301,7 @@ func (c *AdminClient) IndexGet(ctx context.Context, namespace, name string) (*pr logger := c.logger.With(slog.String("namespace", namespace), slog.String("name", name)) logger.InfoContext(ctx, "getting index") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to get index" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -328,7 +331,7 @@ func (c *AdminClient) IndexGetStatus(ctx context.Context, namespace, name string logger := c.logger.With(slog.String("namespace", namespace), slog.String("name", name)) logger.InfoContext(ctx, "getting index status") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to get index status" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -363,7 +366,7 @@ func (c *AdminClient) GcInvalidVertices(ctx context.Context, namespace, name str logger.InfoContext(ctx, "garbage collection invalid vertices") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to garbage collect invalid vertices" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -396,7 +399,7 @@ func (c *AdminClient) CreateUser(ctx context.Context, username, password string, logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.InfoContext(ctx, "creating user") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to create user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -427,7 +430,7 @@ func (c *AdminClient) UpdateCredentials(ctx context.Context, username, password logger := c.logger.With(slog.String("username", username)) logger.InfoContext(ctx, "updating user credentials") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to update user credentials" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -457,7 +460,7 @@ func (c *AdminClient) DropUser(ctx context.Context, username string) error { logger := c.logger.With(slog.String("username", username)) logger.InfoContext(ctx, "dropping user") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to drop user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -487,7 +490,7 @@ func (c *AdminClient) GetUser(ctx context.Context, username string) (*protos.Use logger := c.logger.With(slog.String("username", username)) logger.InfoContext(ctx, "getting user") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to get user" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -516,7 +519,7 @@ func (c *AdminClient) GetUser(ctx context.Context, username string) (*protos.Use func (c *AdminClient) ListUsers(ctx context.Context) (*protos.ListUsersResponse, error) { c.logger.InfoContext(ctx, "listing users") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to list users" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -542,7 +545,7 @@ func (c *AdminClient) GrantRoles(ctx context.Context, username string, roles []s logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.InfoContext(ctx, "granting user roles") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to grant user roles" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -573,7 +576,7 @@ func (c *AdminClient) RevokeRoles(ctx context.Context, username string, roles [] logger := c.logger.With(slog.String("username", username), slog.Any("roles", roles)) logger.InfoContext(ctx, "revoking user roles") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to revoke user roles" logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -603,7 +606,7 @@ func (c *AdminClient) RevokeRoles(ctx context.Context, username string, roles [] func (c *AdminClient) ListRoles(ctx context.Context) (*protos.ListRolesResponse, error) { c.logger.InfoContext(ctx, "listing roles") - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to list roles" c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) @@ -624,6 +627,188 @@ func (c *AdminClient) ListRoles(ctx context.Context) (*protos.ListRolesResponse, return rolesResp, nil } +// NodeIds returns a list of all the node ids that the client is connected to. +// If a node is accessible but not a part of the cluster it will not be returned. +func (c *AdminClient) NodeIDs(ctx context.Context) []*protos.NodeId { + c.logger.InfoContext(ctx, "getting cluster info") + + ids := c.channelProvider.GetNodeIDs() + nodeIDs := make([]*protos.NodeId, len(ids)) + + for i, id := range ids { + nodeIDs[i] = &protos.NodeId{Id: id} + } + + c.logger.Debug("got node ids", slog.Any("nodeIDs", nodeIDs)) + + return nodeIDs +} + +// ConnectedNodeEndpoint returns the endpoint used to connect to a node. If +// nodeID is nil then an endpoint used to connect to your seed (or +// load-balancer) is used. +func (c *AdminClient) ConnectedNodeEndpoint( + ctx context.Context, + nodeID *protos.NodeId, +) (*protos.ServerEndpoint, error) { + c.logger.InfoContext(ctx, "getting connected endpoint for node", slog.Any("nodeID", nodeID)) + + var ( + conn *grpc.ClientConn + err error + ) + + if nodeID == nil { + conn, err = c.channelProvider.GetSeedConn() + } else { + conn, err = c.channelProvider.GetNodeConn(nodeID.Id) + } + + if err != nil { + msg := "failed to get connected endpoint" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSError(msg) + } + + splitEndpoint := strings.Split(conn.Target(), ":") + + resp := protos.ServerEndpoint{ + Address: splitEndpoint[0], + } + + if len(splitEndpoint) > 1 { + port, err := strconv.ParseUint(splitEndpoint[1], 10, 32) + if err != nil { + msg := "failed to parse port" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSErrorFromGrpc(msg, err) + } + + resp.Port = uint32(port) + } + + return &resp, nil +} + +// ClusteringState returns the state of the cluster according the +// given node. If nodeID is nil then the seed node is used. +func (c *AdminClient) ClusteringState(ctx context.Context, nodeID *protos.NodeId) (*protos.ClusteringState, error) { + c.logger.InfoContext(ctx, "getting clustering state for node", slog.Any("nodeID", nodeID)) + + var ( + conn *grpc.ClientConn + err error + ) + + if nodeID == nil { + conn, err = c.channelProvider.GetSeedConn() + } else { + conn, err = c.channelProvider.GetNodeConn(nodeID.GetId()) + } + + if err != nil { + msg := "failed to list roles" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSErrorFromGrpc(msg, err) + } + + client := protos.NewClusterInfoServiceClient(conn) + + state, err := client.GetClusteringState(ctx, &emptypb.Empty{}) + if err != nil { + msg := "failed to get clustering state" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSErrorFromGrpc(msg, err) + } + + return state, nil +} + +// ClusterEndpoints returns the endpoints of all the nodes in the cluster +// according to the specified node. If nodeID is nil then the seed node is used. +// If listenerName is nil then the default listener name is used. +func (c *AdminClient) ClusterEndpoints( + ctx context.Context, + nodeID *protos.NodeId, + listenerName *string, +) (*protos.ClusterNodeEndpoints, error) { + c.logger.InfoContext(ctx, "getting cluster endpoints for node", slog.Any("nodeID", nodeID)) + + var ( + conn *grpc.ClientConn + err error + ) + + if nodeID == nil { + conn, err = c.channelProvider.GetSeedConn() + } else { + conn, err = c.channelProvider.GetNodeConn(nodeID.GetId()) + } + + if err != nil { + msg := "failed to get cluster endpoints" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSErrorFromGrpc(msg, err) + } + + client := protos.NewClusterInfoServiceClient(conn) + + endpoints, err := client.GetClusterEndpoints(ctx, + &protos.ClusterNodeEndpointsRequest{ + ListenerName: listenerName, + }, + ) + if err != nil { + msg := "failed to get cluster endpoints" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSErrorFromGrpc(msg, err) + } + + return endpoints, nil +} + +// About returns information about the provided node. If nodeID is nil +// then the seed node is used. +func (c *AdminClient) About(ctx context.Context, nodeID *protos.NodeId) (*protos.AboutResponse, error) { + c.logger.InfoContext(ctx, "getting \"about\" info from nodes") + + var ( + conn *grpc.ClientConn + err error + ) + + if nodeID == nil { + conn, err = c.channelProvider.GetSeedConn() + } else { + conn, err = c.channelProvider.GetNodeConn(nodeID.GetId()) + } + + if err != nil { + msg := "failed to make about request" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSErrorFromGrpc(msg, err) + } + + client := protos.NewAboutServiceClient(conn) + + resp, err := client.Get(ctx, &protos.AboutRequest{}) + if err != nil { + msg := "failed to make about request" + c.logger.ErrorContext(ctx, msg, slog.Any("error", err)) + + return nil, NewAVSErrorFromGrpc(msg, err) + } + + return resp, nil +} + // waitForIndexCreation waits for an index to be created and blocks until it is. // The amount of time to wait between each call is defined by waitInterval. func (c *AdminClient) waitForIndexCreation(ctx context.Context, @@ -633,7 +818,7 @@ func (c *AdminClient) waitForIndexCreation(ctx context.Context, ) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("name", name)) - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to wait for index creation" logger.Error(msg, slog.Any("error", err)) @@ -689,7 +874,7 @@ func (c *AdminClient) waitForIndexCreation(ctx context.Context, func (c *AdminClient) waitForIndexDrop(ctx context.Context, namespace, name string, waitInterval time.Duration) error { logger := c.logger.With(slog.String("namespace", namespace), slog.String("name", name)) - conn, err := c.channelProvider.GetConn() + conn, err := c.channelProvider.GetRandomConn() if err != nil { msg := "failed to wait for index deletion" logger.Error(msg, slog.Any("error", err)) diff --git a/channel_provider.go b/channel_provider.go index 3527cef..25aed6c 100644 --- a/channel_provider.go +++ b/channel_provider.go @@ -10,6 +10,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "time" "github.com/aerospike/avs-client-go/protos" @@ -20,14 +21,16 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) +var errChannelProviderClosed = errors.New("channel provider is closed") + // channelAndEndpoints represents a combination of a gRPC client connection and server endpoints. type channelAndEndpoints struct { Channel *grpc.ClientConn Endpoints *protos.ServerEndpointList } -// newChannelAndEndpoints creates a new channelAndEndpoints instance. -func newChannelAndEndpoints(channel *grpc.ClientConn, endpoints *protos.ServerEndpointList) *channelAndEndpoints { +// newConnAndEndpoints creates a new channelAndEndpoints instance. +func newConnAndEndpoints(channel *grpc.ClientConn, endpoints *protos.ServerEndpointList) *channelAndEndpoints { return &channelAndEndpoints{ Channel: channel, Endpoints: endpoints, @@ -51,7 +54,7 @@ type channelProvider struct { isLoadBalancer bool token *tokenManager stopTendChan chan struct{} - closed bool + closed atomic.Bool } // newChannelProvider creates a new channelProvider instance. @@ -101,6 +104,7 @@ func newChannelProvider( nodeConnsLock: &sync.RWMutex{}, stopTendChan: make(chan struct{}), logger: logger, + closed: atomic.Bool{}, } // Connect to the seed nodes. @@ -112,14 +116,14 @@ func newChannelProvider( // Schedule token refresh if token manager is present. if token != nil { - cp.token.ScheduleRefresh(cp.GetConn) + cp.token.ScheduleRefresh(cp.GetRandomConn) } // Start the tend routine if load balancing is disabled. if !isLoadBalancer { - cp.logger.Debug("starting tend routine") cp.updateClusterChannels(ctx) // We want at least one tend to occur before we return + cp.logger.Debug("starting tend routine") go cp.tend(context.Background()) // Might add a tend specific timeout in the future? } else { cp.logger.Debug("load balancer is enabled, not starting tend routine") @@ -170,21 +174,41 @@ func (cp *channelProvider) Close() error { } cp.logger.Debug("closed") - cp.closed = true + cp.closed.Store(true) return firstErr } -// GetConn returns a gRPC client connection to an Aerospike server. -func (cp *channelProvider) GetConn() (*grpc.ClientConn, error) { - if cp.closed { +// GetSeedConn returns a gRPC client connection to a seed node. +func (cp *channelProvider) GetSeedConn() (*grpc.ClientConn, error) { + if cp.closed.Load() { + cp.logger.Warn("ChannelProvider is closed, cannot get channel") + return nil, errChannelProviderClosed + } + + if len(cp.seedConns) == 0 { + msg := "no seed channels found" + cp.logger.Warn(msg) + + return nil, errors.New(msg) + } + + idx := rand.Intn(len(cp.seedConns)) //nolint:gosec // Security is not an issue here + + return cp.seedConns[idx], nil +} + +// GetRandomConn returns a gRPC client connection to an Aerospike server. If +// isLoadBalancer is enabled, it will return the seed connection. +func (cp *channelProvider) GetRandomConn() (*grpc.ClientConn, error) { + if cp.closed.Load() { cp.logger.Warn("ChannelProvider is closed, cannot get channel") return nil, errors.New("ChannelProvider is closed") } if cp.isLoadBalancer { cp.logger.Debug("load balancer is enabled, using seed channel") - return cp.seedConns[0], nil + return cp.GetSeedConn() } cp.nodeConnsLock.RLock() @@ -192,13 +216,16 @@ func (cp *channelProvider) GetConn() (*grpc.ClientConn, error) { discoverdChannels := make([]*channelAndEndpoints, len(cp.nodeConns)) - for i, channel := range cp.nodeConns { + i := 0 + + for _, channel := range cp.nodeConns { discoverdChannels[i] = channel + i++ } if len(discoverdChannels) == 0 { cp.logger.Warn("no node channels found, using seed channel") - return cp.seedConns[0], nil + return cp.GetSeedConn() } idx := rand.Intn(len(discoverdChannels)) //nolint:gosec // Security is not an issue here @@ -206,6 +233,48 @@ func (cp *channelProvider) GetConn() (*grpc.ClientConn, error) { return discoverdChannels[idx].Channel, nil } +// GetNodeConn returns a gRPC client connection to a specific node. If the node +// ID cannot be found an error is returned. +func (cp *channelProvider) GetNodeConn(nodeID uint64) (*grpc.ClientConn, error) { + if cp.closed.Load() { + cp.logger.Warn("ChannelProvider is closed, cannot get channel") + return nil, errors.New("ChannelProvider is closed") + } + + if cp.isLoadBalancer { + cp.logger.Error("load balancer is enabled, using seed channel") + return nil, errors.New("load balancer is enabled, cannot get specific node channel") + } + + cp.nodeConnsLock.RLock() + defer cp.nodeConnsLock.RUnlock() + + channel, ok := cp.nodeConns[nodeID] + if !ok { + msg := "channel not found for specified node id" + cp.logger.Error(msg, slog.Uint64("node", nodeID)) + + return nil, errors.New(msg) + } + + return channel.Channel, nil +} + +// GetNodeIDs returns the node IDs of all nodes discovered during cluster +// tending. If tending is disabled (LB true) then no node IDs are returned. +func (cp *channelProvider) GetNodeIDs() []uint64 { + cp.nodeConnsLock.RLock() + defer cp.nodeConnsLock.RUnlock() + + nodeIDs := make([]uint64, 0, len(cp.nodeConns)) + + for node := range cp.nodeConns { + nodeIDs = append(nodeIDs, node) + } + + return nodeIDs +} + // connectToSeeds connects to the seed nodes and creates gRPC client connections. func (cp *channelProvider) connectToSeeds(ctx context.Context) error { if len(cp.seedConns) != 0 { @@ -260,15 +329,18 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { tokenLock.Unlock() } - // TODO: Check compatible client/server version here if extraCheck { - client := protos.NewClusterInfoServiceClient(conn) + client := protos.NewAboutServiceClient(conn) - _, err = client.GetClusterId(ctx, &emptypb.Empty{}) + about, err := client.Get(ctx, &protos.AboutRequest{}) if err != nil { logger.WarnContext(ctx, "failed to connect to seed", slog.Any("error", err)) return } + + if newVersion(about.Version).lt(minimumSupportedAVSVersion) { + logger.WarnContext(ctx, "incompatible server version", slog.String("version", about.Version)) + } } seedCons <- conn @@ -303,16 +375,24 @@ func (cp *channelProvider) connectToSeeds(ctx context.Context) error { // updateNodeConns updates the gRPC client connection for a specific node. func (cp *channelProvider) updateNodeConns( + ctx context.Context, node uint64, endpoints *protos.ServerEndpointList, ) error { - newChannel, err := cp.createChannelFromEndpoints(endpoints) + newConn, err := cp.createConnFromEndpoints(endpoints) + if err != nil { + return err + } + + client := protos.NewAboutServiceClient(newConn) + _, err = client.Get(ctx, &protos.AboutRequest{}) + if err != nil { return err } cp.nodeConnsLock.Lock() - cp.nodeConns[node] = newChannelAndEndpoints(newChannel, endpoints) + cp.nodeConns[node] = newConnAndEndpoints(newConn, endpoints) cp.nodeConnsLock.Unlock() return nil @@ -398,13 +478,17 @@ func (cp *channelProvider) getUpdatedEndpoints(ctx context.Context) map[uint64]* } } + cp.logger.Debug("found new cluster ID", slog.Any("endpoints", maxTempEndpoints)) + return maxTempEndpoints } -// checkAndSetNodeConns checks if the node connections need to be updated and updates them if necessary. -func (cp *channelProvider) checkAndSetNodeConns(newNodeEndpoints map[uint64]*protos.ServerEndpointList) { +// Checks if the node connections need to be updated and updates them if necessary. +func (cp *channelProvider) checkAndSetNodeConns( + ctx context.Context, + newNodeEndpoints map[uint64]*protos.ServerEndpointList, +) { wg := sync.WaitGroup{} - // Find which nodes have a different endpoint list and update their channel for node, newEndpoints := range newNodeEndpoints { wg.Add(1) @@ -428,13 +512,20 @@ func (cp *channelProvider) checkAndSetNodeConns(newNodeEndpoints map[uint64]*pro } // Either this is a new node or its endpoints have changed - err = cp.updateNodeConns(node, newEndpoints) + err = cp.updateNodeConns(ctx, node, newEndpoints) if err != nil { logger.Error("failed to create new channel", slog.Any("error", err)) } + } else { + cp.logger.Debug("endpoints for node unchanged") } } else { - cp.logger.Debug("endpoints for node unchanged", slog.Uint64("node", node)) + logger.Debug("new node found, creating new channel") + + err := cp.updateNodeConns(ctx, node, newEndpoints) + if err != nil { + logger.Error("failed to create new channel", slog.Any("error", err)) + } } }(node, newEndpoints) } @@ -470,9 +561,9 @@ func (cp *channelProvider) updateClusterChannels(ctx context.Context) { return } - cp.logger.Debug("new endpoints found, updating channels", slog.Any("endpoints", updatedEndpoints)) + cp.logger.Debug("new cluster id found, updating channels") - cp.checkAndSetNodeConns(updatedEndpoints) + cp.checkAndSetNodeConns(ctx, updatedEndpoints) cp.removeDownNodes(updatedEndpoints) } @@ -549,9 +640,9 @@ func endpointToHostPort(endpoint *protos.ServerEndpoint) *HostPort { return NewHostPort(endpoint.Address, int(endpoint.Port)) } -// createChannelFromEndpoints creates a gRPC client connection from the first +// createConnFromEndpoints creates a gRPC client connection from the first // successful endpoint in endpoints. -func (cp *channelProvider) createChannelFromEndpoints( +func (cp *channelProvider) createConnFromEndpoints( endpoints *protos.ServerEndpointList, ) (*grpc.ClientConn, error) { for _, endpoint := range endpoints.Endpoints { @@ -575,11 +666,11 @@ func (cp *channelProvider) createChannel(hostPort *HostPort) (*grpc.ClientConn, opts := []grpc.DialOption{} if cp.tlsConfig == nil { - cp.logger.Debug("using insecure connection to host", slog.String("host", hostPort.String())) + cp.logger.Info("using insecure connection to host", slog.String("host", hostPort.String())) opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { - cp.logger.Debug("using secure tls connection to host", slog.String("host", hostPort.String())) + cp.logger.Info("using secure tls connection to host", slog.String("host", hostPort.String())) opts = append(opts, grpc.WithTransportCredentials(grpcCreds.NewTLS(cp.tlsConfig))) } diff --git a/utils.go b/utils.go index daef20c..3b6f9a4 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,11 @@ package avs -import "github.com/aerospike/avs-client-go/protos" +import ( + "strconv" + "strings" + + "github.com/aerospike/avs-client-go/protos" +) func createUserPassCredential(username, password string) *protos.Credentials { return &protos.Credentials{ @@ -12,3 +17,96 @@ func createUserPassCredential(username, password string) *protos.Credentials { }, } } + +var minimumSupportedAVSVersion = newVersion("0.9.0") + +type version []any + +func newVersion(s string) version { + split := strings.Split(s, ".") + v := version{} + + for _, token := range split { + if intVal, err := strconv.ParseUint(token, 10, 64); err == nil { + v = append(v, intVal) + } else { + v = append(v, token) + } + } + + return v +} + +func (v version) String() string { + s := "" + + for i, token := range v { + if i > 0 { + s += "." + } + + switch val := token.(type) { + case uint64: + s += strconv.FormatUint(val, 10) + case string: + s += val + } + } + + return s +} + +func (v version) lt(b version) bool { + strFunc := func(x, y string) bool { + return x < y + } + intFunc := func(x, y int) bool { + return x < y + } + + return compare(v, b, strFunc, intFunc) +} + +func (v version) gt(b version) bool { + strFunc := func(x, y string) bool { + return x > y + } + intFunc := func(x, y int) bool { + return x > y + } + + return compare(v, b, strFunc, intFunc) +} + +type compareFunc[T comparable] func(x, y T) bool + +func compare(a, b version, strFunc compareFunc[string], intFunc compareFunc[int]) bool { + sharedLen := min(len(a), len(b)) + + for i := 0; i < sharedLen; i++ { + switch aVal := a[i].(type) { + case uint64: + switch bVal := b[i].(type) { + case uint64: + if intFunc(int(aVal), int(bVal)) { + return true + } + default: + return false + } + case string: + switch bVal := b[i].(type) { + case string: + if strFunc(aVal, bVal) { + return true + } + default: + return false + } + default: + return false + } + } + + return intFunc(len(a), len(b)) +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..2fc6528 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,77 @@ +package avs + +import "testing" + +func TestVersionLTGT(t *testing.T) { + testCases := []struct { + name string + v1 version + v2 version + wantLT bool + wantGT bool + }{ + { + name: "v1 is less than v2", + v1: newVersion("1.0.0"), + v2: newVersion("2.0.0"), + wantLT: true, + wantGT: false, + }, + { + name: "v1 is greater than v2", + v1: newVersion("2.0.0"), + v2: newVersion("1.0.0"), + wantLT: false, + wantGT: true, + }, + { + name: "v1 is less than v2", + v1: newVersion("1.0.0"), + v2: newVersion("1.0.1"), + wantLT: true, + wantGT: false, + }, + { + name: "v1 is less than v2", + v1: newVersion("1.0.0"), + v2: newVersion("1.0.0.dev0"), + wantLT: true, + wantGT: false, + }, + { + name: "v1 is greater than v2", + v1: newVersion("1.0.0.dev0"), + v2: newVersion("1.0.0"), + wantLT: false, + wantGT: true, + }, + { + name: "v1 is less than v2", + v1: newVersion("1.0.0.dev0"), + v2: newVersion("1.0.0.dev1"), + wantLT: true, + wantGT: false, + }, + { + name: "v1 is equal to v2", + v1: newVersion("1.0.0"), + v2: newVersion("1.0.0"), + wantLT: false, + wantGT: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.v1.lt(tc.v2) + if got != tc.wantLT { + t.Errorf("expected %v, got %v", tc.wantLT, got) + } + + got = tc.v1.gt(tc.v2) + if got != tc.wantGT { + t.Errorf("expected %v, got %v", !tc.wantLT, got) + } + }) + } +}