Skip to content

Commit

Permalink
Better middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
kalverra committed Jan 23, 2025
1 parent b523706 commit 42dda5f
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 78 deletions.
17 changes: 10 additions & 7 deletions parrot/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ import (
)

var (
ErrNilRoute = errors.New("route is nil")
ErrNoMethod = errors.New("no method specified")
ErrInvalidPath = errors.New("invalid path")
ErrNoResponse = errors.New("route must have a handler or some response")
ErrOnlyOneResponse = errors.New("route can only have one response type")
ErrResponseMarshal = errors.New("unable to marshal response body to JSON")
ErrRouteNotFound = errors.New("route not found")
ErrNilRoute = errors.New("route is nil")
ErrNoMethod = errors.New("no method specified")
ErrInvalidPath = errors.New("invalid path")
ErrNoResponse = errors.New("route must have a handler or some response")
ErrOnlyOneResponse = errors.New("route can only have one response type")
ErrResponseMarshal = errors.New("unable to marshal response body to JSON")
ErrRouteNotFound = errors.New("route not found")

ErrNoRecorderURL = errors.New("no recorder URL specified")
ErrNilRecorder = errors.New("recorder is nil")
ErrRecorderNotFound = errors.New("recorder not found")
)

Expand Down
2 changes: 2 additions & 0 deletions parrot/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.23.4

require (
github.com/go-resty/resty/v2 v2.16.3
github.com/google/uuid v1.6.0
github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0
Expand All @@ -15,6 +16,7 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect
Expand Down
3 changes: 3 additions & 0 deletions parrot/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/go-resty/resty/v2 v2.16.3 h1:zacNT7lt4b8M/io2Ahj6yPypL7bqx9n1iprfQuodV+E=
github.com/go-resty/resty/v2 v2.16.3/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
Expand All @@ -15,6 +17,7 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
Expand Down
160 changes: 108 additions & 52 deletions parrot/parrot.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (
"time"

"github.com/go-resty/resty/v2"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
)

// Route holds information about the mock route configuration
Expand Down Expand Up @@ -173,7 +175,7 @@ func Wake(options ...ServerOption) (*Server, error) {
if p.jsonLogs {
writers = append(writers, os.Stderr)
} else {
consoleOut := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339Nano}
consoleOut := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02T15:04:05.000"}
writers = append(writers, consoleOut)
}

Expand Down Expand Up @@ -208,7 +210,7 @@ func Wake(options ...ServerOption) (*Server, error) {
p.server = &http.Server{
ReadHeaderTimeout: 5 * time.Second,
Addr: listener.Addr().String(),
Handler: mux,
Handler: p.loggingMiddleware(mux),
}

if err = p.load(); err != nil {
Expand Down Expand Up @@ -290,69 +292,71 @@ func (p *Server) Register(route *Route) error {
p.routesMu.Lock()
defer p.routesMu.Unlock()
p.routes[route.ID()] = route
p.log.Info().
Str("Route ID", route.ID()).
Str("Path", route.Path).
Str("Method", route.Method).
Msg("Route registered")

return nil
}

// registerRouteHandler handles the dynamic route registration.
func (p *Server) registerRouteHandler(w http.ResponseWriter, r *http.Request) {
const parrotPath = "/register"
registerLogger := zerolog.Ctx(r.Context())
if r.Method == http.MethodDelete {
var routeRequest *RouteRequest
if err := json.NewDecoder(r.Body).Decode(&routeRequest); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
registerLogger.Debug().Err(err).Msg("Failed to decode request body")
return
}
defer r.Body.Close()

if routeRequest.ID == "" {
err := errors.New("ID required")
http.Error(w, err.Error(), http.StatusBadRequest)
http.Error(w, "Route ID required", http.StatusBadRequest)
registerLogger.Debug().Msg("No Route ID provided")
return
}

err := p.Unregister(routeRequest.ID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
p.log.Trace().Err(err).Str("Path", parrotPath).Msg("Failed to unregister route")
registerLogger.Debug().Err(err).Msg("Failed to unregister route")
return
}

w.WriteHeader(http.StatusNoContent)
p.log.Info().
registerLogger.Info().
Str("Route ID", routeRequest.ID).
Str("Parrot Path", parrotPath).
Msg("Route unregistered")
} else if r.Method == http.MethodPost {
var route *Route
if err := json.NewDecoder(r.Body).Decode(&route); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
registerLogger.Debug().Err(err).Msg("Failed to decode request body")
return
}
defer r.Body.Close()

if route.Method == "" || route.Path == "" {
err := errors.New("Method and path are required")
http.Error(w, err.Error(), http.StatusBadRequest)
registerLogger.Debug().Err(err).Msg("Method and path are required")
return
}

err := p.Register(route)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
p.log.Trace().Err(err).Msg("Failed to register route")
registerLogger.Debug().Err(err).Msg("Failed to register route")
return
}

w.WriteHeader(http.StatusCreated)
p.log.Info().
Str("Parrot Path", parrotPath).
Str("Route Path", route.Path).
Str("Method", route.Method).
Msg("Route registered")
} else {
http.Error(w, "Invalid method, only use POST or DELETE", http.StatusMethodNotAllowed)
p.log.Trace().Str("Method", r.Method).Msg("Invalid method")
registerLogger.Debug().Msg("Invalid method")
return
}
}
Expand All @@ -361,34 +365,44 @@ func (p *Server) registerRouteHandler(w http.ResponseWriter, r *http.Request) {
func (p *Server) Record(recorder *Recorder) error {
p.recordersMu.Lock()
defer p.recordersMu.Unlock()
if recorder == nil {
return ErrNilRecorder
}
if recorder.URL == "" {
return ErrNoRecorderURL
}
_, err := url.Parse(recorder.URL)
if err != nil {
return fmt.Errorf("failed to parse recorder URL: %w", err)
}
p.recorderHooks = append(p.recorderHooks, recorder.URL)
return nil
}

func (p *Server) recordHandler(w http.ResponseWriter, r *http.Request) {
const parrotPath = "/record"
recordLogger := zerolog.Ctx(r.Context())
if r.Method != http.MethodPost {
http.Error(w, "Invalid method, only use POST or DELETE", http.StatusMethodNotAllowed)
p.log.Trace().Str("Method", r.Method).Msg("Invalid method")
recordLogger.Debug().Msg("Invalid method")
return
}

var recorder *Recorder
if err := json.NewDecoder(r.Body).Decode(&recorder); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
p.log.Trace().Err(err).Str("Parrot Path", parrotPath).Msg("Failed to decode request body")
recordLogger.Err(err).Msg("Failed to decode request body")
return
}

err := p.Record(recorder)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
p.log.Trace().Err(err).Str("Parrot Path", parrotPath).Msg("Failed to add recorder")
recordLogger.Debug().Err(err).Msg("Failed to add recorder")
return
}

