From 59a4ff93114c005ae0e310fa6de5a30e184654b2 Mon Sep 17 00:00:00 2001 From: Bob Stasyszyn Date: Fri, 1 Nov 2024 14:03:44 -0400 Subject: [PATCH] feat: Add function to set correlation ID Signed-off-by: Bob Stasyszyn --- pkg/otel/correlationid/correlationid.go | 86 +++++++++++++++++ .../correlationid_test.go} | 35 +++++-- .../correlationidtransport.go | 92 ------------------- 3 files changed, 112 insertions(+), 101 deletions(-) create mode 100644 pkg/otel/correlationid/correlationid.go rename pkg/otel/{correlationidtransport/correlationidtransport_test.go => correlationid/correlationid_test.go} (58%) delete mode 100644 pkg/otel/correlationidtransport/correlationidtransport.go diff --git a/pkg/otel/correlationid/correlationid.go b/pkg/otel/correlationid/correlationid.go new file mode 100644 index 0000000..60e5185 --- /dev/null +++ b/pkg/otel/correlationid/correlationid.go @@ -0,0 +1,86 @@ +/* +Copyright Gen Digital Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package correlationid + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "strings" + + "go.opentelemetry.io/otel/trace" + + "github.com/trustbloc/logutil-go/pkg/otel/api" +) + +const ( + nilTraceID = "00000000000000000000000000000000" + correlationIDLength = 8 +) + +type contextKey struct{} + +// Set derives the correlation ID from the OpenTelemetry trace ID and sets it on the returned context. +// If no trace ID is available, a random correlation ID is generated. +func Set(ctx context.Context) (context.Context, string, error) { + var correlationID string + + traceID := trace.SpanFromContext(ctx).SpanContext().TraceID().String() + if traceID != "" && traceID != nilTraceID { + correlationID = deriveID(traceID) + } else { + var err error + correlationID, err = generateID() + if err != nil { + return nil, "", fmt.Errorf("generate correlation ID: %w", err) + } + } + + return context.WithValue(ctx, contextKey{}, correlationID), correlationID, nil +} + +// Transport is an HTTP RoundTripper that adds a correlation ID to the request header. +type Transport struct { + defaultTransport http.RoundTripper +} + +// NewHTTPTransport creates a new HTTP Transport. +func NewHTTPTransport(defaultTransport http.RoundTripper) *Transport { + return &Transport{ + defaultTransport: defaultTransport, + } +} + +// RoundTrip executes a single HTTP transaction. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + correlationID, ok := req.Context().Value(contextKey{}).(string) + if ok { + req = req.Clone(req.Context()) + req.Header.Add(api.CorrelationIDHeader, correlationID) + } + + return t.defaultTransport.RoundTrip(req) +} + +func generateID() (string, error) { + bytes := make([]byte, correlationIDLength/2) //nolint:gomnd + + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + return strings.ToUpper(hex.EncodeToString(bytes)), nil +} + +func deriveID(id string) string { + hash := sha256.Sum256([]byte(id)) + + return strings.ToUpper(hex.EncodeToString(hash[:correlationIDLength/2])) //nolint:gomnd +} diff --git a/pkg/otel/correlationidtransport/correlationidtransport_test.go b/pkg/otel/correlationid/correlationid_test.go similarity index 58% rename from pkg/otel/correlationidtransport/correlationidtransport_test.go rename to pkg/otel/correlationid/correlationid_test.go index 2aa06f7..f935668 100644 --- a/pkg/otel/correlationidtransport/correlationidtransport_test.go +++ b/pkg/otel/correlationid/correlationid_test.go @@ -4,7 +4,7 @@ Copyright Gen Digital Inc. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 */ -package correlationidtransport +package correlationid import ( "context" @@ -19,17 +19,20 @@ import ( ) func TestTransport_RoundTrip(t *testing.T) { - var rt mockRoundTripperFunc = func(req *http.Request) (*http.Response, error) { - correlationID := req.Header.Get(api.CorrelationIDHeader) + t.Run("No span", func(t *testing.T) { + var rt mockRoundTripperFunc = func(req *http.Request) (*http.Response, error) { + require.Len(t, req.Header.Get(api.CorrelationIDHeader), 8) - require.Len(t, correlationID, 8) - return &http.Response{}, nil - } + return &http.Response{}, nil + } - transport := New(rt, WithCorrelationIDLength(8)) + transport := NewHTTPTransport(rt) - t.Run("No span", func(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil) + ctx, correlationID, err := Set(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, correlationID) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) require.NoError(t, err) resp, err := transport.RoundTrip(req) @@ -38,6 +41,16 @@ func TestTransport_RoundTrip(t *testing.T) { }) t.Run("With span", func(t *testing.T) { + var correlationID string + + var rt mockRoundTripperFunc = func(req *http.Request) (*http.Response, error) { + require.Equal(t, correlationID, req.Header.Get(api.CorrelationIDHeader)) + + return &http.Response{}, nil + } + + transport := NewHTTPTransport(rt) + tp := trace.NewTracerProvider() otel.SetTracerProvider(tp) @@ -45,6 +58,10 @@ func TestTransport_RoundTrip(t *testing.T) { ctx, span := tp.Tracer("test").Start(context.Background(), "test") require.NotNil(t, span) + var err error + ctx, correlationID, err = Set(ctx) + require.NoError(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) require.NoError(t, err) diff --git a/pkg/otel/correlationidtransport/correlationidtransport.go b/pkg/otel/correlationidtransport/correlationidtransport.go deleted file mode 100644 index 248f5b2..0000000 --- a/pkg/otel/correlationidtransport/correlationidtransport.go +++ /dev/null @@ -1,92 +0,0 @@ -/* -Copyright Gen Digital Inc. All Rights Reserved. - -SPDX-License-Identifier: Apache-2.0 -*/ - -package correlationidtransport - -import ( - "crypto/rand" - "crypto/sha256" - "encoding/hex" - "fmt" - "net/http" - "strings" - - "go.opentelemetry.io/otel/trace" - - "github.com/trustbloc/logutil-go/pkg/otel/api" -) - -const ( - nilTraceID = "00000000000000000000000000000000" - defaultCorrelationIDLength = 8 -) - -// Transport is an http.RoundTripper that adds a correlation ID to the request. -type Transport struct { - defaultTransport http.RoundTripper - correlationIDLength int -} - -type Opt func(*Transport) - -// WithCorrelationIDLength sets the length of the correlation ID. -func WithCorrelationIDLength(length int) Opt { - return func(t *Transport) { - t.correlationIDLength = length - } -} - -// New creates a new Transport. -func New(defaultTransport http.RoundTripper, opts ...Opt) *Transport { - t := &Transport{ - defaultTransport: defaultTransport, - correlationIDLength: defaultCorrelationIDLength, - } - - for _, opt := range opts { - opt(t) - } - - return t -} - -// RoundTrip executes a single HTTP transaction. -func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - var correlationID string - - span := trace.SpanFromContext(req.Context()) - - traceID := span.SpanContext().TraceID().String() - if traceID == "" || traceID == nilTraceID { - var err error - correlationID, err = t.generateID() - if err != nil { - return nil, fmt.Errorf("generate correlation ID: %w", err) - } - } else { - correlationID = t.shortenID(traceID) - } - - clonedReq := req.Clone(req.Context()) - clonedReq.Header.Add(api.CorrelationIDHeader, correlationID) - - return t.defaultTransport.RoundTrip(clonedReq) -} - -func (t *Transport) generateID() (string, error) { - bytes := make([]byte, t.correlationIDLength/2) //nolint:gomnd - - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - return strings.ToUpper(hex.EncodeToString(bytes)), nil -} - -func (t *Transport) shortenID(id string) string { - hash := sha256.Sum256([]byte(id)) - return strings.ToUpper(hex.EncodeToString(hash[:t.correlationIDLength/2])) //nolint:gomnd -}