Skip to content

Commit

Permalink
Merge pull request #1640 from josephschorr/singleflight-recursion
Browse files Browse the repository at this point in the history
Fix handling of recursive calls via singleflight dispatch
  • Loading branch information
josephschorr authored Nov 6, 2023
2 parents 2591079 + 2e392db commit 0114177
Show file tree
Hide file tree
Showing 17 changed files with 394 additions and 139 deletions.
40 changes: 37 additions & 3 deletions internal/dispatch/singleflight/singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/authzed/spicedb/internal/dispatch"
"github.com/authzed/spicedb/internal/dispatch/keys"
"github.com/authzed/spicedb/pkg/genutil/mapz"
v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
)

Expand All @@ -24,14 +25,23 @@ var singleFlightCount = promauto.NewCounterVec(prometheus.CounterOpts{
}, []string{"method", "shared"})

func New(delegate dispatch.Dispatcher, handler keys.Handler) dispatch.Dispatcher {
return &Dispatcher{delegate: delegate, keyHandler: handler}
return &Dispatcher{
delegate: delegate,
keyHandler: handler,
checkByDispatchKey: mapz.NewCountingMultiMap[string, string](),
expandByDispatchKey: mapz.NewCountingMultiMap[string, string](),
}
}

type Dispatcher struct {
delegate dispatch.Dispatcher
keyHandler keys.Handler
delegate dispatch.Dispatcher
keyHandler keys.Handler

checkGroup singleflight.Group[string, *v1.DispatchCheckResponse]
expandGroup singleflight.Group[string, *v1.DispatchExpandResponse]

checkByDispatchKey *mapz.CountingMultiMap[string, string]
expandByDispatchKey *mapz.CountingMultiMap[string, string]
}

func (d *Dispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) {
Expand All @@ -42,6 +52,18 @@ func (d *Dispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckReq
}

keyString := hex.EncodeToString(key)

// Check if the key has already been part of a dispatch, for the *same* request ID. If so, this represents a
// likely recursive call, so we dispatch it to the delegate to avoid the singleflight from blocking it.
requestID := req.Metadata.RequestId
existed := d.checkByDispatchKey.Add(keyString, requestID)
defer d.checkByDispatchKey.Remove(keyString, requestID)

if existed {
// Likely a recursive call.
return d.delegate.DispatchCheck(ctx, req)
}

v, isShared, err := d.checkGroup.Do(ctx, keyString, func(innerCtx context.Context) (*v1.DispatchCheckResponse, error) {
return d.delegate.DispatchCheck(innerCtx, req)
})
Expand All @@ -62,6 +84,18 @@ func (d *Dispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandR
}

keyString := hex.EncodeToString(key)

// Check if the key has already been part of a dispatch, for the *same* request ID. If so, this represents a
// likely recursive call, so we dispatch it to the delegate to avoid the singleflight from blocking it.
requestID := req.Metadata.RequestId
existed := d.expandByDispatchKey.Add(keyString, requestID)
defer d.expandByDispatchKey.Remove(keyString, requestID)

if existed {
// Likely a recursive call.
return d.delegate.DispatchExpand(ctx, req)
}

v, isShared, err := d.expandGroup.Do(ctx, keyString, func(ictx context.Context) (*v1.DispatchExpandResponse, error) {
return d.delegate.DispatchExpand(ictx, req)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,42 @@ func TestSingleFlightDispatcher(t *testing.T) {
}
disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{})

req := &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
},
}

wg := sync.WaitGroup{}
wg.Add(4)
go func() {
_, _ = disp.DispatchCheck(context.Background(), req)
_, _ = disp.DispatchCheck(context.Background(), &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
RequestId: "first",
},
})
wg.Done()
}()
go func() {
_, _ = disp.DispatchCheck(context.Background(), req)
_, _ = disp.DispatchCheck(context.Background(), &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
RequestId: "second",
},
})
wg.Done()
}()
go func() {
_, _ = disp.DispatchCheck(context.Background(), req)
_, _ = disp.DispatchCheck(context.Background(), &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
RequestId: "third",
},
})

wg.Done()
}()
Expand Down Expand Up @@ -74,35 +89,50 @@ func TestSingleFlightDispatcherCancelation(t *testing.T) {
}
disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{})

req := &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
},
}

