From 9e3645f0edb814b220d74b3b97dc659e1eef724f Mon Sep 17 00:00:00 2001 From: Engin Date: Sat, 5 Oct 2024 23:07:34 +0300 Subject: [PATCH 1/2] Change error control mechanism to can check wrapped errors --- echo_test.go | 2 +- middleware/csrf.go | 12 +++++++----- middleware/jwt_test.go | 2 +- middleware/key_auth.go | 14 +++++++------- middleware/proxy.go | 3 ++- middleware/timeout.go | 3 ++- middleware/timeout_test.go | 2 +- 7 files changed, 21 insertions(+), 17 deletions(-) diff --git a/echo_test.go b/echo_test.go index b7f32017a..fd50ceb39 100644 --- a/echo_test.go +++ b/echo_test.go @@ -915,7 +915,7 @@ func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { return nil // was started } case err := <-errChan: - if err == http.ErrServerClosed { + if errors.Is(err, http.ErrServerClosed) { return nil } return err diff --git a/middleware/csrf.go b/middleware/csrf.go index 92f4019dc..62cb7f071 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -5,6 +5,7 @@ package middleware import ( "crypto/subtle" + "errors" "net/http" "time" @@ -163,16 +164,17 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if lastTokenErr != nil { finalErr = lastTokenErr } else if lastExtractorErr != nil { - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { + switch { + case errors.Is(lastExtractorErr, errQueryExtractorValueMissing): lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") - } else if lastExtractorErr == errFormExtractorValueMissing { + case errors.Is(lastExtractorErr, errFormExtractorValueMissing): lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") - } else if lastExtractorErr == errHeaderExtractorValueMissing { + case errors.Is(lastExtractorErr, errHeaderExtractorValueMissing): lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") - } else { + default: lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) } + finalErr = lastExtractorErr } diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index bbe4b8808..aea5e8152 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -757,7 +757,7 @@ func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, SigningKey: []byte("secret"), ErrorHandlerWithContext: func(err error, c echo.Context) error { - if err == ErrJWTMissing { + if errors.Is(err, ErrJWTMissing) { c.Set("test", "public-token") return nil } diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 79bee207c..6819376ed 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -142,18 +142,18 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { // we are here only when we did not successfully extract and validate any of keys err := lastValidatorErr if err == nil { // prioritize validator errors over extracting errors - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { + switch { + case errors.Is(err, errQueryExtractorValueMissing): err = errors.New("missing key in the query string") - } else if lastExtractorErr == errCookieExtractorValueMissing { + case errors.Is(err, errCookieExtractorValueMissing): err = errors.New("missing key in cookies") - } else if lastExtractorErr == errFormExtractorValueMissing { + case errors.Is(err, errFormExtractorValueMissing): err = errors.New("missing key in the form") - } else if lastExtractorErr == errHeaderExtractorValueMissing { + case errors.Is(err, errHeaderExtractorValueMissing): err = errors.New("missing key in request header") - } else if lastExtractorErr == errHeaderExtractorValueInvalid { + case errors.Is(err, errHeaderExtractorValueInvalid): err = errors.New("invalid key in the request header") - } else { + default: err = lastExtractorErr } err = &ErrKeyAuthMissing{Err: err} diff --git a/middleware/proxy.go b/middleware/proxy.go index 495970aca..f836319c4 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -5,6 +5,7 @@ package middleware import ( "context" + "errors" "fmt" "io" "math/rand" @@ -405,7 +406,7 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle // The Go standard library (at of late 2020) wraps the exported, standard // context.Canceled error with unexported garbage value requiring a substring check, see // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 - if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { + if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") { httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err)) httpError.Internal = err c.Set("_error", httpError) diff --git a/middleware/timeout.go b/middleware/timeout.go index c2aebef30..04d018a30 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -5,6 +5,7 @@ package middleware import ( "context" + "errors" "github.com/labstack/echo/v4" "net/http" "sync" @@ -163,7 +164,7 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques }() err := t.handler(t.ctx) - if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { + if ctxErr := r.Context().Err(); errors.Is(ctxErr, context.DeadlineExceeded) { if err != nil && t.errHandler != nil { t.errHandler(err, t.ctx) } diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index e8415d636..a17d3fc8c 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -478,7 +478,7 @@ func startServer(e *echo.Echo) (*http.Server, string, error) { errCh := make(chan error) go func() { - if err := s.Serve(l); err != http.ErrServerClosed { + if err := s.Serve(l); !errors.Is(err, http.ErrServerClosed) { errCh <- err } }() From d90129008919f3818cfc199dc14e464654ab76fa Mon Sep 17 00:00:00 2001 From: Engin Date: Sun, 6 Oct 2024 16:11:38 +0300 Subject: [PATCH 2/2] Set limit for arrays --- middleware/extractor.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/middleware/extractor.go b/middleware/extractor.go index 3f2741407..1dca63257 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -52,7 +52,7 @@ func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, err return nil, nil } sources := strings.Split(lookups, ",") - var extractors = make([]ValuesExtractor, 0) + var extractors = make([]ValuesExtractor, 0, len(sources)) for _, source := range sources { parts := strings.Split(source, ":") if len(parts) < 2 { @@ -104,7 +104,7 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { return nil, errHeaderExtractorValueMissing } - result := make([]string, 0) + result := make([]string, 0, len(values)) for i, value := range values { if prefixLen == 0 { result = append(result, value) @@ -147,7 +147,7 @@ func valuesFromQuery(param string) ValuesExtractor { // valuesFromParam returns a function that extracts values from the url param string. func valuesFromParam(param string) ValuesExtractor { return func(c echo.Context) ([]string, error) { - result := make([]string, 0) + result := make([]string, 0, len(c.ParamNames())) paramVales := c.ParamValues() for i, p := range c.ParamNames() { if param == p { @@ -172,7 +172,7 @@ func valuesFromCookie(name string) ValuesExtractor { return nil, errCookieExtractorValueMissing } - result := make([]string, 0) + result := make([]string, 0, len(cookies)) for i, cookie := range cookies { if name == cookie.Name { result = append(result, cookie.Value)