Skip to content

Commit

Permalink
New client API (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg authored Sep 15, 2023
1 parent 32ed3a6 commit 021644b
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 14 deletions.
58 changes: 56 additions & 2 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import (
"io"
"math/rand"
"net/http"
"sync/atomic"
"time"

"github.com/cristalhq/hedgedhttp"
)

func ExampleClient() {
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://google.com", http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cristalhq.dev", http.NoBody)
if err != nil {
panic(err)
}
Expand All @@ -39,9 +40,34 @@ func ExampleClient() {
// Output:
}

func Example_configNext() {
rt := &observableRoundTripper{
rt: http.DefaultTransport,
}

cfg := hedgedhttp.Config{
Transport: rt,
Upto: 3,
Delay: 50 * time.Millisecond,
Next: func() (upto int, delay time.Duration) {
return 3, rt.MaxLatency()
},
}
client, err := hedgedhttp.New(cfg)
if err != nil {
panic(err)
}

// or client.Do
resp, err := client.RoundTrip(&http.Request{})
_ = resp

// Output:
}

func ExampleRoundTripper() {
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://google.com", http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://cristalhq.dev", http.NoBody)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -173,3 +199,31 @@ func (t *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
return t.First.RoundTrip(req)
}

type observableRoundTripper struct {
rt http.RoundTripper
maxLatency atomic.Uint64
}

func (ort *observableRoundTripper) MaxLatency() time.Duration {
return time.Duration(ort.maxLatency.Load())
}

func (ort *observableRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
start := time.Now()
resp, err := ort.rt.RoundTrip(req)
if err != nil {
return resp, err
}

took := uint64(time.Since(start).Nanoseconds())
for {
max := ort.maxLatency.Load()
if max >= took {
return resp, err
}
if ort.maxLatency.CompareAndSwap(max, took) {
return resp, err
}
}
}
101 changes: 92 additions & 9 deletions hedged.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,79 @@ import (

const infiniteTimeout = 30 * 24 * time.Hour // domain specific infinite

// Client represents a hedged HTTP client.
type Client struct {
rt http.RoundTripper
stats *Stats
}

// Config for the [Client].
type Config struct {
// Transport of the [Client].
// Default is nil which results in [net/http.DefaultTransport].
Transport http.RoundTripper

// Upto says how much requests to make.
// Default is zero which means no hedged requests will be made.
Upto int

// Delay before 2 consequitive hedged requests.
Delay time.Duration

// Next returns the upto and delay for each HTTP that will be hedged.
// Default is nil which results in (Upto, Delay) result.
Next NextFn
}

// NextFn represents a function that is called for each HTTP request for retrieving hedging options.
type NextFn func() (upto int, delay time.Duration)

// New returns a new Client for the given config.
func New(cfg Config) (*Client, error) {
switch {
case cfg.Delay < 0:
return nil, errors.New("hedgedhttp: timeout cannot be negative")
case cfg.Upto < 0:
return nil, errors.New("hedgedhttp: upto cannot be negative")
}
if cfg.Transport == nil {
cfg.Transport = http.DefaultTransport
}

rt, stats, err := NewRoundTripperAndStats(cfg.Delay, cfg.Upto, cfg.Transport)
if err != nil {
return nil, err
}

// TODO(cristaloleg): this should be removed after internals cleanup.
rt2, ok := rt.(*hedgedTransport)
if !ok {
panic(fmt.Sprintf("want *hedgedTransport got %T", rt))
}
rt2.next = cfg.Next

c := &Client{
rt: rt2,
stats: stats,
}
return c, nil
}

// Stats returns statistics for the given client, see [Stats] methods.
func (c *Client) Stats() *Stats {
return c.stats
}

// Do does the same as [RoundTrip], this method is presented to align with [net/http.Client].
func (c *Client) Do(req *http.Request) (*http.Response, error) {
return c.rt.RoundTrip(req)
}

// RoundTrip implements [net/http.RoundTripper] interface.
func (c *Client) RoundTrip(req *http.Request) (*http.Response, error) {
return c.rt.RoundTrip(req)
}

// NewClient returns a new http.Client which implements hedged requests pattern.
// Given Client starts a new request after a timeout from previous request.
// Starts no more than upto requests.
Expand Down Expand Up @@ -63,8 +136,8 @@ func NewRoundTripperAndStats(timeout time.Duration, upto int, rt http.RoundTripp
switch {
case timeout < 0:
return nil, nil, errors.New("hedgedhttp: timeout cannot be negative")
case upto < 1:
return nil, nil, errors.New("hedgedhttp: upto must be greater than 0")
case upto < 0:
return nil, nil, errors.New("hedgedhttp: upto cannot be negative")
}

if rt == nil {
Expand All @@ -88,21 +161,31 @@ type hedgedTransport struct {
rt http.RoundTripper
timeout time.Duration
upto int
next NextFn
metrics *Stats
}

func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
mainCtx := req.Context()

timeout := ht.timeout
upto, timeout := ht.upto, ht.timeout
if ht.next != nil {
upto, timeout = ht.next()
}

// no hedged requests, just a regular one.
if upto == 0 {
return ht.rt.RoundTrip(req)
}

errOverall := &MultiError{}
resultCh := make(chan indexedResp, ht.upto)
errorCh := make(chan error, ht.upto)
resultCh := make(chan indexedResp, upto)
errorCh := make(chan error, upto)

ht.metrics.requestedRoundTripsInc()

resultIdx := -1
cancels := make([]func(), ht.upto)
cancels := make([]func(), upto)

defer runInPool(func() {
for i, cancel := range cancels {
Expand All @@ -113,8 +196,8 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error)
}
})

for sent := 0; len(errOverall.Errors) < ht.upto; sent++ {
if sent < ht.upto {
for sent := 0; len(errOverall.Errors) < upto; sent++ {
if sent < upto {
idx := sent
subReq, cancel := reqWithCtx(req, mainCtx, idx != 0)
cancels[idx] = cancel
Expand All @@ -132,7 +215,7 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error)
}

// all request sent - effectively disabling timeout between requests
if sent == ht.upto {
if sent == upto {
timeout = infiniteTimeout
}
resp, err := waitResult(mainCtx, resultCh, errorCh, timeout)
Expand Down
48 changes: 45 additions & 3 deletions hedged_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,59 @@ import (
"github.com/cristalhq/hedgedhttp"
)

func TestClient(t *testing.T) {
const handlerSleep = 100 * time.Millisecond
url := testServerURL(t, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(handlerSleep)
})

cfg := hedgedhttp.Config{
Transport: http.DefaultTransport,
Upto: 3,
Delay: 50 * time.Millisecond,
Next: func() (upto int, delay time.Duration) {
return 5, 10 * time.Millisecond
},
}
client, err := hedgedhttp.New(cfg)
mustOk(t, err)

start := time.Now()
resp, err := client.Do(newGetReq(url))
took := time.Since(start)
mustOk(t, err)
defer resp.Body.Close()
mustTrue(t, resp != nil)
mustEqual(t, resp.StatusCode, http.StatusOK)

stats := client.Stats()
mustEqual(t, stats.ActualRoundTrips(), uint64(5))
mustEqual(t, stats.OriginalRequestWins(), uint64(1))
mustTrue(t, took >= handlerSleep && took < (handlerSleep+10*time.Millisecond))
}

func TestValidateInput(t *testing.T) {
_, _, err := hedgedhttp.NewClientAndStats(-time.Second, 0, nil)
var err error
_, err = hedgedhttp.New(hedgedhttp.Config{
Delay: -time.Second,
})
mustFail(t, err)

_, err = hedgedhttp.New(hedgedhttp.Config{
Upto: -1,
})
mustFail(t, err)

_, _, err = hedgedhttp.NewClientAndStats(-time.Second, 0, nil)
mustFail(t, err)

_, _, err = hedgedhttp.NewClientAndStats(time.Second, -1, nil)
mustFail(t, err)

_, _, err = hedgedhttp.NewClientAndStats(time.Second, 0, nil)
_, _, err = hedgedhttp.NewClientAndStats(time.Second, -1, nil)
mustFail(t, err)

_, err = hedgedhttp.NewRoundTripper(time.Second, 0, nil)
_, err = hedgedhttp.NewRoundTripper(time.Second, -1, nil)
mustFail(t, err)
}

Expand Down

0 comments on commit 021644b

Please sign in to comment.