Skip to content

Commit

Permalink
Thread context.Context through auth.Register and add OTel Tracing (
Browse files Browse the repository at this point in the history
…#40597)

* Wire context.Context into `auth/register.go`

* Thread `context.Context` through `auth.Register` and add tracing

* Unwrap error before recording it in the span
  • Loading branch information
strideynet authored Apr 17, 2024
1 parent 9fa4331 commit 955f86b
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 43 deletions.
22 changes: 22 additions & 0 deletions api/observability/tracing/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package tracing
import (
"context"

"github.com/gravitational/trace"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/propagation"
oteltrace "go.opentelemetry.io/otel/trace"
)
Expand Down Expand Up @@ -53,3 +55,23 @@ func DefaultProvider() oteltrace.TracerProvider {
func NewTracer(name string) oteltrace.Tracer {
return DefaultProvider().Tracer(name)
}

// EndSpan ends the given span and if an error has occurred, set's the span's
// status to error and additionally records the error.
//
// Example usage:
//
// func myFunc() (err error) {
// ctx, span := tracer.NewSpan(ctx, "myFunc")
// defer func() {
// tracing.EndSpan(span, err)
// }()
// ...
// }
func EndSpan(span oteltrace.Span, err error) {
if err != nil {
span.SetStatus(codes.Error, err.Error())
span.RecordError(trace.Unwrap(err))
}
span.End()
}
2 changes: 1 addition & 1 deletion integration/proxy/proxy_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token str
require.NoError(t, err)

node := uuid.NewString()
_, err = auth.Register(auth.RegisterParams{
_, err = auth.Register(context.TODO(), auth.RegisterParams{
Token: token,
ID: auth.IdentityID{
Role: types.RoleNode,
Expand Down
3 changes: 3 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/maps"
Expand Down Expand Up @@ -153,6 +154,8 @@ const (
"(hint: use 'tctl get roles' to find roles that need updating)"
)

var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth")

var ErrRequiresEnterprise = services.ErrRequiresEnterprise

// ServerOption allows setting options as functional arguments to Server
Expand Down
6 changes: 3 additions & 3 deletions lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)

certs, err := Register(RegisterParams{
certs, err := Register(ctx, RegisterParams{
Token: token.GetName(),
ID: IdentityID{
Role: types.RoleBot,
Expand Down Expand Up @@ -189,7 +189,7 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)

certs, err := Register(RegisterParams{
certs, err := Register(ctx, RegisterParams{
Token: token.GetName(),
ID: IdentityID{
Role: types.RoleBot,
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestRegisterBotCertificateExtensions(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)

certs, err := Register(RegisterParams{
certs, err := Register(ctx, RegisterParams{
Token: token.GetName(),
ID: IdentityID{
Role: types.RoleBot,
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func TestRegister_Bot(t *testing.T) {
} {
t.Run(test.desc, func(t *testing.T) {
start := srv.Clock().Now()
certs, err := Register(RegisterParams{
certs, err := Register(ctx, RegisterParams{
Token: test.token.GetName(),
ID: IdentityID{
Role: types.RoleBot,
Expand Down Expand Up @@ -473,7 +473,7 @@ func TestRegister_Bot_Expiry(t *testing.T) {
tok := newBotToken(t, t.Name(), botName, types.RoleBot, srv.Clock().Now().Add(time.Hour))
require.NoError(t, srv.Auth().UpsertToken(ctx, tok))

certs, err := Register(RegisterParams{
certs, err := Register(ctx, RegisterParams{
Token: tok.GetName(),
ID: IdentityID{
Role: types.RoleBot,
Expand Down
77 changes: 49 additions & 28 deletions lib/auth/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/auth/native"
Expand Down Expand Up @@ -195,8 +196,12 @@ type HostCredentials func(context.Context, string, bool, types.RegisterUsingToke
// different hosts than the auth server. This method requires provisioning
// tokens to prove a valid auth server was used to issue the joining request
// as well as a method for the node to validate the auth server.
func Register(params RegisterParams) (*proto.Certs, error) {
ctx := context.TODO()
func Register(ctx context.Context, params RegisterParams) (certs *proto.Certs, err error) {
ctx, span := tracer.Start(ctx, "Register")
defer func() {
tracing.EndSpan(span, err)
}()

if err := params.checkAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -254,7 +259,7 @@ func Register(params RegisterParams) (*proto.Certs, error) {
}

type registerMethod struct {
call func(token string, params RegisterParams) (*proto.Certs, error)
call func(ctx context.Context, token string, params RegisterParams) (*proto.Certs, error)
desc string
}

Expand Down Expand Up @@ -286,7 +291,7 @@ func Register(params RegisterParams) (*proto.Certs, error) {
var collectedErrs []error
for _, method := range registerMethods {
log.Infof("Attempting registration %s.", method.desc)
certs, err := method.call(token, params)
certs, err := method.call(ctx, token, params)
if err != nil {
collectedErrs = append(collectedErrs, err)
log.WithError(err).Debugf("Registration %s failed.", method.desc)
Expand Down Expand Up @@ -316,23 +321,30 @@ func proxyServerIsAuth(server utils.NetAddr) bool {
}

// registerThroughProxy is used to register through the proxy server.
func registerThroughProxy(token string, params RegisterParams) (*proto.Certs, error) {
var certs *proto.Certs
func registerThroughProxy(
ctx context.Context,
token string,
params RegisterParams,
) (certs *proto.Certs, err error) {
ctx, span := tracer.Start(ctx, "registerThroughProxy")
defer func() {
tracing.EndSpan(span, err)
}()

switch params.JoinMethod {
case types.JoinMethodIAM, types.JoinMethodAzure:
// IAM and Azure join methods require gRPC client
conn, err := proxyJoinServiceConn(params, params.Insecure)
conn, err := proxyJoinServiceConn(ctx, params, params.Insecure)
if err != nil {
return nil, trace.Wrap(err)
}
defer conn.Close()

joinServiceClient := client.NewJoinServiceClient(proto.NewJoinServiceClient(conn))
if params.JoinMethod == types.JoinMethodIAM {
certs, err = registerUsingIAMMethod(joinServiceClient, token, params)
certs, err = registerUsingIAMMethod(ctx, joinServiceClient, token, params)
} else {
certs, err = registerUsingAzureMethod(joinServiceClient, token, params)
certs, err = registerUsingAzureMethod(ctx, joinServiceClient, token, params)
}

if err != nil {
Expand All @@ -342,7 +354,7 @@ func registerThroughProxy(token string, params RegisterParams) (*proto.Certs, er
// The rest of the join methods use GetHostCredentials function passed through
// params to call proxy HTTP endpoint
var err error
certs, err = params.GetHostCredentials(context.Background(),
certs, err = params.GetHostCredentials(ctx,
getHostAddresses(params)[0],
params.Insecure,
types.RegisterUsingTokenRequest{
Expand Down Expand Up @@ -374,10 +386,15 @@ func getHostAddresses(params RegisterParams) []string {
}

// registerThroughAuth is used to register through the auth server.
func registerThroughAuth(token string, params RegisterParams) (*proto.Certs, error) {
var client *Client
var err error
func registerThroughAuth(
ctx context.Context, token string, params RegisterParams,
) (certs *proto.Certs, err error) {
ctx, span := tracer.Start(ctx, "registerThroughAuth")
defer func() {
tracing.EndSpan(span, err)
}()

var client *Client
// Build a client for the Auth Server with different certificate validation
// depending on the configured values for Insecure, CAPins and CAPath.
switch {
Expand All @@ -386,7 +403,7 @@ func registerThroughAuth(token string, params RegisterParams) (*proto.Certs, err
client, err = insecureRegisterClient(params)
case len(params.CAPins) != 0:
// CAPins takes precedence over CAPath
client, err = pinRegisterClient(params)
client, err = pinRegisterClient(ctx, params)
case params.CAPath != "":
client, err = caPathRegisterClient(params)
default:
Expand All @@ -401,18 +418,17 @@ func registerThroughAuth(token string, params RegisterParams) (*proto.Certs, err
}
defer client.Close()

var certs *proto.Certs
switch params.JoinMethod {
// IAM and Azure methods use unique gRPC endpoints
case types.JoinMethodIAM:
certs, err = registerUsingIAMMethod(client, token, params)
certs, err = registerUsingIAMMethod(ctx, client, token, params)
case types.JoinMethodAzure:
certs, err = registerUsingAzureMethod(client, token, params)
certs, err = registerUsingAzureMethod(ctx, client, token, params)
default:
// non-IAM join methods use HTTP endpoint
// Get the SSH and X509 certificates for a node.
certs, err = client.RegisterUsingToken(
context.Background(),
ctx,
&types.RegisterUsingTokenRequest{
Token: token,
HostID: params.ID.HostUUID,
Expand All @@ -433,7 +449,9 @@ func registerThroughAuth(token string, params RegisterParams) (*proto.Certs, err
// proxyJoinServiceConn attempts to connect to the join service running on the
// proxy. The Proxy's TLS cert will be verified using the host's root CA pool
// (PKI) unless the --insecure flag was passed.
func proxyJoinServiceConn(params RegisterParams, insecure bool) (*grpc.ClientConn, error) {
func proxyJoinServiceConn(
ctx context.Context, params RegisterParams, insecure bool,
) (*grpc.ClientConn, error) {
tlsConfig := utils.TLSConfig(params.CipherSuites)
tlsConfig.Time = params.Clock.Now
// set NextProtos for TLS routing, the actual protocol will be h2
Expand All @@ -454,14 +472,14 @@ func proxyJoinServiceConn(params RegisterParams, insecure bool) (*grpc.ClientCon
// skip verify as the Proxy server will present its host cert which is not
// fully verifiable at this point since the client does not have the host
// CAs yet before completing registration.
alpnConnUpgrade := client.IsALPNConnUpgradeRequired(context.TODO(), getHostAddresses(params)[0], insecure)
alpnConnUpgrade := client.IsALPNConnUpgradeRequired(ctx, getHostAddresses(params)[0], insecure)
if alpnConnUpgrade && !insecure {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyConnection = verifyALPNUpgradedConn(params.Clock)
}

dialer := client.NewDialer(
context.Background(),
ctx,
apidefaults.DefaultIdleTimeout,
apidefaults.DefaultIOTimeout,
client.WithInsecureSkipVerify(insecure),
Expand Down Expand Up @@ -544,7 +562,9 @@ func readCA(path string) (*x509.Certificate, error) {
// pin, a connection will be re-established and the root CA will be used to
// validate the certificate presented. If both conditions hold true, then we
// know we are connecting to the expected Auth Server.
func pinRegisterClient(params RegisterParams) (*Client, error) {
func pinRegisterClient(
ctx context.Context, params RegisterParams,
) (*Client, error) {
// Build a insecure client to the Auth Server. This is safe because even if
// an attacker were to MITM this connection the CA pin will not match below.
tlsConfig := utils.TLSConfig(params.CipherSuites)
Expand All @@ -564,7 +584,7 @@ func pinRegisterClient(params RegisterParams) (*Client, error) {

// Fetch the root CA from the Auth Server. The NOP role has access to the
// GetClusterCACert endpoint.
localCA, err := authClient.GetClusterCACert(context.TODO())
localCA, err := authClient.GetClusterCACert(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -660,9 +680,9 @@ type joinServiceClient interface {

// registerUsingIAMMethod is used to register using the IAM join method. It is
// able to register through a proxy or through the auth server directly.
func registerUsingIAMMethod(joinServiceClient joinServiceClient, token string, params RegisterParams) (*proto.Certs, error) {
ctx := context.Background()

func registerUsingIAMMethod(
ctx context.Context, joinServiceClient joinServiceClient, token string, params RegisterParams,
) (*proto.Certs, error) {
log.Infof("Attempting to register %s with IAM method using regional STS endpoint", params.ID.Role)
// Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request.
certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
Expand Down Expand Up @@ -702,8 +722,9 @@ func registerUsingIAMMethod(joinServiceClient joinServiceClient, token string, p

// registerUsingAzureMethod is used to register using the Azure join method. It
// is able to register through a proxy or through the auth server directly.
func registerUsingAzureMethod(client joinServiceClient, token string, params RegisterParams) (*proto.Certs, error) {
ctx := context.Background()
func registerUsingAzureMethod(
ctx context.Context, client joinServiceClient, token string, params RegisterParams,
) (*proto.Certs, error) {
certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
imds := azure.NewInstanceMetadataClient()
if !imds.IsAvailable(ctx) {
Expand Down
14 changes: 7 additions & 7 deletions lib/auth/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3387,7 +3387,7 @@ func TestRegisterCAPin(t *testing.T) {
caPin := caPins[0]

// Attempt to register with valid CA pin, should work.
_, err = Register(RegisterParams{
_, err = Register(ctx, RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: IdentityID{
Expand All @@ -3405,7 +3405,7 @@ func TestRegisterCAPin(t *testing.T) {

// Attempt to register with multiple CA pins where the auth server only
// matches one, should work.
_, err = Register(RegisterParams{
_, err = Register(ctx, RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: IdentityID{
Expand All @@ -3422,7 +3422,7 @@ func TestRegisterCAPin(t *testing.T) {
require.NoError(t, err)

// Attempt to register with invalid CA pin, should fail.
_, err = Register(RegisterParams{
_, err = Register(ctx, RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: IdentityID{
Expand All @@ -3439,7 +3439,7 @@ func TestRegisterCAPin(t *testing.T) {
require.Error(t, err)

// Attempt to register with multiple invalid CA pins, should fail.
_, err = Register(RegisterParams{
_, err = Register(ctx, RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: IdentityID{
Expand Down Expand Up @@ -3475,7 +3475,7 @@ func TestRegisterCAPin(t *testing.T) {
require.Len(t, caPins, 2)

// Attempt to register with multiple CA pins, should work
_, err = Register(RegisterParams{
_, err = Register(ctx, RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: IdentityID{
Expand Down Expand Up @@ -3520,7 +3520,7 @@ func TestRegisterCAPath(t *testing.T) {
require.NoError(t, err)

// Attempt to register with nothing at the CA path, should work.
_, err = Register(RegisterParams{
_, err = Register(ctx, RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: IdentityID{
Expand Down Expand Up @@ -3549,7 +3549,7 @@ func TestRegisterCAPath(t *testing.T) {
require.NoError(t, err)

// Attempt to register with valid CA path, should work.
_, err = Register(RegisterParams{
_, err = Register(ctx, RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: IdentityID{
Expand Down
2 changes: 1 addition & 1 deletion lib/service/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ func (process *TeleportProcess) firstTimeConnect(role types.SystemRole) (*Connec
}
}

certs, err := auth.Register(registerParams)
certs, err := auth.Register(process.ExitContext(), registerParams)
if err != nil {
if utils.IsUntrustedCertErr(err) {
return nil, trace.WrapWithMessage(err, utils.SelfSignedCertsMsg)
Expand Down
2 changes: 1 addition & 1 deletion lib/tbot/service_bot_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func botIdentityFromToken(ctx context.Context, log logrus.FieldLogger, cfg *conf
}
}

certs, err := auth.Register(params)
certs, err := auth.Register(ctx, params)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down

0 comments on commit 955f86b

Please sign in to comment.