From 573f147c6eded960da3f8d1c7d45751d4799d2f0 Mon Sep 17 00:00:00 2001 From: Shulin Jia Date: Thu, 6 Jun 2024 23:07:38 -0700 Subject: [PATCH] remove 1 level of recursion from the context metadata --- pkg/utils/gwlog/gwlog.go | 24 ++++++++++++------------ pkg/utils/gwlog/metadata.go | 27 +++++++++++++++------------ pkg/utils/gwlog/metadata_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 24 deletions(-) create mode 100644 pkg/utils/gwlog/metadata_test.go diff --git a/pkg/utils/gwlog/gwlog.go b/pkg/utils/gwlog/gwlog.go index 4e1146a8..9e500793 100644 --- a/pkg/utils/gwlog/gwlog.go +++ b/pkg/utils/gwlog/gwlog.go @@ -20,14 +20,14 @@ func (s *TracedLogger) Infoln(args ...interface{}) { func (t *TracedLogger) Infow(ctx context.Context, msg string, keysAndValues ...interface{}) { if GetTrace(ctx) != "" { - keysAndValues = append(keysAndValues, string(traceID), GetTrace(ctx)) + keysAndValues = append(keysAndValues, traceID, GetTrace(ctx)) } t.InnerLogger.Infow(msg, keysAndValues...) } func (t *TracedLogger) Infof(ctx context.Context, template string, args ...interface{}) { if GetTrace(ctx) != "" { - t.InnerLogger.Infow(fmt.Sprintf(template, args...), string(traceID), GetTrace(ctx)) + t.InnerLogger.Infow(fmt.Sprintf(template, args...), traceID, GetTrace(ctx)) return } t.InnerLogger.Infof(template, args...) @@ -35,7 +35,7 @@ func (t *TracedLogger) Infof(ctx context.Context, template string, args ...inter func (t *TracedLogger) Info(ctx context.Context, msg string) { if GetTrace(ctx) != "" { - t.InnerLogger.Infow(msg, string(traceID), GetTrace(ctx)) + t.InnerLogger.Infow(msg, traceID, GetTrace(ctx)) return } t.InnerLogger.Info(msg) @@ -43,14 +43,14 @@ func (t *TracedLogger) Info(ctx context.Context, msg string) { func (t *TracedLogger) Errorw(ctx context.Context, msg string, keysAndValues ...interface{}) { if GetTrace(ctx) != "" { - keysAndValues = append(keysAndValues, string(traceID), GetTrace(ctx)) + keysAndValues = append(keysAndValues, traceID, GetTrace(ctx)) } t.InnerLogger.Errorw(msg, keysAndValues) } func (t *TracedLogger) Errorf(ctx context.Context, template string, args ...interface{}) { if GetTrace(ctx) != "" { - t.InnerLogger.Errorw(fmt.Sprintf(template, args...), string(traceID), GetTrace(ctx)) + t.InnerLogger.Errorw(fmt.Sprintf(template, args...), traceID, GetTrace(ctx)) return } t.InnerLogger.Errorf(template, args...) @@ -58,7 +58,7 @@ func (t *TracedLogger) Errorf(ctx context.Context, template string, args ...inte func (t *TracedLogger) Error(ctx context.Context, msg string) { if GetTrace(ctx) != "" { - t.InnerLogger.Errorw(msg, string(traceID), GetTrace(ctx)) + t.InnerLogger.Errorw(msg, traceID, GetTrace(ctx)) return } t.InnerLogger.Error(msg) @@ -66,14 +66,14 @@ func (t *TracedLogger) Error(ctx context.Context, msg string) { func (t *TracedLogger) Debugw(ctx context.Context, msg string, keysAndValues ...interface{}) { if GetTrace(ctx) != "" { - keysAndValues = append(keysAndValues, string(traceID), GetTrace(ctx)) + keysAndValues = append(keysAndValues, traceID, GetTrace(ctx)) } t.InnerLogger.Debugw(msg, keysAndValues...) } func (t *TracedLogger) Debugf(ctx context.Context, template string, args ...interface{}) { if GetTrace(ctx) != "" { - t.InnerLogger.Debugw(fmt.Sprintf(template, args...), string(traceID), GetTrace(ctx)) + t.InnerLogger.Debugw(fmt.Sprintf(template, args...), traceID, GetTrace(ctx)) return } t.InnerLogger.Debugf(template, args...) @@ -81,7 +81,7 @@ func (t *TracedLogger) Debugf(ctx context.Context, template string, args ...inte func (t *TracedLogger) Debug(ctx context.Context, msg string) { if GetTrace(ctx) != "" { - t.InnerLogger.Debugw(msg, string(traceID), GetTrace(ctx)) + t.InnerLogger.Debugw(msg, traceID, GetTrace(ctx)) return } t.InnerLogger.Debug(msg) @@ -89,14 +89,14 @@ func (t *TracedLogger) Debug(ctx context.Context, msg string) { func (t *TracedLogger) Warnw(ctx context.Context, msg string, keysAndValues ...interface{}) { if GetTrace(ctx) != "" { - keysAndValues = append(keysAndValues, string(traceID), GetTrace(ctx)) + keysAndValues = append(keysAndValues, traceID, GetTrace(ctx)) } t.InnerLogger.Warnw(msg, keysAndValues...) } func (t *TracedLogger) Warnf(ctx context.Context, template string, args ...interface{}) { if GetTrace(ctx) != "" { - t.InnerLogger.Warnw(fmt.Sprintf(template, args...), string(traceID), GetTrace(ctx)) + t.InnerLogger.Warnw(fmt.Sprintf(template, args...), traceID, GetTrace(ctx)) return } t.InnerLogger.Warnf(template, args...) @@ -104,7 +104,7 @@ func (t *TracedLogger) Warnf(ctx context.Context, template string, args ...inter func (t *TracedLogger) Warn(ctx context.Context, msg string) { if GetTrace(ctx) != "" { - t.InnerLogger.Warnw(msg, string(traceID), GetTrace(ctx)) + t.InnerLogger.Warnw(msg, traceID, GetTrace(ctx)) return } t.InnerLogger.Warn(msg) diff --git a/pkg/utils/gwlog/metadata.go b/pkg/utils/gwlog/metadata.go index fb2dba46..5798a868 100644 --- a/pkg/utils/gwlog/metadata.go +++ b/pkg/utils/gwlog/metadata.go @@ -7,9 +7,10 @@ import ( type key string -const traceID key = "trace_id" const metadata key = "metadata" +const traceID string = "trace_id" + type metadataValue struct { m map[string]string } @@ -30,7 +31,11 @@ func newMetadata() *metadataValue { func NewTrace(ctx context.Context) context.Context { currID := uuid.New() - return context.WithValue(context.WithValue(ctx, traceID, currID.String()), metadata, newMetadata()) + + newCtx := context.WithValue(ctx, metadata, newMetadata()) + AddMetadata(newCtx, traceID, currID.String()) + + return newCtx } func AddMetadata(ctx context.Context, key, value string) { @@ -41,12 +46,7 @@ func AddMetadata(ctx context.Context, key, value string) { func GetMetadata(ctx context.Context) []interface{} { var fields []interface{} - /* - if ctx.Value(traceID) != nil { - fields = append(fields, string(traceID)) - fields = append(fields, ctx.Value(traceID)) - } - */ + if ctx.Value(metadata) != nil { for k, v := range ctx.Value(metadata).(*metadataValue).m { fields = append(fields, k) @@ -57,9 +57,12 @@ func GetMetadata(ctx context.Context) []interface{} { } func GetTrace(ctx context.Context) string { - t := ctx.Value(traceID) - if t == nil { - return "" + if ctx.Value(metadata) != nil { + m := ctx.Value(metadata).(*metadataValue).m + if m == nil { + return "" + } + return ctx.Value(metadata).(*metadataValue).m[traceID] } - return t.(string) + return "" } diff --git a/pkg/utils/gwlog/metadata_test.go b/pkg/utils/gwlog/metadata_test.go new file mode 100644 index 00000000..7d33ae49 --- /dev/null +++ b/pkg/utils/gwlog/metadata_test.go @@ -0,0 +1,32 @@ +package gwlog + +import ( + "context" + "fmt" + "testing" +) + +func TestGetTrace(t *testing.T) { + if GetTrace(context.TODO()) != "" { + t.Errorf("expected context with no trace_id to return empty string") + } + + if GetTrace(NewTrace(context.TODO())) == "" { + t.Errorf("expected context with trace_id to return non-empty string") + } +} + +func TestMetadata(t *testing.T) { + ctx := NewTrace(context.TODO()) + AddMetadata(ctx, "foo", "bar") + + md := GetMetadata(ctx) + mdMap := map[string]bool{} + for _, m := range md { + mdMap[fmt.Sprint(m)] = true + } + + if !mdMap["foo"] || !mdMap["bar"] { + t.Errorf("expected context to have metadata with key foo and val bar, got %s", md) + } +}