diff --git a/go.mod b/go.mod index 30cc2b99..249583f4 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/kyma-project/warden go 1.21 require ( + github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 github.com/docker/distribution v2.8.1+incompatible github.com/fsnotify/fsnotify v1.6.0 github.com/go-logr/zapr v1.2.4 github.com/google/go-containerregistry v0.12.1 github.com/google/uuid v1.3.0 - github.com/kyma-project/kyma/common/logging v0.0.0-20230202094231-eaaef03503fe github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.1 github.com/theupdateframework/notary v0.7.0 @@ -20,6 +20,7 @@ require ( k8s.io/apiextensions-apiserver v0.27.10 k8s.io/apimachinery v0.27.10 k8s.io/client-go v0.27.10 + k8s.io/klog/v2 v2.90.1 k8s.io/utils v0.0.0-20240102154912-e7106e64919e sigs.k8s.io/controller-runtime v0.15.3 ) @@ -50,6 +51,8 @@ require ( github.com/imdario/mergo v0.3.12 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/miekg/pkcs11 v1.0.2 // indirect @@ -64,6 +67,7 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.42.0 // indirect github.com/prometheus/procfs v0.9.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/sirupsen/logrus v1.9.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.0 // indirect @@ -81,7 +85,6 @@ require ( gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect k8s.io/component-base v0.27.10 // indirect - k8s.io/klog/v2 v2.90.1 // indirect k8s.io/kube-openapi v0.0.0-20230501164219-8b0f38b5fd1f // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect diff --git a/go.sum b/go.sum index 707e0234..033f6e7c 100644 --- a/go.sum +++ b/go.sum @@ -153,8 +153,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kyma-project/kyma/common/logging v0.0.0-20230202094231-eaaef03503fe h1:DSJ17/jGDSjCDrMY2YLLn5cPBXXchFc+WJqECQ0d3iw= -github.com/kyma-project/kyma/common/logging v0.0.0-20230202094231-eaaef03503fe/go.mod h1:JGb5RBi8Uz+RZ/jf54+qA+RqY6uPQBJ8pO1w3KSwm1Q= github.com/lib/pq v0.0.0-20150723085316-0dad96c0b94f/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/magiconair/properties v1.5.3 h1:C8fxWnhYyME3n0klPOhVM7PtYUB3eV1W3DeFmN3j53Y= github.com/magiconair/properties v1.5.3/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= @@ -196,6 +194,7 @@ github.com/opencontainers/image-spec v1.1.0-rc2 h1:2zx/Stx4Wc5pIPDvIxHXvXtQFW/7X github.com/opencontainers/image-spec v1.1.0-rc2/go.mod h1:3OVijpioIKYWTqjiG0zfF6wvoJ4fAXGbjdZuI2NgsRQ= github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -214,6 +213,7 @@ github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI= github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/sirupsen/logrus v1.0.6/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= diff --git a/internal/logging/logger/format.go b/internal/logging/logger/format.go new file mode 100644 index 00000000..032e614e --- /dev/null +++ b/internal/logging/logger/format.go @@ -0,0 +1,44 @@ +package logger + +import ( + "errors" + "fmt" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Format string + +const ( + JSON Format = "json" + TEXT Format = "text" +) + +var allFormats = []Format{JSON, TEXT} + +func MapFormat(input string) (Format, error) { + var format = Format(input) + switch format { + case JSON, TEXT: + return format, nil + default: + return format, fmt.Errorf("given log format: %s, doesn't match with any of %v", format, allFormats) + } +} + +func (f Format) ToZapEncoder() (zapcore.Encoder, error) { + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.EncodeTime = zapcore.RFC3339TimeEncoder + encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder + encoderConfig.TimeKey = "timestamp" + encoderConfig.MessageKey = "message" + switch f { + case JSON: + return zapcore.NewJSONEncoder(encoderConfig), nil + case TEXT: + return zapcore.NewConsoleEncoder(encoderConfig), nil + default: + return nil, errors.New("unknown encoder") + } +} diff --git a/internal/logging/logger/format_test.go b/internal/logging/logger/format_test.go new file mode 100644 index 00000000..685da7f7 --- /dev/null +++ b/internal/logging/logger/format_test.go @@ -0,0 +1,54 @@ +package logger_test + +import ( + "testing" + + "github.com/kyma-project/warden/internal/logging/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFormatMapping(t *testing.T) { + + testCases := []struct { + name string + input string + expected logger.Format + expectedErr bool + }{ + { + name: "text format", + input: "text", + expected: logger.TEXT, + expectedErr: false, + }, + { + name: "json format", + input: "json", + expected: logger.JSON, + expectedErr: false, + }, + { + name: "not existing format", + input: "csv", + expectedErr: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + //WHEN + + output, err := logger.MapFormat(testCase.input) + + //THEN + if !testCase.expectedErr { + assert.Equal(t, testCase.expected, output) + require.NoError(t, err) + } else { + require.Error(t, err) + } + + }) + } +} diff --git a/internal/logging/logger/level.go b/internal/logging/logger/level.go new file mode 100644 index 00000000..08901383 --- /dev/null +++ b/internal/logging/logger/level.go @@ -0,0 +1,49 @@ +package logger + +import ( + "errors" + "fmt" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Level string + +const ( + DEBUG Level = "debug" + INFO Level = "info" + WARN Level = "warn" + ERROR Level = "error" + FATAL Level = "fatal" +) + +var allLevels = []Level{DEBUG, INFO, WARN, ERROR, FATAL} + +func MapLevel(level string) (Level, error) { + var lvl = Level(level) + + switch lvl { + case DEBUG, INFO, WARN, ERROR, FATAL: + return lvl, nil + default: + return lvl, fmt.Errorf("given log level: %s, doesn't match with any of %v", level, allLevels) + } +} + +func (l Level) ToZapLevel() (zapcore.Level, error) { + switch l { + case DEBUG: + return zap.DebugLevel, nil + case INFO: + return zap.InfoLevel, nil + case WARN: + return zap.WarnLevel, nil + case ERROR: + return zap.ErrorLevel, nil + case FATAL: + return zap.FatalLevel, nil + default: + return zap.DebugLevel, errors.New("unknown level") + } +} diff --git a/internal/logging/logger/level_test.go b/internal/logging/logger/level_test.go new file mode 100644 index 00000000..57c49d33 --- /dev/null +++ b/internal/logging/logger/level_test.go @@ -0,0 +1,66 @@ +package logger_test + +import ( + "testing" + + "github.com/kyma-project/warden/internal/logging/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLevelMapping(t *testing.T) { + + testCases := []struct { + name string + input string + expected logger.Level + expectedErr bool + }{ + { + name: "debug level", + input: "debug", + expected: logger.DEBUG, + expectedErr: false, + }, + { + name: "info level", + input: "info", + expected: logger.INFO, + expectedErr: false, + }, + { + name: "warn level", + input: "warn", + expected: logger.WARN, + expectedErr: false, + }, + { + name: "error level", + input: "error", + expected: logger.ERROR, + expectedErr: false, + }, + { + name: "not existing level", + input: "level", + expectedErr: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + //WHEN + + output, err := logger.MapLevel(testCase.input) + + //THEN + if !testCase.expectedErr { + assert.Equal(t, testCase.expected, output) + require.NoError(t, err) + } else { + require.Error(t, err) + } + + }) + } +} diff --git a/internal/logging/logger/logger.go b/internal/logging/logger/logger.go new file mode 100644 index 00000000..b17e6366 --- /dev/null +++ b/internal/logging/logger/logger.go @@ -0,0 +1,96 @@ +package logger + +import ( + "context" + "os" + + "github.com/pkg/errors" + + "github.com/go-logr/zapr" + "github.com/kyma-project/warden/internal/logging/tracing" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "k8s.io/klog/v2" +) + +type Logger struct { + zapLogger *zap.SugaredLogger +} + +/* +This function creates logger structure based on given format, atomicLevel and additional cores +AtomicLevel structure allows to change level dynamically +*/ +func NewWithAtomicLevel(format Format, atomicLevel zap.AtomicLevel, additionalCores ...zapcore.Core) (*Logger, error) { + return new(format, atomicLevel, additionalCores...) +} + +/* +This function creates logger structure based on given format, level and additional cores +*/ +func New(format Format, level Level, additionalCores ...zapcore.Core) (*Logger, error) { + filterLevel, err := level.ToZapLevel() + if err != nil { + return nil, errors.Wrap(err, "while getting zap log level") + } + + levelEnabler := zap.LevelEnablerFunc(func(incomingLevel zapcore.Level) bool { + return incomingLevel >= filterLevel + }) + + return new(format, levelEnabler, additionalCores...) +} + +func new(format Format, levelEnabler zapcore.LevelEnabler, additionalCores ...zapcore.Core) (*Logger, error) { + encoder, err := format.ToZapEncoder() + if err != nil { + return nil, errors.Wrapf(err, "while getting encoding configuration for %s format", format) + } + + defaultCore := zapcore.NewCore( + encoder, + zapcore.Lock(os.Stderr), + levelEnabler, + ) + cores := append(additionalCores, defaultCore) + return &Logger{zap.New(zapcore.NewTee(cores...), zap.AddCaller()).Sugar()}, nil +} + +func (l *Logger) WithTracing(ctx context.Context) *zap.SugaredLogger { + newLogger := *l + for key, val := range tracing.GetMetadata(ctx) { + newLogger.zapLogger = newLogger.zapLogger.With(key, val) + } + + return newLogger.WithContext() +} + +func (l *Logger) WithContext() *zap.SugaredLogger { + return l.zapLogger.With(zap.Namespace("context")) +} + +/* +By default the Fatal Error log will be in json format, because it's production default. +*/ +func LogFatalError(format string, args ...interface{}) error { + logger, err := New(JSON, ERROR) + if err != nil { + return errors.Wrap(err, "while getting Error Json Logger") + } + logger.zapLogger.Fatalf(format, args...) + return nil +} + +/* +This function initialize klog which is used in k8s/go-client +*/ +func InitKlog(log *Logger, level Level) error { + zaprLogger := zapr.NewLogger(log.WithContext().Desugar()) + lvl, err := level.ToZapLevel() + if err != nil { + return errors.Wrap(err, "while getting zap log level") + } + zaprLogger.V((int)(lvl)) + klog.SetLogger(zaprLogger) + return nil +} diff --git a/internal/logging/logger/logger_test.go b/internal/logging/logger/logger_test.go new file mode 100644 index 00000000..794738d6 --- /dev/null +++ b/internal/logging/logger/logger_test.go @@ -0,0 +1,207 @@ +package logger_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "os" + "strings" + "testing" + "time" + + "github.com/kyma-project/warden/internal/logging/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +type logEntry struct { + Context map[string]string `json:"context"` + Msg string `json:"message"` + TraceID string `json:"traceid"` + SpanID string `json:"spanid"` + Timestamp string `json:"timestamp"` + Level string `json:"level"` + Caller string `json:"caller"` +} + +func TestLogger(t *testing.T) { + t.Run("should log anything", func(t *testing.T) { + // given + core, observedLogs := observer.New(zap.DebugLevel) + log, err := logger.New(logger.JSON, logger.DEBUG, core) + require.NoError(t, err) + zapLogger := log.WithContext() + // when + zapLogger.Desugar().WithOptions(zap.AddCaller()) + zapLogger.Debug("something") + + // then + require.NotEqual(t, 0, observedLogs.Len()) + t.Log(observedLogs.All()) + }) + + t.Run("should log debug log after changing atomic level", func(t *testing.T) { + // given + atomic := zap.NewAtomicLevel() + atomic.SetLevel(zapcore.WarnLevel) + core, observedLogs := observer.New(atomic) + log, err := logger.NewWithAtomicLevel(logger.JSON, atomic, core) + require.NoError(t, err) + zapLogger := log.WithContext() + + // when + zapLogger.Info("log anything") + require.Equal(t, 0, observedLogs.Len()) + + atomic.SetLevel(zapcore.InfoLevel) + zapLogger.Info("log anything 2") + + // then + require.Equal(t, 1, observedLogs.Len()) + }) + + t.Run("should log in the right json format", func(t *testing.T) { + // GIVEN + oldStdErr := os.Stderr + defer rollbackStderr(oldStdErr) + r, w, err := os.Pipe() + require.NoError(t, err) + os.Stderr = w + + log, err := logger.New(logger.JSON, logger.DEBUG) + require.NoError(t, err) + + ctx := fixContext(map[string]string{"traceid": "trace", "spanid": "span"}) + // WHEN + log.WithTracing(ctx).With("key", "value").Info("example message") + + // THEN + err = w.Close() + require.NoError(t, err) + var buf bytes.Buffer + _, err = io.Copy(&buf, r) + require.NoError(t, err) + + require.NotEqual(t, 0, buf.Len()) + var entry = logEntry{} + strictEncoder := json.NewDecoder(strings.NewReader(buf.String())) + strictEncoder.DisallowUnknownFields() + err = strictEncoder.Decode(&entry) + require.NoError(t, err) + + assert.Equal(t, "INFO", entry.Level) + assert.Equal(t, "example message", entry.Msg) + assert.Equal(t, "trace", entry.TraceID) + assert.Equal(t, "span", entry.SpanID) + assert.Contains(t, entry.Caller, "logger_test.go") + + assert.NotEmpty(t, entry.Timestamp) + _, err = time.Parse(time.RFC3339, entry.Timestamp) + assert.NoError(t, err) + }) + + t.Run("should log in total separation", func(t *testing.T) { + oldStdErr := os.Stderr + defer rollbackStderr(oldStdErr) + r, w, err := os.Pipe() + require.NoError(t, err) + os.Stderr = w + + log, err := logger.New(logger.JSON, logger.DEBUG) + require.NoError(t, err) + ctx := fixContext(map[string]string{"traceid": "trace", "spanid": "span"}) + + // WHEN + log.WithTracing(ctx).With("key", "first").Info("first message") + log.WithContext().With("key", "second").Error("second message") + + // THEN + err = w.Close() + require.NoError(t, err) + var buf bytes.Buffer + _, err = io.Copy(&buf, r) + require.NoError(t, err) + + require.NotEqual(t, 0, buf.Len()) + + logs := strings.Split(buf.String(), "\n") + + require.Len(t, logs, 3) // 3rd line is new empty line + + var infoEntry = logEntry{} + strictEncoder := json.NewDecoder(strings.NewReader(logs[0])) + strictEncoder.DisallowUnknownFields() + err = strictEncoder.Decode(&infoEntry) + require.NoError(t, err) + + assert.Equal(t, "INFO", infoEntry.Level) + assert.Equal(t, "first message", infoEntry.Msg) + assert.EqualValues(t, map[string]string{"key": "first"}, infoEntry.Context, 0.0) + assert.Equal(t, "span", infoEntry.SpanID) + assert.Equal(t, "trace", infoEntry.TraceID) + + assert.NotEmpty(t, infoEntry.Timestamp) + _, err = time.Parse(time.RFC3339, infoEntry.Timestamp) + assert.NoError(t, err) + + strictEncoder = json.NewDecoder(strings.NewReader(logs[1])) + strictEncoder.DisallowUnknownFields() + + var errorEntry = logEntry{} + err = strictEncoder.Decode(&errorEntry) + require.NoError(t, err) + assert.Equal(t, "ERROR", errorEntry.Level) + assert.Equal(t, "second message", errorEntry.Msg) + assert.EqualValues(t, map[string]string{"key": "second"}, errorEntry.Context, 0.0) + assert.Empty(t, errorEntry.SpanID) + assert.Empty(t, errorEntry.TraceID) + + assert.NotEmpty(t, errorEntry.Timestamp) + _, err = time.Parse(time.RFC3339, errorEntry.Timestamp) + assert.NoError(t, err) + }) + + t.Run("with context should create new logger", func(t *testing.T) { + //GIVEN + log, err := logger.New(logger.TEXT, logger.INFO) + require.NoError(t, err) + //WHEN + firstLogger := log.WithContext() + secondLogger := log.WithContext() + + //THEN + assert.NotSame(t, firstLogger, secondLogger) + }) + + t.Run("with tracing should create new logger", func(t *testing.T) { + //GIVEN + log, err := logger.New(logger.TEXT, logger.INFO) + require.NoError(t, err) + ctx := fixContext(map[string]string{"traceid": "trace", "spanid": "span"}) + + //WHEN + firstLogger := log.WithTracing(ctx) + secondLogger := log.WithTracing(ctx) + + //THEN + assert.NotSame(t, firstLogger, secondLogger) + }) +} + +func fixContext(values map[string]string) context.Context { + ctx := context.TODO() + for k, v := range values { + //nolint:staticcheck + ctx = context.WithValue(ctx, k, v) + } + + return ctx +} + +func rollbackStderr(oldStdErr *os.File) { + os.Stderr = oldStdErr +} diff --git a/internal/logging/setup.go b/internal/logging/setup.go index 8da3191b..823b01c7 100644 --- a/internal/logging/setup.go +++ b/internal/logging/setup.go @@ -1,7 +1,7 @@ package logging import ( - "github.com/kyma-project/kyma/common/logging/logger" + "github.com/kyma-project/warden/internal/logging/logger" "github.com/pkg/errors" "go.uber.org/zap" ) diff --git a/internal/logging/tracing/helper.go b/internal/logging/tracing/helper.go new file mode 100644 index 00000000..a554f218 --- /dev/null +++ b/internal/logging/tracing/helper.go @@ -0,0 +1,17 @@ +package tracing + +import "context" + +func GetMetadata(ctx context.Context) map[string]string { + m := map[string]string{ + TRACE_KEY: UNKNOWN_VALUE, + SPAN_KEY: UNKNOWN_VALUE, + } + if val, ok := ctx.Value(TRACE_KEY).(string); ok { + m[TRACE_KEY] = val + } + if val, ok := ctx.Value(SPAN_KEY).(string); ok { + m[SPAN_KEY] = val + } + return m +} diff --git a/internal/logging/tracing/helper_test.go b/internal/logging/tracing/helper_test.go new file mode 100644 index 00000000..95441a7a --- /dev/null +++ b/internal/logging/tracing/helper_test.go @@ -0,0 +1,46 @@ +package tracing_test + +import ( + "context" + "testing" + + "github.com/kyma-project/warden/internal/logging/tracing" + + "github.com/bmizerany/assert" +) + +func TestGetMetadata(t *testing.T) { + t.Run("context with values", func(t *testing.T) { + //GIVEN + ctx := fixContext(map[string]string{tracing.TRACE_KEY: "mytrace", tracing.SPAN_KEY: "myspan"}) + + //WHEN + out := tracing.GetMetadata(ctx) + + //THEN + assert.Equal(t, "mytrace", out[tracing.TRACE_KEY]) + assert.Equal(t, "myspan", out[tracing.SPAN_KEY]) + }) + + t.Run("context without values", func(t *testing.T) { + ctx := context.TODO() + + //WHEN + out := tracing.GetMetadata(ctx) + + //THEN + assert.Equal(t, tracing.UNKNOWN_VALUE, out[tracing.TRACE_KEY]) + assert.Equal(t, tracing.UNKNOWN_VALUE, out[tracing.SPAN_KEY]) + }) + +} + +func fixContext(values map[string]string) context.Context { + ctx := context.TODO() + for k, v := range values { + //nolint:staticcheck + ctx = context.WithValue(ctx, k, v) + } + + return ctx +} diff --git a/internal/logging/tracing/middleware.go b/internal/logging/tracing/middleware.go new file mode 100644 index 00000000..e00ee4d3 --- /dev/null +++ b/internal/logging/tracing/middleware.go @@ -0,0 +1,42 @@ +package tracing + +import ( + "context" + "net/http" + "strings" +) + +const ( + SPAN_HEADER_KEY = "X-B3-Spanid" + TRACE_HEADER_KEY = "X-B3-Traceid" + TRACE_KEY = "traceid" + SPAN_KEY = "spanid" + UNKNOWN_VALUE = "unknown" +) + +type tracingMiddleware struct { + handler func(w http.ResponseWriter, r *http.Request) +} + +func NewTracingMiddleware(handler func(w http.ResponseWriter, r *http.Request)) http.Handler { + return &tracingMiddleware{ + handler: handler, + } +} + +func (m *tracingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + newCtx := addHeaderToCtx(r.Context(), r.Header, TRACE_HEADER_KEY, TRACE_KEY) + newCtx = addHeaderToCtx(newCtx, r.Header, SPAN_HEADER_KEY, SPAN_KEY) + + m.handler(w, r.WithContext(newCtx)) +} + +func addHeaderToCtx(ctx context.Context, headers http.Header, key string, desiredKey string) context.Context { + header, ok := headers[key] + if !ok { + return ctx + } + value := strings.Join(header, ";") + //nolint:staticcheck + return context.WithValue(ctx, desiredKey, value) +} diff --git a/internal/logging/tracing/middleware_test.go b/internal/logging/tracing/middleware_test.go new file mode 100644 index 00000000..ffacf6ff --- /dev/null +++ b/internal/logging/tracing/middleware_test.go @@ -0,0 +1,55 @@ +package tracing_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/kyma-project/warden/internal/logging/tracing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMiddleware(t *testing.T) { + t.Run("with trace and span in header, should put traceid and spanid to context", func(t *testing.T) { + //GIVEN + var outRequest *http.Request + middleware := tracing.NewTracingMiddleware(func(w http.ResponseWriter, r *http.Request) { + outRequest = r + }) + resp := httptest.NewRecorder() + + r, err := http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + r.Header[tracing.TRACE_HEADER_KEY] = []string{"mytrace"} + r.Header[tracing.SPAN_HEADER_KEY] = []string{"myspan"} + + //WHEN + middleware.ServeHTTP(resp, r) + + //THEN + ctx := outRequest.Context() + assert.Equal(t, "myspan", ctx.Value(tracing.SPAN_KEY)) + assert.Equal(t, "mytrace", ctx.Value(tracing.TRACE_KEY)) + }) + + t.Run("wihtout trace and span should not change the context", func(t *testing.T) { + //GIVEN + var enhancedRequest *http.Request + middleware := tracing.NewTracingMiddleware(func(w http.ResponseWriter, r *http.Request) { + enhancedRequest = r + }) + resp := httptest.NewRecorder() + + r, err := http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + //WHEN + middleware.ServeHTTP(resp, r) + + //THEN + ctx := enhancedRequest.Context() + assert.Equal(t, ctx, r.Context()) + }) +}