Skip to content

Commit

Permalink
shorter timeout when connecting to s2a service for handshake (#98)
Browse files Browse the repository at this point in the history
* shorter timeout when connecting to s2a service for handshake
  • Loading branch information
xmenxk authored Mar 17, 2023
1 parent 40db1c2 commit 81a6f4a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
12 changes: 7 additions & 5 deletions internal/v2/s2av2.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"context"
"crypto/tls"
"errors"
"flag"
"net"
"time"

Expand All @@ -42,9 +43,10 @@ import (

const (
s2aSecurityProtocol = "tls"
defaultTimeout = 20.0 * time.Second
)

var S2ATimeout = flag.Duration("s2a_timeout", 3*time.Second, "Timeout enforced on the connection to the S2A service for handshake.")

type s2av2TransportCreds struct {
info *credentials.ProtocolInfo
isClient bool
Expand Down Expand Up @@ -119,9 +121,9 @@ func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthori
}
// Remove the port from serverAuthority.
serverName := removeServerNamePort(serverAuthority)
ctx, cancel := context.WithTimeout(ctx, defaultTimeout)
timeoutCtx, cancel := context.WithTimeout(ctx, *S2ATimeout)
defer cancel()
cstream, err := createStream(ctx, c.s2av2Address)
cstream, err := createStream(timeoutCtx, c.s2av2Address)
if err != nil {
grpclog.Infof("Failed to connect to S2Av2: %v", err)
if c.fallbackClientHandshake != nil {
Expand Down Expand Up @@ -165,7 +167,7 @@ func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthori
}
creds := credentials.NewTLS(config)

conn, authInfo, err := creds.ClientHandshake(context.Background(), serverName, rawConn)
conn, authInfo, err := creds.ClientHandshake(ctx, serverName, rawConn)
if err != nil {
grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
if c.fallbackClientHandshake != nil {
Expand All @@ -183,7 +185,7 @@ func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, crede
if c.isClient {
return nil, nil, errors.New("server handshake called using client transport credentials")
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
ctx, cancel := context.WithTimeout(context.Background(), *S2ATimeout)
defer cancel()
cstream, err := createStream(ctx, c.s2av2Address)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion s2a.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, net
if err != nil {
serverName = addr
}
s2aTLSConfig, err := factory.Build(ctx, &TLSClientConfigOptions{
timeoutCtx, cancel := context.WithTimeout(ctx, *v2.S2ATimeout)
defer cancel()
s2aTLSConfig, err := factory.Build(timeoutCtx, &TLSClientConfigOptions{
ServerName: serverName,
})
if err != nil {
Expand Down

0 comments on commit 81a6f4a

Please sign in to comment.