Skip to content

Commit

Permalink
Merge pull request #183 from mailgun/maxim/develop
Browse files Browse the repository at this point in the history
PIP 2712: Reuse open telemetry interceptors
  • Loading branch information
horkhe authored Nov 3, 2023
2 parents 483542c + 9144513 commit 11436b5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
50 changes: 33 additions & 17 deletions grpcconn/grpcconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type ConnMgr[T any] struct {
cancel context.CancelFunc
closeWG sync.WaitGroup
idPool *IDPool
log *logrus.Entry

connPoolMu sync.RWMutex
connPool []*Conn[T]
Expand All @@ -101,22 +102,31 @@ func (c *Conn[T]) ID() string { return c.id.String() }

// NewConnMgr instantiates a connection manager that maintains a gRPC
// connection pool.
func NewConnMgr[T any](cfg *Config, httpClient *http.Client, connFactory ConnFactory[T]) *ConnMgr[T] {
func NewConnMgr[T any](cfg *Config, httpClient *http.Client, connFactory ConnFactory[T], opts ...option[T]) *ConnMgr[T] {
// This ensures NumConnections is always at least 1
setter.SetDefault(&cfg.NumConnections, defaultNumConnections)
gc := ConnMgr[T]{
cm := ConnMgr[T]{
cfg: cfg,
getEndpointsURL: connFactory.GetServerListURL() + "?zone=" + cfg.Zone,
connFactory: connFactory,
httpClt: httpClient,
reconnectCh: make(chan struct{}, 1),
connPool: make([]*Conn[T], 0, defaultConnPoolCapacity),
idPool: NewIDPool(),
log: logrus.WithField("category", "grpcconn"),
}
cm.ctx, cm.cancel = context.WithCancel(context.Background())
cm.closeWG.Add(1)
go cm.run()
return &cm
}

type option[T any] func(cm *ConnMgr[T])

func WithLogger[T any](log *logrus.Entry) option[T] {
return func(cm *ConnMgr[T]) {
cm.log = log
}
gc.ctx, gc.cancel = context.WithCancel(context.Background())
gc.closeWG.Add(1)
go gc.run()
return &gc
}

func (cm *ConnMgr[T]) AcquireConn(ctx context.Context) (_ *Conn[T], err error) {
Expand Down Expand Up @@ -212,12 +222,12 @@ func (cm *ConnMgr[T]) ReleaseConn(conn *Conn[T], err error) bool {
cm.connPoolMu.Unlock()

if removedFromPool {
logrus.WithError(err).Warnf("Server removed from %s pool: %s, poolSize=%d, reason=%s",
cm.log.WithError(err).Warnf("Server removed from %s pool: %s, poolSize=%d, reason=%s",
cm.connFactory.ServiceName(), conn.Target(), connPoolSize, err)
}
if closeConn {
_ = conn.inner.Close()
logrus.Warnf("Disconnected from %s server %s", cm.connFactory.ServiceName(), conn.Target())
cm.log.Warnf("Disconnected from %s server %s", cm.connFactory.ServiceName(), conn.Target())
return true
}
return false
Expand Down Expand Up @@ -275,7 +285,7 @@ func (cm *ConnMgr[T]) run() {
select {
case <-nilOrReconnectCh:
case <-cm.reconnectCh:
logrus.Info("Force connection pool refresh")
cm.log.Info("Force connection pool refresh")
case <-cm.ctx.Done():
return
}
Expand All @@ -285,7 +295,7 @@ func (cm *ConnMgr[T]) run() {
if errors.Is(err, context.Canceled) {
return
}
logrus.WithError(err).Errorf("Failed to refresh connection pool")
cm.log.WithError(err).Errorf("Failed to refresh connection pool")
reconnectPeriod = cm.cfg.BackOffTimeout
}
// If a server returns zero TTL it means that periodic server list
Expand Down Expand Up @@ -316,7 +326,7 @@ func (cm *ConnMgr[T]) refreshConnPool() (clock.Duration, error) {

newConnCount := 0
crossZoneCount := 0
logrus.Infof("Connecting to %d %s servers", len(getGRPCEndpointRs.Servers), cm.connFactory.ServiceName())
cm.log.Infof("Connecting to %d %s servers", len(getGRPCEndpointRs.Servers), cm.connFactory.ServiceName())
for _, serverSpec := range getGRPCEndpointRs.Servers {
if serverSpec.Zone != cm.cfg.Zone {
crossZoneCount++
Expand All @@ -334,7 +344,7 @@ func (cm *ConnMgr[T]) refreshConnPool() (clock.Duration, error) {
if errors.Is(err, context.Canceled) {
return 0, err
}
logrus.WithError(err).Errorf("Failed to dial %s server: %s",
cm.log.WithError(err).Errorf("Failed to dial %s server: %s",
cm.connFactory.ServiceName(), serverSpec.Endpoint)
break
}
Expand All @@ -350,7 +360,7 @@ func (cm *ConnMgr[T]) refreshConnPool() (clock.Duration, error) {
MetricGRPCConnections.WithLabelValues(cm.connFactory.ServiceName(), conn.Target()).Inc()
cm.connPoolMu.Unlock()
newConnCount++
logrus.Infof("Connected to %s server: %s, zone=%s", cm.connFactory.ServiceName(), serverSpec.Endpoint, serverSpec.Zone)
cm.log.Infof("Connected to %s server: %s, zone=%s", cm.connFactory.ServiceName(), serverSpec.Endpoint, serverSpec.Zone)
}
}
cm.connPoolMu.Lock()
Expand All @@ -363,7 +373,7 @@ func (cm *ConnMgr[T]) refreshConnPool() (clock.Duration, error) {
}
cm.connPoolMu.Unlock()
took := clock.Since(begin).Truncate(clock.Millisecond)
logrus.Warnf("Connection pool refreshed: took=%s, zone=%s, poolSize=%d, newConnCount=%d, knownServerCount=%d, crossZoneCount=%d, ttl=%s",
cm.log.Warnf("Connection pool refreshed: took=%s, zone=%s, poolSize=%d, newConnCount=%d, knownServerCount=%d, crossZoneCount=%d, ttl=%s",
took, cm.cfg.Zone, connPoolSize, newConnCount, len(getGRPCEndpointRs.Servers), crossZoneCount, ttl)
if connPoolSize < 1 {
return 0, errConnPoolEmpty
Expand All @@ -389,9 +399,9 @@ func (cm *ConnMgr[T]) newConnection(endpoint string) (*Conn[T], error) {
ctx, cancel := context.WithTimeout(cm.ctx, cm.cfg.RPCTimeout)
opts := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor()),
grpc.WithStreamInterceptor(otelgrpc.StreamClientInterceptor()),
grpc.WithTransportCredentials(insecureCredentials),
grpc.WithUnaryInterceptor(otelUnaryInterceptor),
grpc.WithStreamInterceptor(otelStreamInterceptor),
}
grpcConn, err := grpc.DialContext(ctx, endpoint, opts...)
cancel()
Expand All @@ -406,6 +416,12 @@ func (cm *ConnMgr[T]) newConnection(endpoint string) (*Conn[T], error) {
}, nil
}

var (
insecureCredentials = insecure.NewCredentials()
otelUnaryInterceptor = otelgrpc.UnaryClientInterceptor()
otelStreamInterceptor = otelgrpc.StreamClientInterceptor()
)

func (cm *ConnMgr[T]) getServerEndpoints(ctx context.Context) (*GetGRPCEndpointsRs, error) {
rq, err := http.NewRequestWithContext(ctx, "GET", cm.getEndpointsURL, http.NoBody)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions mxresolv/mxresolv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ func TestLookup(t *testing.T) {
inDomainName: "test-mx-ipv6.definbox.com",
outMXHosts: []string{"::ffff:2296:b0e1"},
outImplicitMX: false,
}, {
inDomainName: "arenhp.co.uk",
outMXHosts: []string{"arenhp.co.uk"},
outImplicitMX: false,
}} {
t.Run(tc.inDomainName, func(t *testing.T) {
defer mxresolv.SetDeterministic()()
Expand Down

0 comments on commit 11436b5

Please sign in to comment.