diff --git a/go/vt/grpcclient/client_auth_static.go b/go/vt/grpcclient/client_auth_static.go index 53be18cc4f..5b0596d635 100644 --- a/go/vt/grpcclient/client_auth_static.go +++ b/go/vt/grpcclient/client_auth_static.go @@ -17,29 +17,39 @@ limitations under the License. package grpcclient import ( + "context" "encoding/json" "flag" "os" - - "context" + "os/signal" + "sync" + "syscall" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + + "vitess.io/vitess/go/vt/servenv" ) var ( credsFile = flag.String("grpc_auth_static_client_creds", "", "when using grpc_static_auth in the server, this file provides the credentials to use to authenticate with server") // StaticAuthClientCreds implements client interface to be able to WithPerRPCCredentials _ credentials.PerRPCCredentials = (*StaticAuthClientCreds)(nil) + + clientCreds *StaticAuthClientCreds + clientCredsCancel context.CancelFunc + clientCredsErr error + clientCredsMu sync.Mutex + clientCredsSigChan chan os.Signal ) -// StaticAuthClientCreds holder for client credentials +// StaticAuthClientCreds holder for client credentials. type StaticAuthClientCreds struct { Username string Password string } -// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds +// GetRequestMetadata gets the request metadata as a map from StaticAuthClientCreds. func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { return map[string]string{ "username": c.Username, @@ -49,30 +59,82 @@ func (c *StaticAuthClientCreds) GetRequestMetadata(context.Context, ...string) ( // RequireTransportSecurity indicates whether the credentials requires transport security. // Given that people can use this with or without TLS, at the moment we are not enforcing -// transport security +// transport security. func (c *StaticAuthClientCreds) RequireTransportSecurity() bool { return false } // AppendStaticAuth optionally appends static auth credentials if provided. func AppendStaticAuth(opts []grpc.DialOption) ([]grpc.DialOption, error) { - if *credsFile == "" { - return opts, nil - } - data, err := os.ReadFile(*credsFile) + creds, err := getStaticAuthCreds() if err != nil { return nil, err } - clientCreds := &StaticAuthClientCreds{} - err = json.Unmarshal(data, clientCreds) + if creds != nil { + grpcCreds := grpc.WithPerRPCCredentials(creds) + opts = append(opts, grpcCreds) + } + return opts, nil +} + +// ResetStaticAuth resets the static auth credentials. +func ResetStaticAuth() { + clientCredsMu.Lock() + defer clientCredsMu.Unlock() + if clientCredsCancel != nil { + clientCredsCancel() + clientCredsCancel = nil + } + clientCreds = nil + clientCredsErr = nil +} + +// getStaticAuthCreds returns the static auth creds and error. +func getStaticAuthCreds() (*StaticAuthClientCreds, error) { + clientCredsMu.Lock() + defer clientCredsMu.Unlock() + if *credsFile != "" && clientCreds == nil { + var ctx context.Context + ctx, clientCredsCancel = context.WithCancel(context.Background()) + go handleClientCredsSignals(ctx) + clientCreds, clientCredsErr = loadStaticAuthCredsFromFile(*credsFile) + } + return clientCreds, clientCredsErr +} + +// handleClientCredsSignals handles signals to reload client creds. +func handleClientCredsSignals(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-clientCredsSigChan: + if newCreds, err := loadStaticAuthCredsFromFile(*credsFile); err == nil { + clientCredsMu.Lock() + clientCreds = newCreds + clientCredsErr = err + clientCredsMu.Unlock() + } + } + } +} + +// loadStaticAuthCredsFromFile loads static auth credentials from a file. +func loadStaticAuthCredsFromFile(path string) (*StaticAuthClientCreds, error) { + data, err := os.ReadFile(path) if err != nil { return nil, err } - creds := grpc.WithPerRPCCredentials(clientCreds) - opts = append(opts, creds) - return opts, nil + creds := &StaticAuthClientCreds{} + err = json.Unmarshal(data, creds) + return creds, err } func init() { + servenv.OnInit(func() { + clientCredsSigChan = make(chan os.Signal, 1) + signal.Notify(clientCredsSigChan, syscall.SIGHUP) + _, _ = getStaticAuthCreds() // preload static auth credentials + }) RegisterGRPCDialOptions(AppendStaticAuth) } diff --git a/go/vt/grpcclient/client_auth_static_test.go b/go/vt/grpcclient/client_auth_static_test.go new file mode 100644 index 0000000000..99c8db5e2f --- /dev/null +++ b/go/vt/grpcclient/client_auth_static_test.go @@ -0,0 +1,127 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package grpcclient + +import ( + "errors" + "fmt" + "os" + "reflect" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +func TestAppendStaticAuth(t *testing.T) { + { + clientCreds = nil + clientCredsErr = nil + opts, err := AppendStaticAuth([]grpc.DialOption{}) + assert.Nil(t, err) + assert.Len(t, opts, 0) + } + { + clientCreds = nil + clientCredsErr = errors.New("test err") + opts, err := AppendStaticAuth([]grpc.DialOption{}) + assert.NotNil(t, err) + assert.Len(t, opts, 0) + } + { + clientCreds = &StaticAuthClientCreds{Username: "test", Password: "123456"} + clientCredsErr = nil + opts, err := AppendStaticAuth([]grpc.DialOption{}) + assert.Nil(t, err) + assert.Len(t, opts, 1) + } +} + +func TestGetStaticAuthCreds(t *testing.T) { + tmp, err := os.CreateTemp("", t.Name()) + assert.Nil(t, err) + defer os.Remove(tmp.Name()) + credsFileTmp := tmp.Name() + credsFile = &credsFileTmp + clientCredsSigChan = make(chan os.Signal, 1) + + // load old creds + fmt.Fprint(tmp, `{"Username": "old", "Password": "123456"}`) + ResetStaticAuth() + creds, err := getStaticAuthCreds() + assert.Nil(t, err) + assert.Equal(t, &StaticAuthClientCreds{Username: "old", Password: "123456"}, creds) + + // write new creds to the same file + _ = tmp.Truncate(0) + _, _ = tmp.Seek(0, 0) + fmt.Fprint(tmp, `{"Username": "new", "Password": "123456789"}`) + + // test the creds did not change yet + creds, err = getStaticAuthCreds() + assert.Nil(t, err) + assert.Equal(t, &StaticAuthClientCreds{Username: "old", Password: "123456"}, creds) + + // test SIGHUP signal triggers reload + credsOld := creds + clientCredsSigChan <- syscall.SIGHUP + timeoutChan := time.After(time.Second * 10) + for { + select { + case <-timeoutChan: + assert.Fail(t, "timed out waiting for SIGHUP reload of static auth creds") + return + default: + // confirm new creds get loaded + creds, err = getStaticAuthCreds() + if reflect.DeepEqual(creds, credsOld) { + continue // not changed yet + } + assert.Nil(t, err) + assert.Equal(t, &StaticAuthClientCreds{Username: "new", Password: "123456789"}, creds) + return + } + } +} + +func TestLoadStaticAuthCredsFromFile(t *testing.T) { + { + f, err := os.CreateTemp("", t.Name()) + if !assert.Nil(t, err) { + assert.FailNowf(t, "cannot create temp file: %s", err.Error()) + } + defer os.Remove(f.Name()) + fmt.Fprint(f, `{ + "Username": "test", + "Password": "correct horse battery staple" + }`) + if !assert.Nil(t, err) { + assert.FailNowf(t, "cannot read auth file: %s", err.Error()) + } + + creds, err := loadStaticAuthCredsFromFile(f.Name()) + assert.Nil(t, err) + assert.Equal(t, "test", creds.Username) + assert.Equal(t, "correct horse battery staple", creds.Password) + } + { + _, err := loadStaticAuthCredsFromFile(`does-not-exist`) + assert.NotNil(t, err) + } +} diff --git a/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go b/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go index 28f9573634..2c908c980d 100644 --- a/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go +++ b/go/vt/vtgate/grpcvtgateconn/conn_rpc_test.go @@ -27,6 +27,7 @@ import ( "context" + "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vtgate/grpcvtgateservice" "vitess.io/vitess/go/vt/vtgate/vtgateconn" @@ -105,6 +106,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) { // Create a Go RPC client connecting to the server ctx := context.Background() flag.Set("grpc_auth_static_client_creds", f.Name()) + grpcclient.ResetStaticAuth() client, err := dial(ctx, listener.Addr().String()) if err != nil { t.Fatalf("dial failed: %v", err) @@ -138,6 +140,7 @@ func TestGRPCVTGateConnAuth(t *testing.T) { // Create a Go RPC client connecting to the server ctx = context.Background() flag.Set("grpc_auth_static_client_creds", f.Name()) + grpcclient.ResetStaticAuth() client, err = dial(ctx, listener.Addr().String()) if err != nil { t.Fatalf("dial failed: %v", err)