wg := sync.WaitGroup{}
wg.Add(3)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err := disp.DispatchCheck(ctx, req)
_, err := disp.DispatchCheck(ctx, &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
RequestId: "first",
},
})
wg.Done()
require.ErrorIs(t, err, context.DeadlineExceeded)
}()
go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err := disp.DispatchCheck(ctx, req)
_, err := disp.DispatchCheck(ctx, &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
RequestId: "second",
},
})
wg.Done()
require.ErrorIs(t, err, context.DeadlineExceeded)
}()
go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err := disp.DispatchCheck(ctx, req)
_, err := disp.DispatchCheck(ctx, &v1.DispatchCheckRequest{
ResourceRelation: tuple.RelationReference("document", "view"),
ResourceIds: []string{"foo", "bar"},
Subject: tuple.ObjectAndRelation("user", "tom", "..."),
Metadata: &v1.ResolverMeta{
AtRevision: "1234",
RequestId: "third",
},
})
wg.Done()
require.ErrorIs(t, err, context.DeadlineExceeded)
}()
Expand Down
4 changes: 4 additions & 0 deletions internal/graph/computed/computecheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
"github.com/authzed/spicedb/pkg/datastore"
"github.com/authzed/spicedb/pkg/genutil/slicez"
"github.com/authzed/spicedb/pkg/middleware/requestid"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
"github.com/authzed/spicedb/pkg/spiceerrors"
Expand Down Expand Up @@ -102,6 +103,8 @@ func computeCheck(ctx context.Context,
metadata := &v1.ResponseMeta{}

// TODO(jschorr): Should we make this run in parallel via the preloadedTaskRunner?
requestID, ctx := requestid.GetOrGenerateRequestID(ctx)

_, err := slicez.ForEachChunkUntil(resourceIDs, datastore.FilterMaximumIDCount, func(resourceIDsToCheck []string) (bool, error) {
checkResult, err := d.DispatchCheck(ctx, &v1.DispatchCheckRequest{
ResourceRelation: params.ResourceType,
Expand All @@ -111,6 +114,7 @@ func computeCheck(ctx context.Context,
Metadata: &v1.ResolverMeta{
AtRevision: params.AtRevision.String(),
DepthRemaining: params.MaximumDepth,
RequestId: requestID,
},
Debug: debugging,
})
Expand Down
1 change: 1 addition & 0 deletions internal/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func decrementDepth(md *v1.ResolverMeta) *v1.ResolverMeta {
return &v1.ResolverMeta{
AtRevision: md.AtRevision,
DepthRemaining: md.DepthRemaining - 1,
RequestId: md.RequestId,
}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/graph/lookupsubjects.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaComputed(
Metadata: &v1.ResolverMeta{
AtRevision: parentRequest.Revision.String(),
DepthRemaining: parentRequest.Metadata.DepthRemaining - 1,
RequestId: parentRequest.Metadata.RequestId,
},
}, stream)
}
Expand Down Expand Up @@ -426,6 +427,7 @@ func (cl *ConcurrentLookupSubjects) dispatchTo(
Metadata: &v1.ResolverMeta{
AtRevision: parentRequest.Revision.String(),
DepthRemaining: parentRequest.Metadata.DepthRemaining - 1,
RequestId: parentRequest.Metadata.RequestId,
},
}, stream)
})
Expand Down
1 change: 1 addition & 0 deletions internal/graph/reachableresources.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ func (crr *CursoredReachableResources) redispatchOrReport(
Metadata: &v1.ResolverMeta{
AtRevision: parentRequest.Revision.String(),
DepthRemaining: parentRequest.Metadata.DepthRemaining - 1,
RequestId: parentRequest.Metadata.RequestId,
},
OptionalCursor: ci.currentCursor,
OptionalLimit: ci.limits.currentLimit,
Expand Down
1 change: 1 addition & 0 deletions internal/services/integrationtesting/consistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ func validateExpansionSubjects(t *testing.T, vctx validationContext) {
Metadata: &dispatchv1.ResolverMeta{
AtRevision: vctx.revision.String(),
DepthRemaining: 100,
RequestId: "somerequestid",
},
ExpansionMode: dispatchv1.DispatchExpandRequest_RECURSIVE,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ func BuildAccessibilitySet(t *testing.T, ccd ConsistencyClusterAndData) *Accessi
Metadata: &dispatchv1.ResolverMeta{
AtRevision: headRevision.String(),
DepthRemaining: 50,
RequestId: "somerequestid",
},
})
require.NoError(t, err)
Expand Down Expand Up @@ -359,6 +360,7 @@ func isAccessibleViaWildcardOnly(
Metadata: &dispatchv1.ResolverMeta{
AtRevision: revision.String(),
DepthRemaining: 100,
RequestId: "somerequestid",
},
ExpansionMode: dispatchv1.DispatchExpandRequest_RECURSIVE,
})
Expand Down
12 changes: 10 additions & 2 deletions internal/services/v1/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/authzed/spicedb/pkg/cursor"
"github.com/authzed/spicedb/pkg/datastore"
"github.com/authzed/spicedb/pkg/middleware/consistency"
"github.com/authzed/spicedb/pkg/middleware/requestid"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
"github.com/authzed/spicedb/pkg/tuple"
Expand Down Expand Up @@ -155,10 +156,13 @@ func (ps *permissionServer) ExpandPermissionTree(ctx context.Context, req *v1.Ex
return nil, ps.rewriteError(ctx, err)
}

