Skip to content

Commit

Permalink
Merge pull request #3 from mailgun/thrawn/develop
Browse files Browse the repository at this point in the history
PIP-407: Added support for key removal within a group
  • Loading branch information
thrawn01 authored May 15, 2019
2 parents d6e54d2 + 4f7e5ec commit 82209d3
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 14 deletions.
94 changes: 88 additions & 6 deletions groupcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ func newGroup(name string, cacheBytes int64, getter Getter, peers PeerPicker) *G
panic("duplicate registration of group " + name)
}
g := &Group{
name: name,
getter: getter,
peers: peers,
cacheBytes: cacheBytes,
loadGroup: &singleflight.Group{},
name: name,
getter: getter,
peers: peers,
cacheBytes: cacheBytes,
loadGroup: &singleflight.Group{},
removeGroup: &singleflight.Group{},
}
if fn := newGroupHook; fn != nil {
fn(g)
Expand Down Expand Up @@ -167,6 +168,10 @@ type Group struct {
// concurrent callers.
loadGroup flightGroup

// removeGroup ensures that each removed key is only removed
// remotely once regardless of the number of concurrent callers.
removeGroup flightGroup

_ int32 // force Stats to be 8-byte aligned on 32-bit platforms

// Stats are statistics on the group.
Expand All @@ -177,8 +182,8 @@ type Group struct {
// satisfies. We define this so that we may test with an alternate
// implementation.
type flightGroup interface {
// Done is called when Do is done.
Do(key string, fn func() (interface{}, error)) (interface{}, error)
Lock(fn func())
}

// Stats are per-group statistics.
Expand Down Expand Up @@ -233,6 +238,53 @@ func (g *Group) Get(ctx Context, key string, dest Sink) error {
return setSinkView(dest, value)
}

// Remove clears the key from our cache then forwards the remove
// request to all peers.
func (g *Group) Remove(ctx Context, key string) error {
_, err := g.removeGroup.Do(key, func() (interface{}, error) {

// Remove from key owner first
owner, ok := g.peers.PickPeer(key)
if ok {
if err := g.removeFromPeer(ctx, owner, key); err != nil {
return nil, err
}
}
// Remove from our cache first in case we are owner
g.localRemove(key)
wg := sync.WaitGroup{}
errs := make(chan error)

// Asynchronously clear the key from all hot and main caches of peers
for _, peer := range g.peers.GetAll() {
// avoid deleting from owner a second time
if peer == owner {
continue
}

wg.Add(1)
go func() {
errs <- g.removeFromPeer(ctx, peer, key)
wg.Done()
}()
}
go func() {
wg.Wait()
close(errs)
}()

// TODO(thrawn01): Should we report all errors? Reporting context
// cancelled error for each peer doesn't make much sense.
var err error
for e := range errs {
err = e
}

return nil, err
})
return err
}

// load loads key either by invoking the getter locally or by sending it to another machine.
func (g *Group) load(ctx Context, key string, dest Sink) (value ByteView, destPopulated bool, err error) {
g.Stats.Loads.Add(1)
Expand Down Expand Up @@ -330,6 +382,14 @@ func (g *Group) getFromPeer(ctx Context, peer ProtoGetter, key string) (ByteView
return value, nil
}

func (g *Group) removeFromPeer(ctx Context, peer ProtoGetter, key string) error {
req := &pb.GetRequest{
Group: &g.name,
Key: &key,
}
return peer.Remove(ctx, req)
}

func (g *Group) lookupCache(key string) (value ByteView, ok bool) {
if g.cacheBytes <= 0 {
return
Expand All @@ -342,6 +402,19 @@ func (g *Group) lookupCache(key string) (value ByteView, ok bool) {
return
}

func (g *Group) localRemove(key string) {
// Clear key from our local cache
if g.cacheBytes <= 0 {
return
}

// Ensure no requests are in flight
g.loadGroup.Lock(func() {
g.hotCache.remove(key)
g.mainCache.remove(key)
})
}

func (g *Group) populateCache(key string, value ByteView, cache *cache) {
if g.cacheBytes <= 0 {
return
Expand Down Expand Up @@ -447,6 +520,15 @@ func (c *cache) get(key string) (value ByteView, ok bool) {
return vi.(ByteView), true
}

func (c *cache) remove(key string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lru == nil {
return
}
c.lru.Remove(key)
}

func (c *cache) removeOldest() {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down
16 changes: 16 additions & 0 deletions groupcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ func (p *fakePeer) Get(_ Context, in *pb.GetRequest, out *pb.GetResponse) error
return nil
}

func (p *fakePeer) Remove(_ Context, in *pb.GetRequest) error {
p.hits++
if p.fail {
return errors.New("simulated error from peer")
}
return nil
}

type fakePeers []ProtoGetter

func (p fakePeers) PickPeer(key string) (peer ProtoGetter, ok bool) {
Expand All @@ -273,6 +281,10 @@ func (p fakePeers) PickPeer(key string) (peer ProtoGetter, ok bool) {
return p[n], p[n] != nil
}

func (p fakePeers) GetAll() []ProtoGetter {
return p
}

// tests that peers (virtual, in-process) are hit, and how much.
func TestPeers(t *testing.T) {
once.Do(testSetup)
Expand Down Expand Up @@ -406,6 +418,10 @@ func (g *orderedFlightGroup) Do(key string, fn func() (interface{}, error)) (int
return g.orig.Do(key, fn)
}

func (g *orderedFlightGroup) Lock(fn func()) {
fn()
}

// TestNoDedup tests invariants on the cache size when singleflight is
// unable to dedup calls.
func TestNoDedup(t *testing.T) {
Expand Down
48 changes: 45 additions & 3 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,20 @@ func (p *HTTPPool) Set(peers ...string) {
}
}

// GetAll returns all the peers in the pool
func (p *HTTPPool) GetAll() []ProtoGetter {
p.mu.Lock()
defer p.mu.Unlock()

var i int
res := make([]ProtoGetter, len(p.httpGetters))
for _, v := range p.httpGetters {
res[i] = v
i++
}
return res
}

func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) {
p.mu.Lock()
defer p.mu.Unlock()
Expand Down Expand Up @@ -163,6 +177,13 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

group.Stats.ServerRequests.Add(1)

// Delete the key and return 200
if r.Method == http.MethodDelete {
group.localRemove(key)
return
}

var b []byte

value := AllocatingByteSliceSink(&b)
Expand Down Expand Up @@ -201,14 +222,14 @@ var bufferPool = sync.Pool{
New: func() interface{} { return new(bytes.Buffer) },
}

func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error {
func (h *httpGetter) makeRequest(context Context, method string, in *pb.GetRequest, out *http.Response) error {
u := fmt.Sprintf(
"%v%v/%v",
h.baseURL,
url.QueryEscape(in.GetGroup()),
url.QueryEscape(in.GetKey()),
)
req, err := http.NewRequest("GET", u, nil)
req, err := http.NewRequest(method, u, nil)
if err != nil {
return err
}
Expand All @@ -220,14 +241,23 @@ func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse
if err != nil {
return err
}
*out = *res
return nil
}

func (h *httpGetter) Get(ctx Context, in *pb.GetRequest, out *pb.GetResponse) error {
var res http.Response
if err := h.makeRequest(ctx, http.MethodGet, in, &res); err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("server returned: %v", res.Status)
}
b := bufferPool.Get().(*bytes.Buffer)
b.Reset()
defer bufferPool.Put(b)
_, err = io.Copy(b, res.Body)
_, err := io.Copy(b, res.Body)
if err != nil {
return fmt.Errorf("reading response body: %v", err)
}
Expand All @@ -237,3 +267,15 @@ func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse
}
return nil
}

func (h *httpGetter) Remove(ctx Context, in *pb.GetRequest) error {
var res http.Response
if err := h.makeRequest(ctx, http.MethodDelete, in, &res); err != nil {
return err
}
res.Body.Close()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("server returned: %v", res.Status)
}
return nil
}
60 changes: 55 additions & 5 deletions http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package groupcache
import (
"errors"
"flag"
"fmt"
"log"
"net"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"strconv"
Expand All @@ -32,14 +34,15 @@ import (
)

var (
peerAddrs = flag.String("test_peer_addrs", "", "Comma-separated list of peer addresses; used by TestHTTPPool")
peerIndex = flag.Int("test_peer_index", -1, "Index of which peer this child is; used by TestHTTPPool")
peerChild = flag.Bool("test_peer_child", false, "True if running as a child process; used by TestHTTPPool")
peerAddrs = flag.String("test_peer_addrs", "", "Comma-separated list of peer addresses; used by TestHTTPPool")
peerIndex = flag.Int("test_peer_index", -1, "Index of which peer this child is; used by TestHTTPPool")
peerChild = flag.Bool("test_peer_child", false, "True if running as a child process; used by TestHTTPPool")
serverAddr = flag.String("test_server_addr", "", "Address of the server Child Getters will hit ; used by TestHTTPPool")
)

func TestHTTPPool(t *testing.T) {
if *peerChild {
beChildForTestHTTPPool()
beChildForTestHTTPPool(t)
os.Exit(0)
}

Expand All @@ -48,6 +51,13 @@ func TestHTTPPool(t *testing.T) {
nGets = 100
)

var serverHits int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello")
serverHits++
}))
defer ts.Close()

var childAddr []string
for i := 0; i < nChild; i++ {
childAddr = append(childAddr, pickFreeAddr(t))
Expand All @@ -61,6 +71,7 @@ func TestHTTPPool(t *testing.T) {
"--test_peer_child",
"--test_peer_addrs="+strings.Join(childAddr, ","),
"--test_peer_index="+strconv.Itoa(i),
"--test_server_addr="+ts.URL,
)
cmds = append(cmds, cmd)
wg.Add(1)
Expand Down Expand Up @@ -100,6 +111,41 @@ func TestHTTPPool(t *testing.T) {
}
t.Logf("Get key=%q, value=%q (peer:key)", key, value)
}

if serverHits != nGets {
t.Error("expected serverHits to equal nGets")
}
serverHits = 0

var value string
var key = "removeTestKey"

// Multiple gets on the same key
for i := 0; i < 2; i++ {
if err := g.Get(nil, key, StringSink(&value)); err != nil {
t.Fatal(err)
}
}

// Should result in only 1 server get
if serverHits != 1 {
t.Error("expected serverHits to be '1'")
}

// Remove the key from the cache and we should see another server hit
if err := g.Remove(nil, key); err != nil {
t.Fatal(err)
}

// Get the key again
if err := g.Get(nil, key, StringSink(&value)); err != nil {
t.Fatal(err)
}

// Should register another server get
if serverHits != 2 {
t.Error("expected serverHits to be '2'")
}
}

func testKeys(n int) (keys []string) {
Expand All @@ -110,13 +156,17 @@ func testKeys(n int) (keys []string) {
return
}

func beChildForTestHTTPPool() {
func beChildForTestHTTPPool(t *testing.T) {
addrs := strings.Split(*peerAddrs, ",")

p := NewHTTPPool("http://" + addrs[*peerIndex])
p.Set(addrToURL(addrs)...)

getter := GetterFunc(func(ctx Context, key string, dest Sink) error {
if _, err := http.Get(*serverAddr); err != nil {
t.Logf("HTTP request from getter failed with '%s'", err)
}

dest.SetString(strconv.Itoa(*peerIndex)+":"+key, time.Time{})
return nil
})
Expand Down
Loading

0 comments on commit 82209d3

Please sign in to comment.