diff --git a/examples_test.go b/examples_test.go index 07ebb37..dc8046f 100644 --- a/examples_test.go +++ b/examples_test.go @@ -7,6 +7,7 @@ import ( "io" "math/rand" "net/http" + "sync/atomic" "time" "github.com/cristalhq/hedgedhttp" @@ -14,7 +15,7 @@ import ( func ExampleClient() { ctx := context.Background() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://google.com", http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cristalhq.dev", http.NoBody) if err != nil { panic(err) } @@ -39,9 +40,34 @@ func ExampleClient() { // Output: } +func Example_configNext() { + rt := &observableRoundTripper{ + rt: http.DefaultTransport, + } + + cfg := hedgedhttp.Config{ + Transport: rt, + Upto: 3, + Delay: 50 * time.Millisecond, + Next: func() (upto int, delay time.Duration) { + return 3, rt.MaxLatency() + }, + } + client, err := hedgedhttp.New(cfg) + if err != nil { + panic(err) + } + + // or client.Do + resp, err := client.RoundTrip(&http.Request{}) + _ = resp + + // Output: +} + func ExampleRoundTripper() { ctx := context.Background() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://google.com", http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cristalhq.dev", http.NoBody) if err != nil { panic(err) } @@ -173,3 +199,31 @@ func (t *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) { } return t.First.RoundTrip(req) } + +type observableRoundTripper struct { + rt http.RoundTripper + maxLatency atomic.Uint64 +} + +func (ort *observableRoundTripper) MaxLatency() time.Duration { + return time.Duration(ort.maxLatency.Load()) +} + +func (ort *observableRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + start := time.Now() + resp, err := ort.rt.RoundTrip(req) + if err != nil { + return resp, err + } + + took := uint64(time.Since(start).Nanoseconds()) + for { + max := ort.maxLatency.Load() + if max >= took { + return resp, err + } + if ort.maxLatency.CompareAndSwap(max, took) { + return resp, err + } + } +} diff --git a/hedged.go b/hedged.go index 24e03f5..0852224 100644 --- a/hedged.go +++ b/hedged.go @@ -12,6 +12,79 @@ import ( const infiniteTimeout = 30 * 24 * time.Hour // domain specific infinite +// Client represents a hedged HTTP client. +type Client struct { + rt http.RoundTripper + stats *Stats +} + +// Config for the [Client]. +type Config struct { + // Transport of the [Client]. + // Default is nil which results in [net/http.DefaultTransport]. + Transport http.RoundTripper + + // Upto says how much requests to make. + // Default is zero which means no hedged requests will be made. + Upto int + + // Delay before 2 consequitive hedged requests. + Delay time.Duration + + // Next returns the upto and delay for each HTTP that will be hedged. + // Default is nil which results in (Upto, Delay) result. + Next NextFn +} + +// NextFn represents a function that is called for each HTTP request for retrieving hedging options. +type NextFn func() (upto int, delay time.Duration) + +// New returns a new Client for the given config. +func New(cfg Config) (*Client, error) { + switch { + case cfg.Delay < 0: + return nil, errors.New("hedgedhttp: timeout cannot be negative") + case cfg.Upto < 0: + return nil, errors.New("hedgedhttp: upto cannot be negative") + } + if cfg.Transport == nil { + cfg.Transport = http.DefaultTransport + } + + rt, stats, err := NewRoundTripperAndStats(cfg.Delay, cfg.Upto, cfg.Transport) + if err != nil { + return nil, err + } + + // TODO(cristaloleg): this should be removed after internals cleanup. + rt2, ok := rt.(*hedgedTransport) + if !ok { + panic(fmt.Sprintf("want *hedgedTransport got %T", rt)) + } + rt2.next = cfg.Next + + c := &Client{ + rt: rt2, + stats: stats, + } + return c, nil +} + +// Stats returns statistics for the given client, see [Stats] methods. +func (c *Client) Stats() *Stats { + return c.stats +} + +// Do does the same as [RoundTrip], this method is presented to align with [net/http.Client]. +func (c *Client) Do(req *http.Request) (*http.Response, error) { + return c.rt.RoundTrip(req) +} + +// RoundTrip implements [net/http.RoundTripper] interface. +func (c *Client) RoundTrip(req *http.Request) (*http.Response, error) { + return c.rt.RoundTrip(req) +} + // NewClient returns a new http.Client which implements hedged requests pattern. // Given Client starts a new request after a timeout from previous request. // Starts no more than upto requests. @@ -63,8 +136,8 @@ func NewRoundTripperAndStats(timeout time.Duration, upto int, rt http.RoundTripp switch { case timeout < 0: return nil, nil, errors.New("hedgedhttp: timeout cannot be negative") - case upto < 1: - return nil, nil, errors.New("hedgedhttp: upto must be greater than 0") + case upto < 0: + return nil, nil, errors.New("hedgedhttp: upto cannot be negative") } if rt == nil { @@ -88,21 +161,31 @@ type hedgedTransport struct { rt http.RoundTripper timeout time.Duration upto int + next NextFn metrics *Stats } func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) { mainCtx := req.Context() - timeout := ht.timeout + upto, timeout := ht.upto, ht.timeout + if ht.next != nil { + upto, timeout = ht.next() + } + + // no hedged requests, just a regular one. + if upto == 0 { + return ht.rt.RoundTrip(req) + } + errOverall := &MultiError{} - resultCh := make(chan indexedResp, ht.upto) - errorCh := make(chan error, ht.upto) + resultCh := make(chan indexedResp, upto) + errorCh := make(chan error, upto) ht.metrics.requestedRoundTripsInc() resultIdx := -1 - cancels := make([]func(), ht.upto) + cancels := make([]func(), upto) defer runInPool(func() { for i, cancel := range cancels { @@ -113,8 +196,8 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) } }) - for sent := 0; len(errOverall.Errors) < ht.upto; sent++ { - if sent < ht.upto { + for sent := 0; len(errOverall.Errors) < upto; sent++ { + if sent < upto { idx := sent subReq, cancel := reqWithCtx(req, mainCtx, idx != 0) cancels[idx] = cancel @@ -132,7 +215,7 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) } // all request sent - effectively disabling timeout between requests - if sent == ht.upto { + if sent == upto { timeout = infiniteTimeout } resp, err := waitResult(mainCtx, resultCh, errorCh, timeout) diff --git a/hedged_test.go b/hedged_test.go index a931683..c5ae939 100644 --- a/hedged_test.go +++ b/hedged_test.go @@ -15,17 +15,59 @@ import ( "github.com/cristalhq/hedgedhttp" ) +func TestClient(t *testing.T) { + const handlerSleep = 100 * time.Millisecond + url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) { + time.Sleep(handlerSleep) + }) + + cfg := hedgedhttp.Config{ + Transport: http.DefaultTransport, + Upto: 3, + Delay: 50 * time.Millisecond, + Next: func() (upto int, delay time.Duration) { + return 5, 10 * time.Millisecond + }, + } + client, err := hedgedhttp.New(cfg) + mustOk(t, err) + + start := time.Now() + resp, err := client.Do(newGetReq(url)) + took := time.Since(start) + mustOk(t, err) + defer resp.Body.Close() + mustTrue(t, resp != nil) + mustEqual(t, resp.StatusCode, http.StatusOK) + + stats := client.Stats() + mustEqual(t, stats.ActualRoundTrips(), uint64(5)) + mustEqual(t, stats.OriginalRequestWins(), uint64(1)) + mustTrue(t, took >= handlerSleep && took < (handlerSleep+10*time.Millisecond)) +} + func TestValidateInput(t *testing.T) { - _, _, err := hedgedhttp.NewClientAndStats(-time.Second, 0, nil) + var err error + _, err = hedgedhttp.New(hedgedhttp.Config{ + Delay: -time.Second, + }) + mustFail(t, err) + + _, err = hedgedhttp.New(hedgedhttp.Config{ + Upto: -1, + }) + mustFail(t, err) + + _, _, err = hedgedhttp.NewClientAndStats(-time.Second, 0, nil) mustFail(t, err) _, _, err = hedgedhttp.NewClientAndStats(time.Second, -1, nil) mustFail(t, err) - _, _, err = hedgedhttp.NewClientAndStats(time.Second, 0, nil) + _, _, err = hedgedhttp.NewClientAndStats(time.Second, -1, nil) mustFail(t, err) - _, err = hedgedhttp.NewRoundTripper(time.Second, 0, nil) + _, err = hedgedhttp.NewRoundTripper(time.Second, -1, nil) mustFail(t, err) }