w.WriteHeader(http.StatusCreated)
p.log.Info().Str("Recorder URL", recorder.URL).Str("Parrot Path", parrotPath).Msg("Recorder added")
recordLogger.Info().Str("Recorder URL", recorder.URL).Msg("Recorder added")
}

// Unregister removes a route from the parrot
Expand Down Expand Up @@ -423,78 +437,98 @@ func (p *Server) dynamicHandler(w http.ResponseWriter, r *http.Request) {
route, exists := p.routes[r.Method+":"+r.URL.Path]
p.routesMu.RUnlock()

dynamicLogger := zerolog.Ctx(r.Context())
if !exists {
http.NotFound(w, r)
dynamicLogger.Debug().Msg("Route not found")
return
}

requestID := uuid.New().String()[0:8]
dynamicLogger.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("Request ID", requestID).Str("Route ID", route.ID())
})

requestBody, err := io.ReadAll(r.Body)
if err != nil {
dynamicLogger.Debug().
Err(err).
Msg("Failed to read request body")
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}

routeCall := &RouteCall{
RouteID: r.Method + ":" + r.URL.Path,
Request: r,
Request: &RouteCallRequest{
Method: r.Method,
URL: r.URL,
Header: r.Header,
Body: requestBody,
},
}
recordingWriter := newResponseWriterRecorder(w)

defer func() {
routeCall.Response = recordingWriter.Result()
res := recordingWriter.Result()
resBody, err := io.ReadAll(res.Body)
if err != nil {
dynamicLogger.Debug().Err(err).Msg("Failed to read response body")
http.Error(w, "Failed to read response body", http.StatusInternalServerError)
return
}

routeCall.Response = &RouteCallResponse{
StatusCode: res.StatusCode,
Header: res.Header,
Body: resBody,
}
p.sendToRecorders(routeCall)
}()

if !exists { // Route not found
http.NotFound(recordingWriter, r)
p.log.Trace().Str("Remote Addr", r.RemoteAddr).Str("Path", r.URL.Path).Str("Method", r.Method).Msg("Route not found")
return
}

// Let the custom handler take over if it exists
if route.Handler != nil {
p.log.Trace().Str("Remote Addr", r.RemoteAddr).Str("Path", r.URL.Path).Str("Method", r.Method).Msg("Calling route handler")
dynamicLogger.Debug().Msg("Calling route handler")
route.Handler(recordingWriter, r)
return
}

