From b369e7295efab4448f084e604b38e53c62c62d62 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Thu, 27 Jun 2024 15:08:14 +1000 Subject: [PATCH] Add a dispatchhttp package The package wraps http.Request and http.Response so that requests and responses are serializable, and so that when a function returns a response, its status is set correctly based on the HTTP status code. --- dispatchhttp/client.go | 42 ++++++++++ dispatchhttp/header.go | 18 +++++ dispatchhttp/http_test.go | 159 ++++++++++++++++++++++++++++++++++++++ dispatchhttp/request.go | 43 +++++++++++ dispatchhttp/response.go | 104 +++++++++++++++++++++++++ 5 files changed, 366 insertions(+) create mode 100644 dispatchhttp/client.go create mode 100644 dispatchhttp/header.go create mode 100644 dispatchhttp/http_test.go create mode 100644 dispatchhttp/request.go create mode 100644 dispatchhttp/response.go diff --git a/dispatchhttp/client.go b/dispatchhttp/client.go new file mode 100644 index 0000000..55de191 --- /dev/null +++ b/dispatchhttp/client.go @@ -0,0 +1,42 @@ +package dispatchhttp + +import ( + "bytes" + "context" + "net/http" +) + +// Client wraps an http.Client to accept Request instances +// and return Response instances. +type Client struct{ Client *http.Client } + +// DefaultClient is the default client. +var DefaultClient = &Client{Client: http.DefaultClient} + +// Get makes an HTTP GET request to the specified URL and returns +// its Response. +func (c *Client) Get(ctx context.Context, url string) (*Response, error) { + req := &Request{Method: "GET", URL: url} + return c.Do(ctx, req) +} + +// Get makes an HTTP GET request to the specified URL and returns +// its Response. +func Get(ctx context.Context, url string) (*Response, error) { + return DefaultClient.Get(ctx, url) +} + +// Do makes a HTTP Request and returns its Response. +func (c *Client) Do(ctx context.Context, r *Request) (*Response, error) { + httpReq, err := http.NewRequestWithContext(ctx, r.Method, r.URL, bytes.NewReader(r.Body)) + if err != nil { + return nil, err + } + copyHeader(httpReq.Header, r.Header) + + httpRes, err := c.Client.Do(httpReq) + if err != nil { + return nil, err + } + return FromResponse(httpRes) +} diff --git a/dispatchhttp/header.go b/dispatchhttp/header.go new file mode 100644 index 0000000..1aa4069 --- /dev/null +++ b/dispatchhttp/header.go @@ -0,0 +1,18 @@ +package dispatchhttp + +import ( + "net/http" + "slices" +) + +func cloneHeader(h http.Header) http.Header { + c := make(http.Header, len(h)) + copyHeader(c, h) + return c +} + +func copyHeader(dst, src http.Header) { + for name, values := range src { + dst[name] = slices.Clone(values) + } +} diff --git a/dispatchhttp/http_test.go b/dispatchhttp/http_test.go new file mode 100644 index 0000000..d3d7b6e --- /dev/null +++ b/dispatchhttp/http_test.go @@ -0,0 +1,159 @@ +package dispatchhttp_test + +import ( + "net/http" + "strconv" + "testing" + + "github.com/dispatchrun/dispatch-go/dispatchhttp" + "github.com/dispatchrun/dispatch-go/dispatchproto" + "github.com/google/go-cmp/cmp" +) + +func TestSerializable(t *testing.T) { + t.Run("request", func(t *testing.T) { + req := &dispatchhttp.Request{ + Method: "GET", + URL: "http://example.com", + Header: http.Header{"X-Foo": []string{"bar"}}, + Body: []byte("abc"), + } + boxed, err := dispatchproto.Marshal(req) + if err != nil { + t.Fatal(err) + } + var req2 *dispatchhttp.Request + if err := boxed.Unmarshal(&req2); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, req2); diff != "" { + t.Errorf("invalid request: %v", diff) + } + }) + + t.Run("response", func(t *testing.T) { + res := &dispatchhttp.Response{ + StatusCode: 200, + Header: http.Header{"X-Foo": []string{"bar"}}, + Body: []byte("abc"), + } + boxed, err := dispatchproto.Marshal(res) + if err != nil { + t.Fatal(err) + } + var res2 *dispatchhttp.Response + if err := boxed.Unmarshal(&res2); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(res, res2); diff != "" { + t.Errorf("invalid response: %v", diff) + } + }) +} + +func TestStatusCodeStatus(t *testing.T) { + for _, test := range []struct { + code int + want dispatchproto.Status + }{ + // 1xx + { + code: http.StatusContinue, + want: dispatchproto.PermanentErrorStatus, + }, + + // 2xx + { + code: http.StatusOK, + want: dispatchproto.OKStatus, + }, + { + code: http.StatusAccepted, + want: dispatchproto.OKStatus, + }, + { + code: http.StatusCreated, + want: dispatchproto.OKStatus, + }, + + // 3xx + { + code: http.StatusTemporaryRedirect, + want: dispatchproto.PermanentErrorStatus, + }, + { + code: http.StatusPermanentRedirect, + want: dispatchproto.PermanentErrorStatus, + }, + + // 4xx + { + code: http.StatusBadRequest, + want: dispatchproto.InvalidArgumentStatus, + }, + { + code: http.StatusUnauthorized, + want: dispatchproto.UnauthenticatedStatus, + }, + { + code: http.StatusForbidden, + want: dispatchproto.PermissionDeniedStatus, + }, + { + code: http.StatusNotFound, + want: dispatchproto.NotFoundStatus, + }, + { + code: http.StatusMethodNotAllowed, + want: dispatchproto.PermanentErrorStatus, + }, + { + code: http.StatusRequestTimeout, + want: dispatchproto.TimeoutStatus, + }, + { + code: http.StatusTooManyRequests, + want: dispatchproto.ThrottledStatus, + }, + + // 5xx + { + code: http.StatusInternalServerError, + want: dispatchproto.TemporaryErrorStatus, + }, + { + code: http.StatusNotImplemented, + want: dispatchproto.PermanentErrorStatus, + }, + { + code: http.StatusBadGateway, + want: dispatchproto.TemporaryErrorStatus, + }, + { + code: http.StatusServiceUnavailable, + want: dispatchproto.TemporaryErrorStatus, + }, + { + code: http.StatusGatewayTimeout, + want: dispatchproto.TemporaryErrorStatus, + }, + + // invalid + { + code: 0, + want: dispatchproto.UnspecifiedStatus, + }, + { + code: 9999, + want: dispatchproto.UnspecifiedStatus, + }, + } { + t.Run(strconv.Itoa(test.code), func(t *testing.T) { + res := &dispatchhttp.Response{StatusCode: test.code} + got := dispatchproto.StatusOf(res) + if got != test.want { + t.Errorf("unexpected status for code %d: got %v, want %v", test.code, got, test.want) + } + }) + } +} diff --git a/dispatchhttp/request.go b/dispatchhttp/request.go new file mode 100644 index 0000000..247871c --- /dev/null +++ b/dispatchhttp/request.go @@ -0,0 +1,43 @@ +package dispatchhttp + +import ( + "encoding/json" + "net/http" +) + +// Request is an HTTP request. +type Request struct { + Method string + URL string + Header http.Header + Body []byte +} + +func (r *Request) MarshalJSON() ([]byte, error) { + // Indirection is required to avoid an infinite loop. + return json.Marshal(jsonRequest{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Body: r.Body, + }) +} + +func (r *Request) UnmarshalJSON(b []byte) error { + var jr jsonRequest + if err := json.Unmarshal(b, &jr); err != nil { + return err + } + r.Method = jr.Method + r.URL = jr.URL + r.Header = jr.Header + r.Body = jr.Body + return nil +} + +type jsonRequest struct { + Method string `json:"method,omitempty"` + URL string `json:"url,omitempty"` + Header http.Header `json:"header,omitempty"` + Body []byte `json:"body,omitempty"` +} diff --git a/dispatchhttp/response.go b/dispatchhttp/response.go new file mode 100644 index 0000000..4f5fe7b --- /dev/null +++ b/dispatchhttp/response.go @@ -0,0 +1,104 @@ +package dispatchhttp + +import ( + "encoding/json" + "io" + "net/http" + + "github.com/dispatchrun/dispatch-go/dispatchproto" +) + +// Response is an HTTP response. +type Response struct { + StatusCode int + Header http.Header + Body []byte +} + +// FromResponse creates a Response from an http.Response. +// +// The http.Response.Body is consumed and closed by this +// operation. +func FromResponse(r *http.Response) (*Response, error) { + if r == nil { + return nil, nil + } + + defer r.Body.Close() + b, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + + return &Response{ + StatusCode: r.StatusCode, + Header: cloneHeader(r.Header), + Body: b, + }, nil +} + +func (r *Response) MarshalJSON() ([]byte, error) { + // Indirection is required to avoid an infinite loop. + return json.Marshal(jsonResponse{ + StatusCode: r.StatusCode, + Header: r.Header, + Body: r.Body, + }) +} + +func (r *Response) UnmarshalJSON(b []byte) error { + var jr jsonResponse + if err := json.Unmarshal(b, &jr); err != nil { + return err + } + r.StatusCode = jr.StatusCode + r.Header = jr.Header + r.Body = jr.Body + return nil +} + +type jsonResponse struct { + StatusCode int `json:"status_code,omitempty"` + Header http.Header `json:"header,omitempty"` + Body []byte `json:"body,omitempty"` +} + +// Status is the status for the response. +func (r *Response) Status() dispatchproto.Status { + return statusCodeStatus(r.StatusCode) +} + +func statusCodeStatus(statusCode int) dispatchproto.Status { + // Keep in sync with https://github.com/dispatchrun/dispatch-py/blob/main/src/dispatch/integrations/http.py + switch statusCode { + case http.StatusBadRequest: // 400 + return dispatchproto.InvalidArgumentStatus + case http.StatusUnauthorized: // 401 + return dispatchproto.UnauthenticatedStatus + case http.StatusForbidden: // 403 + return dispatchproto.PermissionDeniedStatus + case http.StatusNotFound: // 404 + return dispatchproto.NotFoundStatus + case http.StatusRequestTimeout: // 408 + return dispatchproto.TimeoutStatus + case http.StatusTooManyRequests: // 429 + return dispatchproto.ThrottledStatus + case http.StatusNotImplemented: // 501 + return dispatchproto.PermanentErrorStatus + } + + switch statusCode / 100 { + case 1: // 1xx informational + return dispatchproto.PermanentErrorStatus + case 2: // 2xx success + return dispatchproto.OKStatus + case 3: // 3xx redirect + return dispatchproto.PermanentErrorStatus + case 4: // 4xx client error + return dispatchproto.PermanentErrorStatus + case 5: // 5xx server error + return dispatchproto.TemporaryErrorStatus + } + + return dispatchproto.UnspecifiedStatus +}