Skip to content

Commit

Permalink
Merge branch 'slack-vitess-r14.0.5' into slack-vitess-r14.0.5-dsdefense
Browse files Browse the repository at this point in the history
  • Loading branch information
ejortegau committed Feb 13, 2024
2 parents 534b326 + b878269 commit 2390213
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 38 deletions.
1 change: 1 addition & 0 deletions go/flags/endtoend/vtgate.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Usage of vtgate:
--grpc_server_keepalive_enforcement_policy_min_time duration gRPC server minimum keepalive time (default 10s)
--grpc_server_keepalive_enforcement_policy_permit_without_stream gRPC server permit client keepalive pings even when there are no active streams (RPCs)
--grpc_use_effective_callerid If set, and SSL is not used, will set the immediate caller id from the effective caller id's principal.
--healthcheck-dial-concurrency int Maxiumum concurrency of new healthcheck connections. This should be less than the golang max thread limit of 10000. (default 1024)
--healthcheck_retry_delay duration health check retry delay (default 2ms)
--healthcheck_timeout duration the health check timeout period (default 1m0s)
-h, --help display usage and exit
Expand Down
1 change: 1 addition & 0 deletions go/flags/endtoend/vttablet.txt
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ Usage of vttablet:
--grpc_server_keepalive_enforcement_policy_min_time duration gRPC server minimum keepalive time (default 10s)
--grpc_server_keepalive_enforcement_policy_permit_without_stream gRPC server permit client keepalive pings even when there are no active streams (RPCs)
--health_check_interval duration Interval between health checks (default 20s)
--healthcheck-dial-concurrency int Maxiumum concurrency of new healthcheck connections. This should be less than the golang max thread limit of 10000. (default 1024)
--heartbeat_enable If true, vttablet records (if master) or checks (if replica) the current time of a replication heartbeat in the table _vt.heartbeat. The result is used to inform the serving state of the vttablet via healthchecks.
--heartbeat_interval duration How frequently to read and write replication heartbeat. (default 1s)
--heartbeat_on_demand_duration duration If non-zero, heartbeats are only written upon consumer request, and only run for up to given duration following the request. Frequent requests can keep the heartbeat running consistently; when requests are infrequent heartbeat may completely stop between requests (default 0s)
Expand Down
8 changes: 7 additions & 1 deletion go/vt/discovery/healthcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (

"vitess.io/vitess/go/flagutil"
"vitess.io/vitess/go/stats"
"vitess.io/vitess/go/sync2"
"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/topodata"
Expand Down Expand Up @@ -81,6 +82,8 @@ var (
refreshKnownTablets = flag.Bool("tablet_refresh_known_tablets", true, "tablet refresh reloads the tablet address/port map from topo in case it changes")
// topoReadConcurrency tells us how many topo reads are allowed in parallel
topoReadConcurrency = flag.Int("topo_read_concurrency", 32, "concurrent topo reads")
// healthCheckDialConcurrency tells us how many healthcheck connections can be opened to tablets at once. This should be less than the golang max thread limit of 10000.
healthCheckDialConcurrency = flag.Int("healthcheck-dial-concurrency", 1024, "Maxiumum concurrency of new healthcheck connections. This should be less than the golang max thread limit of 10000.")
)

// See the documentation for NewHealthCheck below for an explanation of these parameters.
Expand Down Expand Up @@ -260,6 +263,8 @@ type HealthCheckImpl struct {
subMu sync.Mutex
// subscribers
subscribers map[chan *TabletHealth]struct{}
// healthCheckDialSem is used to limit how many healthchecks initiate in parallel.
healthCheckDialSem *sync2.Semaphore
}

// NewHealthCheck creates a new HealthCheck object.
Expand Down Expand Up @@ -294,6 +299,7 @@ func NewHealthCheck(ctx context.Context, retryDelay, healthCheckTimeout time.Dur
cell: localCell,
retryDelay: retryDelay,
healthCheckTimeout: healthCheckTimeout,
healthCheckDialSem: sync2.NewSemaphore(*healthCheckDialConcurrency, 0),
healthByAlias: make(map[tabletAliasString]*tabletHealthCheck),
healthData: make(map[KeyspaceShardTabletType]map[tabletAliasString]*TabletHealth),
healthy: make(map[KeyspaceShardTabletType][]*TabletHealth),
Expand Down Expand Up @@ -780,7 +786,7 @@ func (hc *HealthCheckImpl) TabletConnection(alias *topodata.TabletAlias, target
// TODO: test that throws this error
return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "tablet: %v is either down or nonexistent", alias)
}
return thc.Connection(), nil
return thc.Connection(hc), nil
}