recordingWriter.WriteHeader(route.ResponseStatusCode)

if route.RawResponseBody != "" {
if _, err := w.Write([]byte(route.RawResponseBody)); err != nil {
p.log.Trace().Err(err).Str("Remote Addr", r.RemoteAddr).Str("Path", r.URL.Path).Str("Method", r.Method).Msg("Failed to write response")
dynamicLogger.Debug().Err(err).Msg("Failed to write response")
http.Error(recordingWriter, "Failed to write response", http.StatusInternalServerError)
return
}
p.log.Trace().
Str("Remote Addr", r.RemoteAddr).
dynamicLogger.Debug().
Str("Response", route.RawResponseBody).
Str("Path", r.URL.Path).
Str("Method", r.Method).
Msg("Returned raw response")
recordingWriter.WriteHeader(route.ResponseStatusCode)
return
}

if route.ResponseBody != nil {
rawJSON, err := json.Marshal(route.ResponseBody)
if err != nil {
p.log.Trace().Err(err).
Str("Remote Addr", r.RemoteAddr).
Str("Path", r.URL.Path).
Str("Method", r.Method).
Msg("Failed to marshal JSON response")
dynamicLogger.Debug().Err(err).Msg("Failed to marshal JSON response")
http.Error(recordingWriter, "Failed to marshal response into json", http.StatusInternalServerError)
return
}
if _, err = w.Write(rawJSON); err != nil {
p.log.Trace().Err(err).
dynamicLogger.Debug().Err(err).
RawJSON("Response", rawJSON).
Str("Remote Addr", r.RemoteAddr).
Str("Path", r.URL.Path).
Str("Method", r.Method).
Msg("Failed to write response")
http.Error(recordingWriter, "Failed to write JSON response", http.StatusInternalServerError)
return
}
p.log.Trace().
Str("Remote Addr", r.RemoteAddr).
dynamicLogger.Debug().
RawJSON("Response", rawJSON).
Str("Path", r.URL.Path).
Str("Method", r.Method).
Msg("Returned JSON response")
recordingWriter.WriteHeader(route.ResponseStatusCode)
return
}

p.log.Error().Str("Remote Addr", r.RemoteAddr).Str("Path", r.URL.Path).Str("Method", r.Method).Msg("Route has no response")
dynamicLogger.Error().Msg("Route has no response")
http.Error(recordingWriter, "Route has no response", http.StatusInternalServerError)
}

// load loads all registered routes from a file.
Expand Down Expand Up @@ -554,12 +588,15 @@ func (p *Server) save() error {
func (p *Server) sendToRecorders(routeCall *RouteCall) {
p.recordersMu.RLock()
defer p.recordersMu.RUnlock()
if len(p.recorderHooks) == 0 {
return
}

client := resty.New()
p.log.Trace().Strs("Recorders", p.recorderHooks).Str("Route ID", routeCall.RouteID).Msg("Sending route call to recorders")

for _, hook := range p.recorderHooks {
go func(hook string) {
p.log.Trace().Str("Recorder Hook", hook).Msg("Sending route call to recorder")
resp, err := client.R().SetBody(routeCall).Post(hook)
if err != nil {
p.log.Error().Err(err).Str("Recorder Hook", hook).Msg("Failed to send route call to recorder")
Expand All @@ -573,11 +610,30 @@ func (p *Server) sendToRecorders(routeCall *RouteCall) {
Msg("Failed to send route call to recorder")
return
}
p.log.Debug().Str("Recorder Hook", hook).Msg("Route call sent to recorder")
p.log.Trace().Str("Route ID", routeCall.RouteID).Str("Recorder Hook", hook).Msg("Route call sent to recorder")
}(hook)
}
}

func (p *Server) loggingMiddleware(next http.Handler) http.Handler {
h := hlog.NewHandler(p.log)

accessHandler := hlog.AccessHandler(
func(r *http.Request, status, size int, duration time.Duration) {
hlog.FromRequest(r).Trace().
Str("Method", r.Method).
Stringer("URL", r.URL).
Int("Status Code", status).
Int("Response Size Bytes", size).
Str("Duration", duration.String()).
Str("Remote Addr", r.RemoteAddr).
Msg("Handled request")
},
)

return h(accessHandler(next))
}

var pathRegex = regexp.MustCompile(`^\/[a-zA-Z0-9\-._~%!$&'()*+,;=:@\/]*$`)

func isValidPath(path string) bool {
Expand Down
Loading

0 comments on commit 42dda5f

Please sign in to comment.