diff --git a/lib/srv/alpnproxy/azure_msi_middleware.go b/lib/srv/alpnproxy/azure_msi_middleware.go index b26adc80a78fd..0156e18b49f7f 100644 --- a/lib/srv/alpnproxy/azure_msi_middleware.go +++ b/lib/srv/alpnproxy/azure_msi_middleware.go @@ -23,6 +23,7 @@ import ( "encoding/json" "fmt" "net/http" + "sync" "time" "github.com/gravitational/trace" @@ -45,15 +46,16 @@ type AzureMSIMiddleware struct { // ClientID to be returned in a claim. ClientID string - // Key used to sign JWT - Key crypto.Signer - // Clock is used to override time in tests. Clock clockwork.Clock // Log is the Logger. Log logrus.FieldLogger // Secret to be provided by the client. Secret string + + // privateKey used to sign JWT + privateKey crypto.Signer + privateKeyMu sync.RWMutex } var _ LocalProxyHTTPMiddleware = &AzureMSIMiddleware{} @@ -66,9 +68,6 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error { m.Log = logrus.WithField(teleport.ComponentKey, "azure_msi") } - if m.Key == nil { - return trace.BadParameter("missing Key") - } if m.Secret == "" { return trace.BadParameter("missing Secret") } @@ -96,6 +95,22 @@ func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Req return false } +// SetPrivateKey updates the private key. +func (m *AzureMSIMiddleware) SetPrivateKey(privateKey crypto.Signer) { + m.privateKeyMu.Lock() + defer m.privateKeyMu.Unlock() + m.privateKey = privateKey +} +func (m *AzureMSIMiddleware) getPrivateKey() (crypto.Signer, error) { + m.privateKeyMu.RLock() + defer m.privateKeyMu.RUnlock() + if m.privateKey == nil { + // Use a plain error to return status code 500. + return nil, trace.Errorf("missing private key set in AzureMSIMiddleware") + } + return m.privateKey, nil +} + func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Request) error { // request validation if req.URL.Path != ("/" + m.Secret) { @@ -173,10 +188,14 @@ func (m *AzureMSIMiddleware) fetchMSILoginResp(resource string) ([]byte, error) } func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) { + privateKey, err := m.getPrivateKey() + if err != nil { + return "", trace.Wrap(err) + } // Create a new key that can sign and verify tokens. key, err := jwt.New(&jwt.Config{ Clock: m.Clock, - PrivateKey: m.Key, + PrivateKey: privateKey, ClusterName: types.TeleportAzureMSIEndpoint, // todo get cluster name }) if err != nil { diff --git a/lib/srv/alpnproxy/azure_msi_middleware_test.go b/lib/srv/alpnproxy/azure_msi_middleware_test.go index 57bc412560178..8104186a3cb1e 100644 --- a/lib/srv/alpnproxy/azure_msi_middleware_test.go +++ b/lib/srv/alpnproxy/azure_msi_middleware_test.go @@ -53,13 +53,13 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith require.NoError(t, err) return privateKey } + privateKey := newPrivateKey() m := &AzureMSIMiddleware{ Identity: "azureTestIdentity", TenantID: "cafecafe-cafe-4aaa-cafe-cafecafecafe", ClientID: "decaffff-cafe-4aaa-cafe-cafecafecafe", Log: logrus.WithField(teleport.ComponentKey, "msi"), Clock: clockwork.NewFakeClockAt(time.Date(2022, 1, 1, 9, 0, 0, 0, time.UTC)), - Key: newPrivateKey(), Secret: "my-secret", } require.NoError(t, m.CheckAndSetDefaults()) @@ -68,6 +68,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name string url string headers map[string]string + privateKey crypto.Signer expectedHandle bool expectedCode int expectedBody string @@ -76,12 +77,14 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith { name: "ignore non-msi requests", url: "https://graph.windows.net/foo/bar/baz", + privateKey: privateKey, expectedHandle: false, }, { name: "invalid request, wrong secret", url: "https://azure-msi.teleport.dev/bad-secret", headers: nil, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"invalid secret\"\n }\n}", @@ -90,6 +93,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "invalid request, missing secret", url: "https://azure-msi.teleport.dev", headers: nil, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"invalid secret\"\n }\n}", @@ -98,6 +102,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "invalid request, missing metadata", url: "https://azure-msi.teleport.dev/my-secret", headers: nil, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"expected Metadata header with value 'true'\"\n }\n}", @@ -106,6 +111,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "invalid request, bad metadata value", url: "https://azure-msi.teleport.dev/my-secret", headers: map[string]string{"Metadata": "false"}, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"expected Metadata header with value 'true'\"\n }\n}", @@ -114,6 +120,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "invalid request, missing arguments", url: "https://azure-msi.teleport.dev/my-secret", headers: map[string]string{"Metadata": "true"}, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}", @@ -122,6 +129,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "invalid request, missing resource", url: "https://azure-msi.teleport.dev/my-secret?msi_res_id=azureTestIdentity", headers: map[string]string{"Metadata": "true"}, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}", @@ -130,6 +138,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "invalid request, missing identity", url: "https://azure-msi.teleport.dev/my-secret?resource=myresource", headers: map[string]string{"Metadata": "true"}, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': \"\n }\n}", @@ -138,6 +147,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "invalid request, wrong identity", url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestWrongIdentity", headers: map[string]string{"Metadata": "true"}, + privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': azureTestWrongIdentity\"\n }\n}", @@ -146,6 +156,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith name: "well-formatted request", url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity", headers: map[string]string{"Metadata": "true"}, + privateKey: privateKey, expectedHandle: true, expectedCode: 200, verifyBody: func(t *testing.T, body []byte) { @@ -182,7 +193,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith return key.VerifyAzureToken(token) } - claims, err := fromJWT(req.AccessToken, m.Key) + claims, err := fromJWT(req.AccessToken, privateKey) require.NoError(t, err) require.Equal(t, jwt.AzureTokenClaims{ TenantID: "cafecafe-cafe-4aaa-cafe-cafecafecafe", @@ -202,10 +213,21 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith require.Equal(t, expected.NotBefore, req.NotBefore) }, }, + { + name: "no private key set", + url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity", + headers: map[string]string{"Metadata": "true"}, + privateKey: nil, + expectedHandle: true, + expectedCode: 500, + expectedBody: "{\n \"error\": {\n \"message\": \"missing private key set in AzureMSIMiddleware\"\n }\n}", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + m.SetPrivateKey(tt.privateKey) + // prepare request req, err := http.NewRequest("GET", tt.url, strings.NewReader("")) require.NoError(t, err) diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 0aa7e75a193f2..481e0b0813a6d 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -92,6 +92,8 @@ type LocalProxyConfig struct { CheckCertNeeded bool // verifyUpstreamConnection is a callback function to verify upstream connection state. verifyUpstreamConnection func(tls.ConnectionState) error + // onSetCert is a callback when lp.SetCert is called. + onSetCert func(tls.Certificate) } // LocalProxyMiddleware provides callback functions for LocalProxy. @@ -484,6 +486,11 @@ func (l *LocalProxy) SetCert(cert tls.Certificate) { l.certMu.Lock() defer l.certMu.Unlock() l.cfg.Cert = cert + + // Callback, if any. + if l.cfg.onSetCert != nil { + l.cfg.onSetCert(cert) + } } // getCertForConn determines if certificates should be used when dialing diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index 7b520ce37611f..50ff7de9538cf 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -170,3 +170,11 @@ func mySQLVersionToProto(database types.Database) string { // Include MySQL server version return string(common.ProtocolMySQLWithVerPrefix) + versionBase64 } + +// WithOnSetCert provides a callback when lp.SetCert is called. +func WithOnSetCert(callback func(tls.Certificate)) LocalProxyConfigOpt { + return func(config *LocalProxyConfig) error { + config.onSetCert = callback + return nil + } +} diff --git a/tool/tsh/common/app_azure.go b/tool/tsh/common/app_azure.go index 5a8494aae5d3e..6a31c014f802d 100644 --- a/tool/tsh/common/app_azure.go +++ b/tool/tsh/common/app_azure.go @@ -21,6 +21,7 @@ package common import ( "context" "crypto" + "crypto/tls" "fmt" "os" "os/exec" @@ -72,17 +73,11 @@ type azureApp struct { *localProxyApp cf *CLIConf - signer crypto.Signer msiSecret string } // newAzureApp creates a new Azure app. func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azureApp, error) { - keyRing, err := tc.LocalAgent().GetCoreKeyRing() - if err != nil { - return nil, trace.Wrap(err) - } - msiSecret, err := getMSISecret() if err != nil { return nil, err @@ -91,7 +86,6 @@ func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azu return &azureApp{ localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify), cf: cf, - signer: keyRing.TLSPrivateKey, msiSecret: msiSecret, }, nil } @@ -133,7 +127,6 @@ func getMSISecret() (string, error) { // These calls are served entirely locally, which helps the overall performance experienced by the user. func (a *azureApp) StartLocalProxies(ctx context.Context) error { azureMiddleware := &alpnproxy.AzureMSIMiddleware{ - Key: a.signer, Secret: a.msiSecret, // we could, in principle, get the actual TenantID either from live data or from static configuration, // but at this moment there is no clear advantage over simply issuing a new random identifier. @@ -143,7 +136,19 @@ func (a *azureApp) StartLocalProxies(ctx context.Context) error { } // HTTPS proxy mode - err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAzureRequests, alpnproxy.WithHTTPMiddleware(azureMiddleware)) + err := a.StartLocalProxyWithForwarder(ctx, + alpnproxy.MatchAzureRequests, + alpnproxy.WithHTTPMiddleware(azureMiddleware), + alpnproxy.WithOnSetCert(func(cert tls.Certificate) { + // Note that the PrivateKey is most likely set by api/utils/keys.TLSCertificateForSigner. + signer, ok := cert.PrivateKey.(crypto.Signer) + if ok { + azureMiddleware.SetPrivateKey(signer) + } else { + log.Warn("Provided tls.Certificate has no valid private key.") + } + }), + ) return trace.Wrap(err) }