Skip to content

Commit

Permalink
Move config to middleware
Browse files Browse the repository at this point in the history
Signed-off-by: Dipack Panjabi <[email protected]>
  • Loading branch information
dipack95 committed Feb 23, 2025
1 parent 81cdf49 commit 2e074aa
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
34 changes: 0 additions & 34 deletions pkg/log/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package log

import (
"fmt"
"net/http"

"github.com/spf13/pflag"
)
Expand Down Expand Up @@ -60,49 +59,16 @@ type AccessLogHeaderConfig struct {
// Log only these headers.
// You can only define one of Allowlist or Blocklist.
Allowlist []string `json:"allowlist" yaml:"allowlist"`

// In-memory map that is composed right after arguments have
// been validated. This is used to speed up runtime lookups
// for headers.
allowList map[string]string
}

func (c *AccessLogHeaderConfig) Validate() error {
if len(c.Allowlist) > 0 && len(c.Blocklist) > 0 {
return fmt.Errorf("cannot define both allowlist and blocklist")
}

// Create the allow-list header map used to filter headers at runtime.
if len(c.Allowlist) > 0 {
c.allowList = make(map[string]string)
for _, el := range c.Allowlist {
c.allowList[el] = el
}
}
return nil
}

func (c *AccessLogHeaderConfig) Filter(h http.Header) http.Header {
if len(c.Allowlist) > 0 {
for name := range h {
// Use the map created during validation to hasten lookups.
if _, ok := c.allowList[name]; !ok {
h.Del(name)
}
}
return h
}

if len(c.Blocklist) > 0 {
for _, blocked := range c.Blocklist {
h.Del(blocked)
}
return h
}

return h
}

func (c *AccessLogHeaderConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) {
fs.StringSliceVar(
&c.Allowlist,
Expand Down
75 changes: 67 additions & 8 deletions pkg/middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"net/http"
"net/textproto"
"strings"
"time"

Expand All @@ -22,9 +23,21 @@ type loggedRequest struct {
Duration string `json:"duration"`
}

type logHeaderFilter struct {
allowList map[string]string
blockList map[string]string
}

type loggerConfig struct {
RequestHeader logHeaderFilter
ResponseHeader logHeaderFilter
}

// NewLogger creates logging middleware that logs every request.
func NewLogger(accessLogConfig log.AccessLogConfig, logger log.Logger) gin.HandlerFunc {
logger = logger.WithSubsystem(logger.Subsystem() + ".access")
func NewLogger(config log.AccessLogConfig, l log.Logger) gin.HandlerFunc {
l = l.WithSubsystem(l.Subsystem() + ".access")

lc := NewLoggerConfig(config)
return func(c *gin.Context) {
s := time.Now()

Expand All @@ -35,8 +48,8 @@ func NewLogger(accessLogConfig log.AccessLogConfig, logger log.Logger) gin.Handl
return
}

requestHeaders := accessLogConfig.RequestHeaders.Filter(c.Request.Header)
responseHeaders := accessLogConfig.ResponseHeaders.Filter(c.Writer.Header())
requestHeaders := lc.RequestHeader.Filter(c.Request.Header)
responseHeaders := lc.ResponseHeader.Filter(c.Writer.Header())

req := &loggedRequest{
Proto: c.Request.Proto,
Expand All @@ -49,11 +62,57 @@ func NewLogger(accessLogConfig log.AccessLogConfig, logger log.Logger) gin.Handl
Duration: time.Since(s).String(),
}
if c.Writer.Status() >= http.StatusInternalServerError {
logger.Warn("request", zap.Any("request", req))
} else if accessLogConfig.Disable {
logger.Debug("request", zap.Any("request", req))
l.Warn("request", zap.Any("request", req))
} else if config.Disable {
l.Debug("request", zap.Any("request", req))
} else {
logger.Info("request", zap.Any("request", req))
l.Info("request", zap.Any("request", req))
}
}
}

func (l *logHeaderFilter) New(allowList []string, blockList []string) {
if len(allowList) > 0 {
l.allowList = make(map[string]string)
for _, el := range allowList {
h := textproto.CanonicalMIMEHeaderKey(el)
l.allowList[h] = h
}
}

if len(blockList) > 0 {
l.blockList = make(map[string]string)
for _, el := range blockList {
h := textproto.CanonicalMIMEHeaderKey(el)
l.blockList[h] = h
}
}
}

func (l *logHeaderFilter) Filter(h http.Header) http.Header {
if len(l.allowList) > 0 {
for name := range h {
// Use the map created during validation to hasten lookups.
if _, ok := l.allowList[name]; !ok {
h.Del(name)
}
}
return h
}

if len(l.blockList) > 0 {
for _, blocked := range l.blockList {
h.Del(blocked)
}
return h
}

return h
}

func NewLoggerConfig(c log.AccessLogConfig) loggerConfig {

Check warning on line 113 in pkg/middleware/logger.go

View workflow job for this annotation

GitHub Actions / lint

unexported-return: exported func NewLoggerConfig returns unexported type middleware.loggerConfig, which can be annoying to use (revive)
l := loggerConfig{}
l.RequestHeader.New(c.RequestHeaders.Allowlist, c.RequestHeaders.Blocklist)
l.ResponseHeader.New(c.ResponseHeaders.Allowlist, c.ResponseHeaders.Blocklist)
return l
}

0 comments on commit 2e074aa

Please sign in to comment.