requestID, ctx := requestid.GetOrGenerateRequestID(ctx)

resp, err := ps.dispatch.DispatchExpand(ctx, &dispatch.DispatchExpandRequest{
Metadata: &dispatch.ResolverMeta{
AtRevision: atRevision.String(),
DepthRemaining: ps.config.MaximumAPIDepth,
RequestId: requestID,
},
ResourceAndRelation: &core.ObjectAndRelation{
Namespace: req.Resource.ObjectType,
Expand Down Expand Up @@ -326,7 +330,8 @@ func TranslateExpansionTree(node *core.RelationTupleTreeNode) *v1.PermissionRela
}

func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp v1.PermissionsService_LookupResourcesServer) error {
ctx := resp.Context()
requestID, ctx := requestid.GetOrGenerateRequestID(resp.Context())

atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
if err != nil {
return ps.rewriteError(ctx, err)
Expand Down Expand Up @@ -420,6 +425,7 @@ func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp
Metadata: &dispatch.ResolverMeta{
AtRevision: atRevision.String(),
DepthRemaining: ps.config.MaximumAPIDepth,
RequestId: requestID,
},
ObjectRelation: &core.RelationReference{
Namespace: req.ResourceObjectType,
Expand All @@ -444,7 +450,8 @@ func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp
}

func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v1.PermissionsService_LookupSubjectsServer) error {
ctx := resp.Context()
requestID, ctx := requestid.GetOrGenerateRequestID(resp.Context())

atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
if err != nil {
return ps.rewriteError(ctx, err)
Expand Down Expand Up @@ -538,6 +545,7 @@ func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v
Metadata: &dispatch.ResolverMeta{
AtRevision: atRevision.String(),
DepthRemaining: ps.config.MaximumAPIDepth,
RequestId: requestID,
},
ResourceRelation: &core.RelationReference{
Namespace: req.Resource.ObjectType,
Expand Down
1 change: 1 addition & 0 deletions pkg/development/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func RunValidation(devContext *DevContext, validation *blocks.ParsedExpectedRela
Metadata: &v1.ResolverMeta{
AtRevision: devContext.Revision.String(),
DepthRemaining: maxDispatchDepth,
RequestId: "validation",
},
ExpansionMode: v1.DispatchExpandRequest_RECURSIVE,
})
Expand Down
50 changes: 50 additions & 0 deletions pkg/genutil/mapz/countingmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package mapz

import "sync"

// CountingMultiMap is a multimap that counts the number of distinct values for each
// key, removing the key from the map when the count reaches zero. Safe for concurrent
// use.
type CountingMultiMap[T comparable, Q comparable] struct {
valuesByKey map[T]*Set[Q]
lock sync.Mutex
}

// NewCountingMultiMap constructs a new counting multimap.
func NewCountingMultiMap[T comparable, Q comparable]() *CountingMultiMap[T, Q] {
return &CountingMultiMap[T, Q]{
valuesByKey: map[T]*Set[Q]{},
lock: sync.Mutex{},
}
}

// Add adds the given value to the map at the given key. Returns true if the value
// already existed in the map for the given key.
func (cmm *CountingMultiMap[T, Q]) Add(key T, value Q) bool {
cmm.lock.Lock()
defer cmm.lock.Unlock()

values, ok := cmm.valuesByKey[key]
if !ok {
values = NewSet[Q]()
cmm.valuesByKey[key] = values
}
return !values.Add(value)
}

// Remove removes the given value for the given key from the map. If, after this removal,
// the key has no additional values, it is removed entirely from the map.
func (cmm *CountingMultiMap[T, Q]) Remove(key T, value Q) {
cmm.lock.Lock()
defer cmm.lock.Unlock()

values, ok := cmm.valuesByKey[key]
if !ok {
return
}

values.Remove(value)
if values.IsEmpty() {
delete(cmm.valuesByKey, key)
}
}
Loading

0 comments on commit 0114177

Please sign in to comment.