diff --git a/cmd/osd/main.go b/cmd/osd/main.go index cf070b46f..d842a9eaf 100644 --- a/cmd/osd/main.go +++ b/cmd/osd/main.go @@ -7,18 +7,18 @@ // This document represents the API documentaton of Openstorage, for the GO client please visit: // https://github.com/libopenstorage/openstorage // -// Schemes: http, https -// Host: localhost -// BasePath: /v1 -// Version: 2.0.0 -// License: APACHE2 https://opensource.org/licenses/Apache-2.0 -// Contact: https://github.com/libopenstorage/openstorage +// Schemes: http, https +// Host: localhost +// BasePath: /v1 +// Version: 2.0.0 +// License: APACHE2 https://opensource.org/licenses/Apache-2.0 +// Contact: https://github.com/libopenstorage/openstorage // -// Consumes: -// - application/json +// Consumes: +// - application/json // -// Produces: -// - application/json +// Produces: +// - application/json // // swagger:meta package main @@ -33,6 +33,7 @@ import ( "runtime" "strconv" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -54,7 +55,9 @@ import ( "github.com/libopenstorage/openstorage/objectstore" "github.com/libopenstorage/openstorage/pkg/auth" "github.com/libopenstorage/openstorage/pkg/auth/systemtoken" + "github.com/libopenstorage/openstorage/pkg/loadbalancer" "github.com/libopenstorage/openstorage/pkg/role" + "github.com/libopenstorage/openstorage/pkg/sched" policy "github.com/libopenstorage/openstorage/pkg/storagepolicy" "github.com/libopenstorage/openstorage/schedpolicy" "github.com/libopenstorage/openstorage/volume" @@ -488,14 +491,23 @@ func start(c *cli.Context) error { return fmt.Errorf("Unable to find cluster instance: %v", err) } sdkPort := c.String("sdkport") + // Initialize the scheduler + sched.Init(time.Second) + + roundRobinBalancer, err := loadbalancer.NewRoundRobinBalancer(cm, sdkPort) + if err != nil { + return fmt.Errorf("Unable to get round robin balancer: %v", err) + } + csiServer, err := csi.NewOsdCsiServer(&csi.OsdCsiServerConfig{ - Net: "unix", - Address: csisock, - DriverName: d, - Cluster: cm, - SdkUds: sdksocket, - SdkPort: sdkPort, - CsiDriverName: c.String("csidrivername"), + Net: "unix", + Address: csisock, + DriverName: d, + Cluster: cm, + SdkUds: sdksocket, + SdkPort: sdkPort, + CsiDriverName: c.String("csidrivername"), + RoundRobinBalancer: roundRobinBalancer, }) if err != nil { return fmt.Errorf("Failed to create CSI server for driver %s: %v", d, err) diff --git a/csi/csi.go b/csi/csi.go index 6dac836aa..8c2a47170 100644 --- a/csi/csi.go +++ b/csi/csi.go @@ -17,16 +17,14 @@ limitations under the License. package csi import ( - "errors" "fmt" - "sort" "sync" - "time" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/libopenstorage/openstorage/api" "github.com/libopenstorage/openstorage/csi/sched/k8s" "github.com/libopenstorage/openstorage/pkg/correlation" + "github.com/libopenstorage/openstorage/pkg/loadbalancer" "github.com/libopenstorage/openstorage/pkg/options" "github.com/portworx/kvdb" "github.com/sirupsen/logrus" @@ -49,11 +47,6 @@ var ( clogger *logrus.Logger ) -const ( - connCleanupInterval = 15 * time.Minute - connIdleConnLength = 30 * time.Minute -) - func init() { clogger = correlation.NewPackageLogger(correlation.ComponentCSIDriver) } @@ -61,13 +54,14 @@ func init() { // OsdCsiServerConfig provides the configuration to the // the gRPC CSI server created by NewOsdCsiServer() type OsdCsiServerConfig struct { - Net string - Address string - DriverName string - Cluster cluster.Cluster - SdkUds string - SdkPort string - SchedulerName string + Net string + Address string + DriverName string + Cluster cluster.Cluster + RoundRobinBalancer loadbalancer.Balancer + SdkUds string + SdkPort string + SchedulerName string // Name to be reported back to the CO. If not provided, // the name will be in the format of .openstorage.org @@ -78,12 +72,6 @@ type OsdCsiServerConfig struct { EnableInlineVolumes bool } -// TimedSDKConn represents a gRPC connection and the last time it was used -type TimedSDKConn struct { - Conn *grpc.ClientConn - LastUsage time.Time -} - // OsdCsiServer is a OSD CSI compliant server which // proxies CSI requests for a single specific driver type OsdCsiServer struct { @@ -92,18 +80,16 @@ type OsdCsiServer struct { csi.IdentityServer *grpcserver.GrpcServer - specHandler spec.SpecHandler - driver volume.VolumeDriver - cluster cluster.Cluster - sdkUds string - sdkPort string - conn *grpc.ClientConn - connMap map[string]*TimedSDKConn - nextCreateNodeNumber int - mu sync.Mutex - csiDriverName string - allowInlineVolumes bool - stopCleanupCh chan bool + specHandler spec.SpecHandler + driver volume.VolumeDriver + cluster cluster.Cluster + sdkUds string + sdkPort string + conn *grpc.ClientConn + mu sync.Mutex + csiDriverName string + allowInlineVolumes bool + roundRobinBalancer loadbalancer.Balancer } // NewOsdCsiServer creates a gRPC CSI complient server on the @@ -169,6 +155,7 @@ func NewOsdCsiServer(config *OsdCsiServerConfig) (grpcserver.Server, error) { sdkPort: config.SdkPort, csiDriverName: config.CsiDriverName, allowInlineVolumes: config.EnableInlineVolumes, + roundRobinBalancer: config.RoundRobinBalancer, }, nil } @@ -193,59 +180,8 @@ func (s *OsdCsiServer) getConn() (*grpc.ClientConn, error) { } func (s *OsdCsiServer) getRemoteConn(ctx context.Context) (*grpc.ClientConn, error) { - s.mu.Lock() - defer s.mu.Unlock() - - // Get all nodes and sort them - nodesResp, err := s.cluster.Enumerate() - if err != nil { - return nil, err - } - if len(nodesResp.Nodes) < 1 { - return nil, errors.New("cluster nodes for remote connection not found") - } - sort.Slice(nodesResp.Nodes, func(i, j int) bool { - return nodesResp.Nodes[i].Id < nodesResp.Nodes[j].Id - }) - - // Clean up connections for missing nodes - s.cleanupMissingNodeConnections(ctx, nodesResp.Nodes) - - // Get target node info and set next round robbin node. - // nextNode is always lastNode + 1 mod (numOfNodes), to loop back to zero - var targetNodeNumber int - if s.nextCreateNodeNumber != 0 { - targetNodeNumber = s.nextCreateNodeNumber - } - targetNodeEndpoint := nodesResp.Nodes[targetNodeNumber].MgmtIp - s.nextCreateNodeNumber = (targetNodeNumber + 1) % len(nodesResp.Nodes) - - // Get conn for this node, otherwise create new conn - if len(s.connMap) == 0 { - s.connMap = make(map[string]*TimedSDKConn) - } - if s.connMap[targetNodeEndpoint] == nil { - var err error - clogger.WithContext(ctx).Infof("Round-robin connecting to node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, s.sdkPort) - remoteConn, err := grpcserver.ConnectWithTimeout( - fmt.Sprintf("%s:%s", targetNodeEndpoint, s.sdkPort), - []grpc.DialOption{ - grpc.WithInsecure(), - grpc.WithUnaryInterceptor(correlation.ContextUnaryClientInterceptor), - }, 10*time.Second) - if err != nil { - return nil, err - } - - s.connMap[targetNodeEndpoint] = &TimedSDKConn{ - Conn: remoteConn, - } - } - - // Keep track of when this conn was last accessed - clogger.WithContext(ctx).Infof("Using remote connection to SDK node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, s.sdkPort) - s.connMap[targetNodeEndpoint].LastUsage = time.Now() - return s.connMap[targetNodeEndpoint].Conn, nil + remoteConn, _, err := s.roundRobinBalancer.GetRemoteNodeConnection(ctx) + return remoteConn, err } // driverGetVolume returns a volume for a given ID. This function skips @@ -315,8 +251,6 @@ func (s *OsdCsiServer) addEncryptionInfoToLabels(labels, csiSecrets map[string]s // It will return an error if the server is already running. func (s *OsdCsiServer) Start() error { return s.GrpcServer.Start(func(grpcServer *grpc.Server) { - go s.cleanupConnections() - csi.RegisterIdentityServer(grpcServer, s) csi.RegisterControllerServer(grpcServer, s) csi.RegisterNodeServer(grpcServer, s) @@ -325,85 +259,9 @@ func (s *OsdCsiServer) Start() error { // Start is used to stop the server. func (s *OsdCsiServer) Stop() { - if s.stopCleanupCh != nil { - close(s.stopCleanupCh) - } s.GrpcServer.Stop() } -func (s *OsdCsiServer) cleanupConnections() { - s.stopCleanupCh = make(chan bool) - ticker := time.NewTicker(connCleanupInterval) - - // Check every so often and delete/close connections - for { - select { - case <-s.stopCleanupCh: - ticker.Stop() - return - - case _ = <-ticker.C: - ctx := correlation.WithCorrelationContext(context.Background(), correlation.ComponentCSIDriver) - - // Anonymous function for using defer to unlock mutex - func() { - s.mu.Lock() - defer s.mu.Unlock() - clogger.Tracef("Cleaning up open gRPC connections for CSI distributed provisioning") - - // Clean all expired connections - numConnsClosed := 0 - for ip, timedConn := range s.connMap { - expiryTime := timedConn.LastUsage.Add(connIdleConnLength) - - // Connection has expired after 1hr of no usage. - // Close connection and remove from connMap - if expiryTime.Before(time.Now()) { - clogger.Infof("SDK gRPC connection to %s is has expired after %v minutes of no usage. Closing this connection", ip, connIdleConnLength.Minutes()) - if err := timedConn.Conn.Close(); err != nil { - clogger.Errorf("failed to close connection to %s: %v", ip, timedConn.Conn) - } - delete(s.connMap, ip) - numConnsClosed++ - } - } - - // Get all nodes and cleanup conns for missing/deprovisioned nodes - nodesResp, err := s.cluster.Enumerate() - if err != nil { - clogger.Errorf("failed to get all nodes for connection cleanup: %v", err) - return - } - if len(nodesResp.Nodes) < 1 { - clogger.Errorf("no nodes available to cleanup: %v", err) - return - } - s.cleanupMissingNodeConnections(ctx, nodesResp.Nodes) - - if numConnsClosed > 0 { - clogger.Infof("Cleaned up %v connections for CSI distributed provisioning. %v connections remaining", numConnsClosed, len(s.connMap)) - } - }() - } - } -} - -func (s *OsdCsiServer) cleanupMissingNodeConnections(ctx context.Context, nodes []*api.Node) { - nodesMap := make(map[string]bool) - for _, node := range nodes { - nodesMap[node.MgmtIp] = true - } - for ip, timedConn := range s.connMap { - if ok := nodesMap[ip]; !ok { - // If key in connmap is not in current nodes, close and remove it - if err := timedConn.Conn.Close(); err != nil { - clogger.WithContext(ctx).Errorf("failed to close conn to %s: %v", ip, err) - } - delete(s.connMap, ip) - } - } -} - // adjustFinalErrors adjusts certain gRPC status to make CSI callers // (csi-provisioner, kubelet, etc) retry instead of being marked as a failure. // See https://github.com/kubernetes/kubernetes/blob/64ed9145452d2d1d324d2437566f1ea1ce76f226/pkg/volume/csi/csi_client.go#L718-L724 diff --git a/csi/csi_test.go b/csi/csi_test.go index fe366db56..e0ecb8563 100644 --- a/csi/csi_test.go +++ b/csi/csi_test.go @@ -37,6 +37,7 @@ import ( "github.com/libopenstorage/openstorage/config" "github.com/libopenstorage/openstorage/pkg/auth" "github.com/libopenstorage/openstorage/pkg/grpcserver" + "github.com/libopenstorage/openstorage/pkg/loadbalancer" "github.com/libopenstorage/openstorage/pkg/options" "github.com/libopenstorage/openstorage/pkg/role" "github.com/libopenstorage/openstorage/pkg/storagepolicy" @@ -152,6 +153,7 @@ func newTestServerWithConfig(t *testing.T, config *OsdCsiServerConfig) *testServ if config.Cluster == nil { config.Cluster = tester.c } + config.RoundRobinBalancer = loadbalancer.NewNullBalancer() setupMockDriver(tester, t) diff --git a/pkg/correlation/context.go b/pkg/correlation/context.go index 04e94825e..c3fee7b95 100644 --- a/pkg/correlation/context.go +++ b/pkg/correlation/context.go @@ -41,10 +41,11 @@ const ( // ContextOriginKey represents the key for the correlation origin ContextOriginKey = "correlation-context-origin" - ComponentUnknown = Component("unknown") - ComponentCSIDriver = Component("csi-driver") - ComponentSDK = Component("sdk-server") - ComponentAuth = Component("openstorage/pkg/auth") + ComponentUnknown = Component("unknown") + ComponentCSIDriver = Component("csi-driver") + ComponentSDK = Component("sdk-server") + ComponentRoundRobinBalancer = Component("round-robin-balancer") + ComponentAuth = Component("openstorage/pkg/auth") ) // RequestContext represents the context for a given a request. diff --git a/pkg/loadbalancer/balancer.go b/pkg/loadbalancer/balancer.go new file mode 100644 index 000000000..8e2b2f464 --- /dev/null +++ b/pkg/loadbalancer/balancer.go @@ -0,0 +1,30 @@ +package loadbalancer + +import ( + "context" + "fmt" + + "google.golang.org/grpc" +) + +// Balancer provides APIs to load balance a gRPC connection over a given +// cluster. +type Balancer interface { + // GetRemoteNodeConnection returns a gRPC client connection to a node + // in the cluster using a round-robin algorithm. The API will return + // an error if it fails to create a connection to a node in the cluster. + // The boolean return argument is set to false if the connection is created + // to the local node. + GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) +} + +type nullBalancer struct{} + +// NewNullBalancer is the no-op implementation of the Balancer interface +func NewNullBalancer() Balancer { + return &nullBalancer{} +} + +func (n *nullBalancer) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) { + return nil, false, fmt.Errorf("remote connections not supported") +} diff --git a/pkg/loadbalancer/roundrobin.go b/pkg/loadbalancer/roundrobin.go new file mode 100644 index 000000000..03885db60 --- /dev/null +++ b/pkg/loadbalancer/roundrobin.go @@ -0,0 +1,232 @@ +package loadbalancer + +import ( + "context" + "errors" + "fmt" + "sort" + "sync" + "time" + + "github.com/libopenstorage/openstorage/api" + "github.com/libopenstorage/openstorage/cluster" + "github.com/libopenstorage/openstorage/pkg/correlation" + "github.com/libopenstorage/openstorage/pkg/grpcserver" + "github.com/libopenstorage/openstorage/pkg/sched" + "github.com/sirupsen/logrus" + "google.golang.org/grpc" +) + +// TimedSDKConn represents a gRPC connection and the last time it was used +type TimedSDKConn struct { + Conn *grpc.ClientConn + LastUsage time.Time +} + +type roundRobin struct { + cluster cluster.Cluster + connMap map[string]*TimedSDKConn + nextCreateNodeNumber int + mu sync.RWMutex + grpcServerPort string +} + +var ( + rrlogger *logrus.Logger +) + +const ( + connCleanupInterval = 15 * time.Minute + connIdleConnLength = 30 * time.Minute +) + +// NewRoundRobinBalancer returns an implementation of the RoundRobin interface +// for getting a remote grpc client connection to one of the nodes in the cluster. +func NewRoundRobinBalancer( + cluster cluster.Cluster, + grpcServerPort string, +) (Balancer, error) { + if cluster == nil { + return nil, fmt.Errorf("cluster cannot be nil") + } + rr := &roundRobin{cluster: cluster, grpcServerPort: grpcServerPort} + if sched.Instance() == nil { + return nil, fmt.Errorf("sched instance is not initialized") + } + if _, err := sched.Instance().Schedule( + func(interval sched.Interval) { rr.cleanupConnections() }, + sched.Periodic(connCleanupInterval), + time.Now().Add(connCleanupInterval), + false, + ); err != nil { + return nil, fmt.Errorf("failed to schedule round robin cleanup routine: %v", err) + } + return rr, nil +} + +func (rr *roundRobin) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) { + // Get all nodes and sort them + cluster, err := rr.cluster.Enumerate() + if err != nil { + return nil, false, err + } + if len(cluster.Nodes) < 1 { + return nil, false, errors.New("cluster nodes for remote connection not found") + } + sort.Slice(cluster.Nodes, func(i, j int) bool { + return cluster.Nodes[i].Id < cluster.Nodes[j].Id + }) + + // Get target node info and set next round robbin node. + // nextNode is always lastNode + 1 mod (numOfNodes), to loop back to zero + targetNodeEndpoint, isRemoteConn := rr.getTargetAndIncrement(&cluster) + + // Get conn for this node, otherwise create new conn + timedSDKConn, ok := rr.getNodeConnection(targetNodeEndpoint) + if !ok { + var err error + rrlogger.WithContext(ctx).Infof("Round-robin connecting to node %s:%s", targetNodeEndpoint, rr.grpcServerPort) + remoteConn, err := grpcserver.ConnectWithTimeout( + fmt.Sprintf("%s:%s", targetNodeEndpoint, rr.grpcServerPort), + []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithUnaryInterceptor(correlation.ContextUnaryClientInterceptor), + }, 10*time.Second) + if err != nil { + return nil, isRemoteConn, err + } + timedSDKConn = &TimedSDKConn{ + Conn: remoteConn, + } + + rr.setNodeConnection(targetNodeEndpoint, timedSDKConn) + } + + // Keep track of when this conn was last accessed + rrlogger.WithContext(ctx).Infof("Using remote connection to SDK node %s:%s", targetNodeEndpoint, rr.grpcServerPort) + timedSDKConn.LastUsage = time.Now() + return timedSDKConn.Conn, isRemoteConn, nil + +} + +func (rr *roundRobin) getTargetAndIncrement(cluster *api.Cluster) (string, bool) { + rr.mu.Lock() + defer rr.mu.Unlock() + var ( + targetNodeNumber int + isRemoteConn bool + ) + if rr.nextCreateNodeNumber != 0 { + targetNodeNumber = rr.nextCreateNodeNumber + } + targetNode := cluster.Nodes[targetNodeNumber] + if targetNode.Id != cluster.NodeId { + // NodeID set on the cluster object is this node's ID. + // Target NodeID does not match with our NodeID, so this will be a remote connection. + isRemoteConn = true + } + targetNodeEndpoint := targetNode.MgmtIp + rr.nextCreateNodeNumber = (targetNodeNumber + 1) % len(cluster.Nodes) + + return targetNodeEndpoint, isRemoteConn +} + +func (rr *roundRobin) getNodeConnection(targetNodeEndpoint string) (*TimedSDKConn, bool) { + if len(rr.connMap) == 0 { + rr.mu.Lock() + rr.connMap = make(map[string]*TimedSDKConn) + rr.mu.Unlock() + } + + rr.mu.RLock() + timedSDKConn, ok := rr.connMap[targetNodeEndpoint] + rr.mu.RUnlock() + + return timedSDKConn, ok +} + +func (rr *roundRobin) setNodeConnection(targetNodeEndpoint string, tsc *TimedSDKConn) { + rr.mu.Lock() + defer rr.mu.Unlock() + + if len(rr.connMap) == 0 { + rr.connMap = make(map[string]*TimedSDKConn) + } + rr.connMap[targetNodeEndpoint] = tsc +} + +func (rr *roundRobin) cleanupMissingNodeConnections(ctx context.Context, nodes []*api.Node) int { + rr.mu.Lock() + defer rr.mu.Unlock() + + numConnsClosed := 0 + nodesMap := make(map[string]bool) + for _, node := range nodes { + nodesMap[node.MgmtIp] = true + } + for ip, timedConn := range rr.connMap { + if ok := nodesMap[ip]; !ok { + // If key in connmap is not in current nodes, close and remove it + if err := timedConn.Conn.Close(); err != nil { + rrlogger.WithContext(ctx).Errorf("failed to close conn to %s: %v", ip, err) + } + delete(rr.connMap, ip) + numConnsClosed++ + } + } + + return numConnsClosed +} + +func (rr *roundRobin) cleanupExpiredConnections() int { + rr.mu.Lock() + defer rr.mu.Unlock() + numConnsClosed := 0 + + for ip, timedConn := range rr.connMap { + expiryTime := timedConn.LastUsage.Add(connIdleConnLength) + + // Connection has expired after 1hr of no usage. + // Close connection and remove from connMap + if expiryTime.Before(time.Now()) { + rrlogger.Infof("SDK gRPC connection to %s is has expired after %v minutes of no usage. Closing this connection", ip, connIdleConnLength.Minutes()) + if err := timedConn.Conn.Close(); err != nil { + rrlogger.Errorf("failed to close connection to %s: %v", ip, timedConn.Conn) + } + delete(rr.connMap, ip) + numConnsClosed++ + } + } + + return numConnsClosed +} + +func (rr *roundRobin) cleanupConnections() { + ctx := correlation.WithCorrelationContext(context.Background(), correlation.ComponentRoundRobinBalancer) + rrlogger.Tracef("Cleaning up open gRPC connections created for round-robin balancing.") + + // Clean all expired connections + expiredConnsClosed := rr.cleanupExpiredConnections() + if expiredConnsClosed > 0 { + rrlogger.Infof("Cleaned up %v expired node connections created for round-robin balancing. %v connections remaining", expiredConnsClosed, len(rr.connMap)) + } + + // Get all nodes and cleanup conns for missing/decommissioned nodes + nodesResp, err := rr.cluster.Enumerate() + if err != nil { + rrlogger.Errorf("failed to get all nodes for connection cleanup: %v", err) + return + } + if len(nodesResp.Nodes) < 1 { + rrlogger.Errorf("no nodes available to cleanup: %v", err) + return + } + missingNodeConnsClosed := rr.cleanupMissingNodeConnections(ctx, nodesResp.Nodes) + if missingNodeConnsClosed > 0 { + rrlogger.Infof("Cleaned up %v connections for missing nodes created for round-robin balancing. %v connections remaining", missingNodeConnsClosed, len(rr.connMap)) + } +} + +func init() { + rrlogger = correlation.NewPackageLogger(correlation.ComponentRoundRobinBalancer) +}