diff --git a/parrot/cmd/main.go b/parrot/cmd/main.go index 1a0afb72f..45caa26e1 100644 --- a/parrot/cmd/main.go +++ b/parrot/cmd/main.go @@ -14,11 +14,12 @@ import ( func main() { var ( - port int - debug bool - trace bool - silent bool - json bool + port int + debug bool + trace bool + silent bool + json bool + recorders []string ) rootCmd := &cobra.Command{ @@ -49,6 +50,13 @@ func main() { return err } + for _, r := range recorders { + err = p.Record(r) + if err != nil { + return err + } + } + c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) <-c @@ -65,6 +73,7 @@ func main() { rootCmd.Flags().BoolVarP(&trace, "trace", "t", false, "Enable trace and debug output") rootCmd.Flags().BoolVarP(&silent, "silent", "s", false, "Disable all output") rootCmd.Flags().BoolVarP(&json, "json", "j", false, "Output logs in JSON format") + rootCmd.Flags().StringSliceVarP(&recorders, "recorders", "r", nil, "Existing recorders to use") if err := rootCmd.Execute(); err != nil { log.Error().Err(err).Msg("error executing command") diff --git a/parrot/errors.go b/parrot/errors.go index 767cd5a50..5643fb2a7 100644 --- a/parrot/errors.go +++ b/parrot/errors.go @@ -14,9 +14,9 @@ var ( ErrResponseMarshal = errors.New("unable to marshal response body to JSON") ErrRouteNotFound = errors.New("route not found") - ErrNoRecorderURL = errors.New("no recorder URL specified") - ErrNilRecorder = errors.New("recorder is nil") - ErrRecorderNotFound = errors.New("recorder not found") + ErrNoRecorderURL = errors.New("no recorder URL specified") + ErrInvalidRecorderURL = errors.New("invalid recorder URL") + ErrRecorderNotFound = errors.New("recorder not found") ) // Custom error type to help add more detail to base errors diff --git a/parrot/parrot.go b/parrot/parrot.go index 2772a2a96..2a0ad23e0 100644 --- a/parrot/parrot.go +++ b/parrot/parrot.go @@ -362,20 +362,17 @@ func (p *Server) registerRouteHandler(w http.ResponseWriter, r *http.Request) { } // Record registers a new recorder with the parrot. All incoming requests to the parrot will be sent to the recorder. -func (p *Server) Record(recorder *Recorder) error { +func (p *Server) Record(recorderURL string) error { p.recordersMu.Lock() defer p.recordersMu.Unlock() - if recorder == nil { - return ErrNilRecorder - } - if recorder.URL == "" { + if recorderURL == "" { return ErrNoRecorderURL } - _, err := url.Parse(recorder.URL) + _, err := url.Parse(recorderURL) if err != nil { - return fmt.Errorf("failed to parse recorder URL: %w", err) + return ErrInvalidRecorderURL } - p.recorderHooks = append(p.recorderHooks, recorder.URL) + p.recorderHooks = append(p.recorderHooks, recorderURL) return nil } @@ -394,7 +391,7 @@ func (p *Server) recordHandler(w http.ResponseWriter, r *http.Request) { return } - err := p.Record(recorder) + err := p.Record(recorder.URL) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) recordLogger.Debug().Err(err).Msg("Failed to add recorder") diff --git a/parrot/recorder_test.go b/parrot/recorder_test.go index 091c39794..0872e53b4 100644 --- a/parrot/recorder_test.go +++ b/parrot/recorder_test.go @@ -85,7 +85,7 @@ func TestRecorder(t *testing.T) { recorder, err := NewRecorder() require.NoError(t, err, "error creating recorder") - err = p.Record(recorder) + err = p.Record(recorder.URL) require.NoError(t, err, "error recording parrot") t.Cleanup(func() { require.NoError(t, recorder.Close())