Skip to content

Commit

Permalink
Merge pull request #31 from mailgun/neill/develop
Browse files Browse the repository at this point in the history
Adding Set to Groupcache
  • Loading branch information
thrawn01 authored Jan 6, 2022
2 parents 144bfc1 + bd86e3c commit 075b815
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 21 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ require (
github.com/sirupsen/logrus v1.6.0
)

go 1.13
go 1.15
63 changes: 63 additions & 0 deletions groupcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func newGroup(name string, cacheBytes int64, getter Getter, peers PeerPicker) *G
peers: peers,
cacheBytes: cacheBytes,
loadGroup: &singleflight.Group{},
setGroup: &singleflight.Group{},
removeGroup: &singleflight.Group{},
}
if fn := newGroupHook; fn != nil {
Expand Down Expand Up @@ -182,6 +183,10 @@ type Group struct {
// concurrent callers.
loadGroup flightGroup

// setGroup ensures that each added key is only added
// remotely once regardless of the number of concurrent callers.
setGroup flightGroup

// removeGroup ensures that each removed key is only removed
// remotely once regardless of the number of concurrent callers.
removeGroup flightGroup
Expand Down Expand Up @@ -253,6 +258,34 @@ func (g *Group) Get(ctx context.Context, key string, dest Sink) error {
return setSinkView(dest, value)
}

func (g *Group) Set(ctx context.Context, key string, value []byte, expire time.Time, hotCache bool) error {
g.peersOnce.Do(g.initPeers)

if key == "" {
return errors.New("empty Set() key not allowed")
}

_, err := g.setGroup.Do(key, func() (interface{}, error) {
// If remote peer owns this key
owner, ok := g.peers.PickPeer(key)
if ok {
if err := g.setFromPeer(ctx, owner, key, value, expire); err != nil {
return nil, err
}
// TODO(thrawn01): Not sure if this is useful outside of tests...
// maybe we should ALWAYS update the local cache?
if hotCache {
g.localSet(key, value, expire, &g.hotCache)
}
return nil, nil
}
// We own this key
g.localSet(key, value, expire, &g.mainCache)
return nil, nil
})
return err
}

// Remove clears the key from our cache then forwards the remove
// request to all peers.
func (g *Group) Remove(ctx context.Context, key string) error {
Expand Down Expand Up @@ -425,6 +458,20 @@ func (g *Group) getFromPeer(ctx context.Context, peer ProtoGetter, key string) (
return value, nil
}

func (g *Group) setFromPeer(ctx context.Context, peer ProtoGetter, k string, v []byte, e time.Time) error {
var expire int64
if !e.IsZero() {
expire = e.UnixNano()
}
req := &pb.SetRequest{
Expire: &expire,
Group: &g.name,
Key: &k,
Value: v,
}
return peer.Set(ctx, req)
}

func (g *Group) removeFromPeer(ctx context.Context, peer ProtoGetter, key string) error {
req := &pb.GetRequest{
Group: &g.name,
Expand All @@ -445,6 +492,22 @@ func (g *Group) lookupCache(key string) (value ByteView, ok bool) {
return
}

func (g *Group) localSet(key string, value []byte, expire time.Time, cache *cache) {
if g.cacheBytes <= 0 {
return
}

bv := ByteView{
b: value,
e: expire,
}

// Ensure no requests are in flight
g.loadGroup.Lock(func() {
g.populateCache(key, bv, cache)
})
}

func (g *Group) localRemove(key string) {
// Clear key from our local cache
if g.cacheBytes <= 0 {
Expand Down
8 changes: 8 additions & 0 deletions groupcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ func (p *fakePeer) Get(_ context.Context, in *pb.GetRequest, out *pb.GetResponse
return nil
}

func (p *fakePeer) Set(_ context.Context, in *pb.SetRequest) error {
p.hits++
if p.fail {
return errors.New("simulated error from peer")
}
return nil
}

func (p *fakePeer) Remove(_ context.Context, in *pb.GetRequest) error {
p.hits++
if p.fail {
Expand Down
72 changes: 58 additions & 14 deletions groupcachepb/groupcache.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions groupcachepb/groupcache.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ message GetResponse {
optional int64 expire = 3;
}

message SetRequest {
required string group = 1;
required string key = 2;
optional bytes value = 3;
optional int64 expire = 4;
}

service GroupCache {
rpc Get(GetRequest) returns (GetResponse) {
};
Expand Down
64 changes: 59 additions & 5 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/url"
"strings"
"sync"
"time"

"github.com/golang/protobuf/proto"
"github.com/mailgun/groupcache/v2/consistenthash"
Expand Down Expand Up @@ -191,6 +192,34 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// The read the body and set the key value
if r.Method == http.MethodPut {
defer r.Body.Close()
b := bufferPool.Get().(*bytes.Buffer)
b.Reset()
defer bufferPool.Put(b)
_, err := io.Copy(b, r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

var out pb.SetRequest
err = proto.Unmarshal(b.Bytes(), &out)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

var expire time.Time
if out.Expire != nil && *out.Expire != 0 {
expire = time.Unix(*out.Expire/int64(time.Second), *out.Expire%int64(time.Second))
}

group.localSet(*out.Key, out.Value, expire, &group.mainCache)
return
}

var b []byte

value := AllocatingByteSliceSink(&b)
Expand Down Expand Up @@ -225,7 +254,6 @@ type httpGetter struct {
baseURL string
}

// GetURL
func (p *httpGetter) GetURL() string {
return p.baseURL
}
Expand All @@ -234,14 +262,19 @@ var bufferPool = sync.Pool{
New: func() interface{} { return new(bytes.Buffer) },
}

func (h *httpGetter) makeRequest(ctx context.Context, method string, in *pb.GetRequest, out *http.Response) error {
type request interface {
GetGroup() string
GetKey() string
}

func (h *httpGetter) makeRequest(ctx context.Context, m string, in request, b io.Reader, out *http.Response) error {
u := fmt.Sprintf(
"%v%v/%v",
h.baseURL,
url.QueryEscape(in.GetGroup()),
url.QueryEscape(in.GetKey()),
)
req, err := http.NewRequestWithContext(ctx, method, u, nil)
req, err := http.NewRequestWithContext(ctx, m, u, b)
if err != nil {
return err
}
Expand All @@ -261,7 +294,7 @@ func (h *httpGetter) makeRequest(ctx context.Context, method string, in *pb.GetR

func (h *httpGetter) Get(ctx context.Context, in *pb.GetRequest, out *pb.GetResponse) error {
var res http.Response
if err := h.makeRequest(ctx, http.MethodGet, in, &res); err != nil {
if err := h.makeRequest(ctx, http.MethodGet, in, nil, &res); err != nil {
return err
}
defer res.Body.Close()
Expand All @@ -282,9 +315,30 @@ func (h *httpGetter) Get(ctx context.Context, in *pb.GetRequest, out *pb.GetResp
return nil
}

func (h *httpGetter) Set(ctx context.Context, in *pb.SetRequest) error {
body, err := proto.Marshal(in)
if err != nil {
return fmt.Errorf("while marshaling SetRequest body: %w", err)
}
var res http.Response
if err := h.makeRequest(ctx, http.MethodPut, in, bytes.NewReader(body), &res); err != nil {
return err
}
defer res.Body.Close()

if res.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("while reading body response: %v", res.Status)
}
return fmt.Errorf("server returned status %d: %s", res.StatusCode, body)
}
return nil
}

func (h *httpGetter) Remove(ctx context.Context, in *pb.GetRequest) error {
var res http.Response
if err := h.makeRequest(ctx, http.MethodDelete, in, &res); err != nil {
if err := h.makeRequest(ctx, http.MethodDelete, in, nil, &res); err != nil {
return err
}
defer res.Body.Close()
Expand Down
23 changes: 23 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package groupcache

import (
"bytes"
"context"
"errors"
"flag"
Expand Down Expand Up @@ -75,6 +76,7 @@ func TestHTTPPool(t *testing.T) {
"--test_server_addr="+ts.URL,
)
cmds = append(cmds, cmd)
cmd.Stdout = os.Stdout
wg.Add(1)
if err := cmd.Start(); err != nil {
t.Fatal("failed to start child process: ", err)
Expand Down Expand Up @@ -151,6 +153,27 @@ func TestHTTPPool(t *testing.T) {
if serverHits != 2 {
t.Error("expected serverHits to be '2'")
}

key = "setMyTestKey"
setValue := []byte("test set")
// Add the key to the cache, optionally updating our local hot cache
if err := g.Set(ctx, key, setValue, time.Time{}, false); err != nil {
t.Fatal(err)
}

// Get the key
var getValue ByteView
if err := g.Get(ctx, key, ByteViewSink(&getValue)); err != nil {
t.Fatal(err)
}

if serverHits != 2 {
t.Errorf("expected serverHits to be '3' got '%d'", serverHits)
}

if !bytes.Equal(setValue, getValue.ByteSlice()) {
t.Fatal(errors.New(fmt.Sprintf("incorrect value retrieved after set: %s", getValue)))
}
}

func testKeys(n int) (keys []string) {
Expand Down
1 change: 1 addition & 0 deletions peers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
type ProtoGetter interface {
Get(context context.Context, in *pb.GetRequest, out *pb.GetResponse) error
Remove(context context.Context, in *pb.GetRequest) error
Set(context context.Context, in *pb.SetRequest) error
// GetURL returns the peer URL
GetURL() string
}
Expand Down
1 change: 0 additions & 1 deletion sinks.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ func ProtoSink(m proto.Message) Sink {
type protoSink struct {
dst proto.Message // authoritative value
typ string
ttl time.Duration

v ByteView // encoded
}
Expand Down

0 comments on commit 075b815

Please sign in to comment.