Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PWX-33631: Hold round-robin lock only when needed/remove global lock #2342

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions cmd/osd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
// This document represents the API documentaton of Openstorage, for the GO client please visit:
// https://github.com/libopenstorage/openstorage
//
// Schemes: http, https
// Host: localhost
// BasePath: /v1
// Version: 2.0.0
// License: APACHE2 https://opensource.org/licenses/Apache-2.0
// Contact: https://github.com/libopenstorage/openstorage
// Schemes: http, https
// Host: localhost
// BasePath: /v1
// Version: 2.0.0
// License: APACHE2 https://opensource.org/licenses/Apache-2.0
// Contact: https://github.com/libopenstorage/openstorage
//
// Consumes:
// - application/json
// Consumes:
// - application/json
//
// Produces:
// - application/json
// Produces:
// - application/json
//
// swagger:meta
package main
Expand All @@ -33,6 +33,7 @@ import (
"runtime"
"strconv"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
Expand All @@ -54,7 +55,9 @@ import (
"github.com/libopenstorage/openstorage/objectstore"
"github.com/libopenstorage/openstorage/pkg/auth"
"github.com/libopenstorage/openstorage/pkg/auth/systemtoken"
"github.com/libopenstorage/openstorage/pkg/loadbalancer"
"github.com/libopenstorage/openstorage/pkg/role"
"github.com/libopenstorage/openstorage/pkg/sched"
policy "github.com/libopenstorage/openstorage/pkg/storagepolicy"
"github.com/libopenstorage/openstorage/schedpolicy"
"github.com/libopenstorage/openstorage/volume"
Expand Down Expand Up @@ -488,14 +491,23 @@ func start(c *cli.Context) error {
return fmt.Errorf("Unable to find cluster instance: %v", err)
}
sdkPort := c.String("sdkport")
// Initialize the scheduler
sched.Init(time.Second)

roundRobinBalancer, err := loadbalancer.NewRoundRobinBalancer(cm, sdkPort)
if err != nil {
return fmt.Errorf("Unable to get round robin balancer: %v", err)
}

csiServer, err := csi.NewOsdCsiServer(&csi.OsdCsiServerConfig{
Net: "unix",
Address: csisock,
DriverName: d,
Cluster: cm,
SdkUds: sdksocket,
SdkPort: sdkPort,
CsiDriverName: c.String("csidrivername"),
Net: "unix",
Address: csisock,
DriverName: d,
Cluster: cm,
SdkUds: sdksocket,
SdkPort: sdkPort,
CsiDriverName: c.String("csidrivername"),
RoundRobinBalancer: roundRobinBalancer,
})
if err != nil {
return fmt.Errorf("Failed to create CSI server for driver %s: %v", d, err)
Expand Down
186 changes: 22 additions & 164 deletions csi/csi.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@ limitations under the License.
package csi

import (
"errors"
"fmt"
"sort"
"sync"
"time"

csi "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/libopenstorage/openstorage/api"
"github.com/libopenstorage/openstorage/csi/sched/k8s"
"github.com/libopenstorage/openstorage/pkg/correlation"
"github.com/libopenstorage/openstorage/pkg/loadbalancer"
"github.com/libopenstorage/openstorage/pkg/options"
"github.com/portworx/kvdb"
"github.com/sirupsen/logrus"
Expand All @@ -49,25 +47,21 @@ var (
clogger *logrus.Logger
)

const (
connCleanupInterval = 15 * time.Minute
connIdleConnLength = 30 * time.Minute
)

func init() {
clogger = correlation.NewPackageLogger(correlation.ComponentCSIDriver)
}

// OsdCsiServerConfig provides the configuration to the
// the gRPC CSI server created by NewOsdCsiServer()
type OsdCsiServerConfig struct {
Net string
Address string
DriverName string
Cluster cluster.Cluster
SdkUds string
SdkPort string
SchedulerName string
Net string
Address string
DriverName string
Cluster cluster.Cluster
RoundRobinBalancer loadbalancer.Balancer
SdkUds string
SdkPort string
SchedulerName string

// Name to be reported back to the CO. If not provided,
// the name will be in the format of <driver>.openstorage.org
Expand All @@ -78,12 +72,6 @@ type OsdCsiServerConfig struct {
EnableInlineVolumes bool
}

// TimedSDKConn represents a gRPC connection and the last time it was used
type TimedSDKConn struct {
Conn *grpc.ClientConn
LastUsage time.Time
}

// OsdCsiServer is a OSD CSI compliant server which
// proxies CSI requests for a single specific driver
type OsdCsiServer struct {
Expand All @@ -92,18 +80,16 @@ type OsdCsiServer struct {
csi.IdentityServer

*grpcserver.GrpcServer
specHandler spec.SpecHandler
driver volume.VolumeDriver
cluster cluster.Cluster
sdkUds string
sdkPort string
conn *grpc.ClientConn
connMap map[string]*TimedSDKConn
nextCreateNodeNumber int
mu sync.Mutex
csiDriverName string
allowInlineVolumes bool
stopCleanupCh chan bool
specHandler spec.SpecHandler
driver volume.VolumeDriver
cluster cluster.Cluster
sdkUds string
sdkPort string
conn *grpc.ClientConn
mu sync.Mutex
csiDriverName string
allowInlineVolumes bool
roundRobinBalancer loadbalancer.Balancer
}

// NewOsdCsiServer creates a gRPC CSI complient server on the
Expand Down Expand Up @@ -169,6 +155,7 @@ func NewOsdCsiServer(config *OsdCsiServerConfig) (grpcserver.Server, error) {
sdkPort: config.SdkPort,
csiDriverName: config.CsiDriverName,
allowInlineVolumes: config.EnableInlineVolumes,
roundRobinBalancer: config.RoundRobinBalancer,
}, nil
}

Expand All @@ -193,59 +180,8 @@ func (s *OsdCsiServer) getConn() (*grpc.ClientConn, error) {
}

func (s *OsdCsiServer) getRemoteConn(ctx context.Context) (*grpc.ClientConn, error) {
s.mu.Lock()
defer s.mu.Unlock()

// Get all nodes and sort them
nodesResp, err := s.cluster.Enumerate()
if err != nil {
return nil, err
}
if len(nodesResp.Nodes) < 1 {
return nil, errors.New("cluster nodes for remote connection not found")
}
sort.Slice(nodesResp.Nodes, func(i, j int) bool {
return nodesResp.Nodes[i].Id < nodesResp.Nodes[j].Id
})

// Clean up connections for missing nodes
s.cleanupMissingNodeConnections(ctx, nodesResp.Nodes)

// Get target node info and set next round robbin node.
// nextNode is always lastNode + 1 mod (numOfNodes), to loop back to zero
var targetNodeNumber int
if s.nextCreateNodeNumber != 0 {
targetNodeNumber = s.nextCreateNodeNumber
}
targetNodeEndpoint := nodesResp.Nodes[targetNodeNumber].MgmtIp
s.nextCreateNodeNumber = (targetNodeNumber + 1) % len(nodesResp.Nodes)

// Get conn for this node, otherwise create new conn
if len(s.connMap) == 0 {
s.connMap = make(map[string]*TimedSDKConn)
}
if s.connMap[targetNodeEndpoint] == nil {
var err error
clogger.WithContext(ctx).Infof("Round-robin connecting to node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, s.sdkPort)
remoteConn, err := grpcserver.ConnectWithTimeout(
fmt.Sprintf("%s:%s", targetNodeEndpoint, s.sdkPort),
[]grpc.DialOption{
grpc.WithInsecure(),
grpc.WithUnaryInterceptor(correlation.ContextUnaryClientInterceptor),
}, 10*time.Second)
if err != nil {
return nil, err
}

s.connMap[targetNodeEndpoint] = &TimedSDKConn{
Conn: remoteConn,
}
}

// Keep track of when this conn was last accessed
clogger.WithContext(ctx).Infof("Using remote connection to SDK node %v - %s:%s", targetNodeNumber, targetNodeEndpoint, s.sdkPort)
s.connMap[targetNodeEndpoint].LastUsage = time.Now()
return s.connMap[targetNodeEndpoint].Conn, nil
remoteConn, _, err := s.roundRobinBalancer.GetRemoteNodeConnection(ctx)
return remoteConn, err
}

// driverGetVolume returns a volume for a given ID. This function skips
Expand Down Expand Up @@ -315,8 +251,6 @@ func (s *OsdCsiServer) addEncryptionInfoToLabels(labels, csiSecrets map[string]s
// It will return an error if the server is already running.
func (s *OsdCsiServer) Start() error {
return s.GrpcServer.Start(func(grpcServer *grpc.Server) {
go s.cleanupConnections()

csi.RegisterIdentityServer(grpcServer, s)
csi.RegisterControllerServer(grpcServer, s)
csi.RegisterNodeServer(grpcServer, s)
Expand All @@ -325,85 +259,9 @@ func (s *OsdCsiServer) Start() error {

// Start is used to stop the server.
func (s *OsdCsiServer) Stop() {
if s.stopCleanupCh != nil {
close(s.stopCleanupCh)
}
s.GrpcServer.Stop()
}

func (s *OsdCsiServer) cleanupConnections() {
s.stopCleanupCh = make(chan bool)
ticker := time.NewTicker(connCleanupInterval)

// Check every so often and delete/close connections
for {
select {
case <-s.stopCleanupCh:
ticker.Stop()
return

case _ = <-ticker.C:
ctx := correlation.WithCorrelationContext(context.Background(), correlation.ComponentCSIDriver)

// Anonymous function for using defer to unlock mutex
func() {
s.mu.Lock()
defer s.mu.Unlock()
clogger.Tracef("Cleaning up open gRPC connections for CSI distributed provisioning")

// Clean all expired connections
numConnsClosed := 0
for ip, timedConn := range s.connMap {
expiryTime := timedConn.LastUsage.Add(connIdleConnLength)

// Connection has expired after 1hr of no usage.
// Close connection and remove from connMap
if expiryTime.Before(time.Now()) {
clogger.Infof("SDK gRPC connection to %s is has expired after %v minutes of no usage. Closing this connection", ip, connIdleConnLength.Minutes())
if err := timedConn.Conn.Close(); err != nil {
clogger.Errorf("failed to close connection to %s: %v", ip, timedConn.Conn)
}
delete(s.connMap, ip)
numConnsClosed++
}
}

// Get all nodes and cleanup conns for missing/deprovisioned nodes
nodesResp, err := s.cluster.Enumerate()
if err != nil {
clogger.Errorf("failed to get all nodes for connection cleanup: %v", err)
return
}
if len(nodesResp.Nodes) < 1 {
clogger.Errorf("no nodes available to cleanup: %v", err)
return
}
s.cleanupMissingNodeConnections(ctx, nodesResp.Nodes)

if numConnsClosed > 0 {
clogger.Infof("Cleaned up %v connections for CSI distributed provisioning. %v connections remaining", numConnsClosed, len(s.connMap))
}
}()
}
}
}

func (s *OsdCsiServer) cleanupMissingNodeConnections(ctx context.Context, nodes []*api.Node) {
nodesMap := make(map[string]bool)
for _, node := range nodes {
nodesMap[node.MgmtIp] = true
}
for ip, timedConn := range s.connMap {
if ok := nodesMap[ip]; !ok {
// If key in connmap is not in current nodes, close and remove it
if err := timedConn.Conn.Close(); err != nil {
clogger.WithContext(ctx).Errorf("failed to close conn to %s: %v", ip, err)
}
delete(s.connMap, ip)
}
}
}

// adjustFinalErrors adjusts certain gRPC status to make CSI callers
// (csi-provisioner, kubelet, etc) retry instead of being marked as a failure.
// See https://github.com/kubernetes/kubernetes/blob/64ed9145452d2d1d324d2437566f1ea1ce76f226/pkg/volume/csi/csi_client.go#L718-L724
Expand Down
2 changes: 2 additions & 0 deletions csi/csi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/libopenstorage/openstorage/config"
"github.com/libopenstorage/openstorage/pkg/auth"
"github.com/libopenstorage/openstorage/pkg/grpcserver"
"github.com/libopenstorage/openstorage/pkg/loadbalancer"
"github.com/libopenstorage/openstorage/pkg/options"
"github.com/libopenstorage/openstorage/pkg/role"
"github.com/libopenstorage/openstorage/pkg/storagepolicy"
Expand Down Expand Up @@ -152,6 +153,7 @@ func newTestServerWithConfig(t *testing.T, config *OsdCsiServerConfig) *testServ
if config.Cluster == nil {
config.Cluster = tester.c
}
config.RoundRobinBalancer = loadbalancer.NewNullBalancer()

setupMockDriver(tester, t)

Expand Down
9 changes: 5 additions & 4 deletions pkg/correlation/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ const (
// ContextOriginKey represents the key for the correlation origin
ContextOriginKey = "correlation-context-origin"

ComponentUnknown = Component("unknown")
ComponentCSIDriver = Component("csi-driver")
ComponentSDK = Component("sdk-server")
ComponentAuth = Component("openstorage/pkg/auth")
ComponentUnknown = Component("unknown")
ComponentCSIDriver = Component("csi-driver")
ComponentSDK = Component("sdk-server")
ComponentRoundRobinBalancer = Component("round-robin-balancer")
ComponentAuth = Component("openstorage/pkg/auth")
)

// RequestContext represents the context for a given a request.
Expand Down
30 changes: 30 additions & 0 deletions pkg/loadbalancer/balancer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package loadbalancer

import (
"context"
"fmt"

"google.golang.org/grpc"
)

// Balancer provides APIs to load balance a gRPC connection over a given
// cluster.
type Balancer interface {
// GetRemoteNodeConnection returns a gRPC client connection to a node
// in the cluster using a round-robin algorithm. The API will return
// an error if it fails to create a connection to a node in the cluster.
// The boolean return argument is set to false if the connection is created
// to the local node.
GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error)
}

type nullBalancer struct{}

// NewNullBalancer is the no-op implementation of the Balancer interface
func NewNullBalancer() Balancer {
return &nullBalancer{}
}

func (n *nullBalancer) GetRemoteNodeConnection(ctx context.Context) (*grpc.ClientConn, bool, error) {
return nil, false, fmt.Errorf("remote connections not supported")
}
Loading