// getAliasByCell should only be called while holding hc.mu
Expand Down
30 changes: 24 additions & 6 deletions go/vt/discovery/tablet_health_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package discovery
import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"
Expand All @@ -34,6 +35,7 @@ import (
"vitess.io/vitess/go/vt/vttablet/queryservice"
"vitess.io/vitess/go/vt/vttablet/tabletconn"

"google.golang.org/grpc"
"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -123,8 +125,8 @@ func (thc *tabletHealthCheck) setServingState(serving bool, reason string) {
}

// stream streams healthcheck responses to callback.
func (thc *tabletHealthCheck) stream(ctx context.Context, callback func(*query.StreamHealthResponse) error) error {
conn := thc.Connection()
func (thc *tabletHealthCheck) stream(ctx context.Context, hc *HealthCheckImpl, callback func(*query.StreamHealthResponse) error) error {
conn := thc.Connection(hc)
if conn == nil {
// This signals the caller to retry
return nil
Expand All @@ -137,14 +139,30 @@ func (thc *tabletHealthCheck) stream(ctx context.Context, callback func(*query.S
return err
}

func (thc *tabletHealthCheck) Connection() queryservice.QueryService {
func (thc *tabletHealthCheck) Connection(hc *HealthCheckImpl) queryservice.QueryService {
thc.connMu.Lock()
defer thc.connMu.Unlock()
return thc.connectionLocked()
return thc.connectionLocked(hc)
}

func (thc *tabletHealthCheck) connectionLocked() queryservice.QueryService {
func healthCheckDialerFactory(hc *HealthCheckImpl) func(ctx context.Context, addr string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
// Limit the number of healthcheck connections opened in parallel to avoid high OS-thread
// usage due to blocking networking syscalls (eg: DNS lookups, TCP connection opens,
// etc). Without this limit it is possible for vtgates watching >10k tablets to hit
// the panic: 'runtime: program exceeds 10000-thread limit'.
hc.healthCheckDialSem.Acquire()
defer hc.healthCheckDialSem.Release()
var dialer net.Dialer
return dialer.DialContext(ctx, "tcp", addr)
}
}

func (thc *tabletHealthCheck) connectionLocked(hc *HealthCheckImpl) queryservice.QueryService {
if thc.Conn == nil {
grpcclient.RegisterGRPCDialOptions(func(opts []grpc.DialOption) ([]grpc.DialOption, error) {
return append(opts, grpc.WithContextDialer(healthCheckDialerFactory(hc))), nil
})
conn, err := tabletconn.GetDialer()(thc.Tablet, grpcclient.FailFast(true))
if err != nil {
thc.LastError = err
Expand Down Expand Up @@ -273,7 +291,7 @@ func (thc *tabletHealthCheck) checkConn(hc *HealthCheckImpl) {
}()

// Read stream health responses.
err := thc.stream(streamCtx, func(shr *query.StreamHealthResponse) error {
err := thc.stream(streamCtx, hc, func(shr *query.StreamHealthResponse) error {
// We received a message. Reset the back-off.
retryDelay = hc.retryDelay
// Don't block on send to avoid deadlocks.
Expand Down
6 changes: 6 additions & 0 deletions go/vt/grpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"crypto/tls"
"flag"
"sync"
"time"

grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
Expand All @@ -39,6 +40,7 @@ import (
)

var (
grpcDialOptionsMu sync.Mutex
keepaliveTime = flag.Duration("grpc_keepalive_time", 10*time.Second, "After a duration of this time, if the client doesn't see any activity, it pings the server to see if the transport is still alive.")
keepaliveTimeout = flag.Duration("grpc_keepalive_timeout", 10*time.Second, "After having pinged for keepalive check, the client waits for a duration of Timeout and if no activity is seen even after that the connection is closed.")
initialConnWindowSize = flag.Int("grpc_initial_conn_window_size", 0, "gRPC initial connection window size")
Expand All @@ -53,6 +55,8 @@ var grpcDialOptions []func(opts []grpc.DialOption) ([]grpc.DialOption, error)

// RegisterGRPCDialOptions registers an implementation of AuthServer.
func RegisterGRPCDialOptions(grpcDialOptionsFunc func(opts []grpc.DialOption) ([]grpc.DialOption, error)) {
grpcDialOptionsMu.Lock()
defer grpcDialOptionsMu.Unlock()
grpcDialOptions = append(grpcDialOptions, grpcDialOptionsFunc)
}

Expand Down Expand Up @@ -101,12 +105,14 @@ func DialContext(ctx context.Context, target string, failFast FailFast, opts ...

newopts = append(newopts, opts...)
var err error
grpcDialOptionsMu.Lock()
for _, grpcDialOptionInitializer := range grpcDialOptions {
newopts, err = grpcDialOptionInitializer(newopts)
if err != nil {
log.Fatalf("There was an error initializing client grpc.DialOption: %v", err)
}
}
grpcDialOptionsMu.Unlock()

newopts = append(newopts, interceptors()...)

Expand Down
90 changes: 76 additions & 14 deletions go/vt/grpcclient/client_auth_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,39 @@ limitations under the License.
package grpcclient

import (
"context"
"encoding/json"
"flag"
"os"

"context"
"os/signal"
"sync"
"syscall"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"vitess.io/vitess/go/vt/servenv"
)

var (
credsFile = flag.String("grpc_auth_static_client_creds", "", "when using grpc_static_auth in the server, this file provides the credentials to use to authenticate with server")
// StaticAuthClientCreds implements client interface to be able to WithPerRPCCredentials
_ credentials.PerRPCCredentials = (*StaticAuthClientCreds)(nil)

clientCreds *StaticAuthClientCreds
clientCredsCancel context.CancelFunc
clientCredsErr error
clientCredsMu sync.Mutex
clientCredsSigChan chan os.Signal
)

// StaticAuthClientCreds holder for client credentials
// StaticAuthClientCreds holder for client credentials.
type StaticAuthClientCreds struct {
Username string
Password string
}

// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds
// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds.
func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
return map[string]string{
"username": c.Username,
Expand All @@ -49,30 +59,82 @@ func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) (

// RequireTransportSecurity indicates whether the credentials requires transport security.
// Given that people can use this with or without TLS, at the moment we are not enforcing
// transport security
// transport security.
func (c *StaticAuthClientCreds) RequireTransportSecurity() bool {
return false
}

// AppendStaticAuth optionally appends static auth credentials if provided.
func AppendStaticAuth(opts []grpc.DialOption) ([]grpc.DialOption, error) {
if *credsFile == "" {
return opts, nil
}
data, err := os.ReadFile(*credsFile)
creds, err := getStaticAuthCreds()
if err != nil {
return nil, err
}
clientCreds := &StaticAuthClientCreds{}
err = json.Unmarshal(data, clientCreds)
if creds != nil {
grpcCreds := grpc.WithPerRPCCredentials(creds)
opts = append(opts, grpcCreds)
}
return opts, nil
}

// ResetStaticAuth resets the static auth credentials.
func ResetStaticAuth() {
clientCredsMu.Lock()
defer clientCredsMu.Unlock()
if clientCredsCancel != nil {
clientCredsCancel()
clientCredsCancel = nil
}
clientCreds = nil
clientCredsErr = nil
}

// getStaticAuthCreds returns the static auth creds and error.
func getStaticAuthCreds() (*StaticAuthClientCreds, error) {
clientCredsMu.Lock()
defer clientCredsMu.Unlock()
if *credsFile != "" && clientCreds == nil {
var ctx context.Context
ctx, clientCredsCancel = context.WithCancel(context.Background())
go handleClientCredsSignals(ctx)
clientCreds, clientCredsErr = loadStaticAuthCredsFromFile(*credsFile)
}
return clientCreds, clientCredsErr
}

// handleClientCredsSignals handles signals to reload client creds.
func handleClientCredsSignals(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-clientCredsSigChan:
if newCreds, err := loadStaticAuthCredsFromFile(*credsFile); err == nil {
clientCredsMu.Lock()
clientCreds = newCreds
clientCredsErr = err
clientCredsMu.Unlock()
}
}
}
}

// loadStaticAuthCredsFromFile loads static auth credentials from a file.
func loadStaticAuthCredsFromFile(path string) (*StaticAuthClientCreds, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
creds := grpc.WithPerRPCCredentials(clientCreds)
opts = append(opts, creds)
return opts, nil
creds := &StaticAuthClientCreds{}
err = json.Unmarshal(data, creds)
return creds, err
}

func init() {
servenv.OnInit(func() {
clientCredsSigChan = make(chan os.Signal, 1)
signal.Notify(clientCredsSigChan, syscall.SIGHUP)
_, _ = getStaticAuthCreds() // preload static auth credentials
})
RegisterGRPCDialOptions(AppendStaticAuth)
}
Loading

0 comments on commit 2390213

Please sign in to comment.