Skip to content

Commit

Permalink
Move caveat loading into a shared runner to reduce overhead in dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
josephschorr committed Dec 20, 2024
1 parent 68a6ebc commit 3ab909a
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 101 deletions.
232 changes: 140 additions & 92 deletions internal/caveats/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"fmt"
"maps"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"

"github.com/authzed/spicedb/pkg/caveats"
caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
"github.com/authzed/spicedb/pkg/datastore"
Expand All @@ -14,6 +17,8 @@ import (
"github.com/authzed/spicedb/pkg/spiceerrors"
)

var tracer = otel.Tracer("spicedb/internal/caveats/run")

// RunCaveatExpressionDebugOption are the options for running caveat expression evaluation
// with debugging enabled or disabled.
type RunCaveatExpressionDebugOption int
Expand All @@ -26,125 +31,101 @@ const (
RunCaveatExpressionWithDebugInformation RunCaveatExpressionDebugOption = 1
)

// RunCaveatExpression runs a caveat expression over the given context and returns the result.
func RunCaveatExpression(
// RunSingleCaveatExpression runs a caveat expression over the given context and returns the result.
// This instantiates its own CaveatRunner, and should therefore only be used in one-off situations.
func RunSingleCaveatExpression(
ctx context.Context,
expr *core.CaveatExpression,
context map[string]any,
reader datastore.CaveatReader,
debugOption RunCaveatExpressionDebugOption,
) (ExpressionResult, error) {
env := caveats.NewEnvironment()
return runExpression(ctx, env, expr, context, reader, debugOption)
}

// ExpressionResult is the result of a caveat expression being run.
type ExpressionResult interface {
// Value is the resolved value for the expression. For partially applied expressions, this value will be false.
Value() bool

// IsPartial returns whether the expression was only partially applied.
IsPartial() bool

// MissingVarNames returns the names of the parameters missing from the context.
MissingVarNames() ([]string, error)
runner := NewCaveatRunner()
return runner.RunCaveatExpression(ctx, expr, context, reader, debugOption)
}

type syntheticResult struct {
value bool
isPartialResult bool

op core.CaveatOperation_Operation
exprResultsForDebug []ExpressionResult
missingVarNames *mapz.Set[string]
}

func (sr syntheticResult) Value() bool {
return sr.value
}

func (sr syntheticResult) IsPartial() bool {
return sr.isPartialResult
// CaveatRunner is a helper for running caveats, providing a cache for deserialized caveats.
type CaveatRunner struct {
caveatDefs map[string]*core.CaveatDefinition
deserializedCaveats map[string]*caveats.CompiledCaveat
}

func (sr syntheticResult) MissingVarNames() ([]string, error) {
if sr.isPartialResult {
if sr.missingVarNames != nil {
return sr.missingVarNames.AsSlice(), nil
}

missingVarNames := mapz.NewSet[string]()
for _, exprResult := range sr.exprResultsForDebug {
if exprResult.IsPartial() {
found, err := exprResult.MissingVarNames()
if err != nil {
return nil, err
}

missingVarNames.Extend(found)
}
}

return missingVarNames.AsSlice(), nil
// NewCaveatRunner creates a new CaveatRunner.
func NewCaveatRunner() *CaveatRunner {
return &CaveatRunner{
caveatDefs: map[string]*core.CaveatDefinition{},
deserializedCaveats: map[string]*caveats.CompiledCaveat{},
}

return nil, fmt.Errorf("not a partial value")
}

func isFalseResult(result ExpressionResult) bool {
return !result.Value() && !result.IsPartial()
}

func isTrueResult(result ExpressionResult) bool {
return result.Value() && !result.IsPartial()
}

func runExpression(
// RunCaveatExpression runs a caveat expression over the given context and returns the result.
func (cr *CaveatRunner) RunCaveatExpression(
ctx context.Context,
env *caveats.Environment,
expr *core.CaveatExpression,
context map[string]any,
reader datastore.CaveatReader,
debugOption RunCaveatExpressionDebugOption,
) (ExpressionResult, error) {
ctx, span := tracer.Start(ctx, "RunCaveatExpression")
defer span.End()

if err := cr.PopulateCaveatDefinitionsForExpr(ctx, expr, reader); err != nil {
return nil, err
}

env := caveats.NewEnvironment()
return cr.runExpressionWithCaveats(ctx, env, expr, context, debugOption)
}

// PopulateCaveatDefinitionsForExpr populates the CaveatRunner's cache with the definitions
// referenced in the given caveat expression.
func (cr *CaveatRunner) PopulateCaveatDefinitionsForExpr(ctx context.Context, expr *core.CaveatExpression, reader datastore.CaveatReader) error {
ctx, span := tracer.Start(ctx, "PopulateCaveatDefinitions")
defer span.End()

// Collect all referenced caveat definitions in the expression.
caveatNames := mapz.NewSet[string]()
collectCaveatNames(expr, caveatNames)

span.AddEvent("collected caveat names")
span.SetAttributes(attribute.StringSlice("caveat-names", caveatNames.AsSlice()))

if caveatNames.IsEmpty() {
return nil, fmt.Errorf("received empty caveat expression")
return fmt.Errorf("received empty caveat expression")
}

// Remove any caveats already loaded.
for name := range cr.caveatDefs {
caveatNames.Delete(name)
}

if caveatNames.IsEmpty() {
return nil
}

// Bulk lookup all of the referenced caveat definitions.
caveatDefs, err := reader.LookupCaveatsWithNames(ctx, caveatNames.AsSlice())
if err != nil {
return nil, err
}

lc := loadedCaveats{
caveatDefs: map[string]*core.CaveatDefinition{},
deserializedCaveats: map[string]*caveats.CompiledCaveat{},
return err
}
span.AddEvent("looked up caveats")

for _, cd := range caveatDefs {
lc.caveatDefs[cd.Definition.GetName()] = cd.Definition
cr.caveatDefs[cd.Definition.GetName()] = cd.Definition
}

return runExpressionWithCaveats(ctx, env, expr, context, lc, debugOption)
return nil
}

type loadedCaveats struct {
caveatDefs map[string]*core.CaveatDefinition
deserializedCaveats map[string]*caveats.CompiledCaveat
}

func (lc loadedCaveats) Get(caveatDefName string) (*core.CaveatDefinition, *caveats.CompiledCaveat, error) {
caveat, ok := lc.caveatDefs[caveatDefName]
// get retrieves a caveat definition and its deserialized form. The caveat name must be
// present in the CaveatRunner's cache.
func (cr *CaveatRunner) get(caveatDefName string) (*core.CaveatDefinition, *caveats.CompiledCaveat, error) {
caveat, ok := cr.caveatDefs[caveatDefName]
if !ok {
return nil, nil, datastore.NewCaveatNameNotFoundErr(caveatDefName)
}

deserialized, ok := lc.deserializedCaveats[caveatDefName]
deserialized, ok := cr.deserializedCaveats[caveatDefName]
if ok {
return caveat, deserialized, nil
}
Expand All @@ -159,20 +140,36 @@ func (lc loadedCaveats) Get(caveatDefName string) (*core.CaveatDefinition, *cave
return caveat, nil, err
}

lc.deserializedCaveats[caveatDefName] = justDeserialized
cr.deserializedCaveats[caveatDefName] = justDeserialized
return caveat, justDeserialized, nil
}

func runExpressionWithCaveats(
func collectCaveatNames(expr *core.CaveatExpression, caveatNames *mapz.Set[string]) {
if expr.GetCaveat() != nil {
caveatNames.Add(expr.GetCaveat().CaveatName)
return
}

cop := expr.GetOperation()
for _, child := range cop.Children {
collectCaveatNames(child, caveatNames)
}
}

func (cr *CaveatRunner) runExpressionWithCaveats(
ctx context.Context,
env *caveats.Environment,
expr *core.CaveatExpression,
context map[string]any,
loadedCaveats loadedCaveats,
debugOption RunCaveatExpressionDebugOption,
) (ExpressionResult, error) {
ctx, span := tracer.Start(ctx, "runExpressionWithCaveats")
defer span.End()

if expr.GetCaveat() != nil {
caveat, compiled, err := loadedCaveats.Get(expr.GetCaveat().CaveatName)
span.SetAttributes(attribute.String("caveat-name", expr.GetCaveat().CaveatName))

caveat, compiled, err := cr.get(expr.GetCaveat().CaveatName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -210,6 +207,8 @@ func runExpressionWithCaveats(
}

cop := expr.GetOperation()
span.SetAttributes(attribute.String("caveat-operation", cop.Op.String()))

var currentResult ExpressionResult = syntheticResult{
value: cop.Op == core.CaveatOperation_AND,
isPartialResult: false,
Expand Down Expand Up @@ -311,7 +310,7 @@ func runExpressionWithCaveats(
}

for _, child := range cop.Children {
childResult, err := runExpressionWithCaveats(ctx, env, child, context, loadedCaveats, debugOption)
childResult, err := cr.runExpressionWithCaveats(ctx, env, child, context, debugOption)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -361,14 +360,63 @@ func runExpressionWithCaveats(
return currentResult, nil
}

func collectCaveatNames(expr *core.CaveatExpression, caveatNames *mapz.Set[string]) {
if expr.GetCaveat() != nil {
caveatNames.Add(expr.GetCaveat().CaveatName)
return
}
// ExpressionResult is the result of a caveat expression being run.
type ExpressionResult interface {
// Value is the resolved value for the expression. For partially applied expressions, this value will be false.
Value() bool

cop := expr.GetOperation()
for _, child := range cop.Children {
collectCaveatNames(child, caveatNames)
// IsPartial returns whether the expression was only partially applied.
IsPartial() bool

// MissingVarNames returns the names of the parameters missing from the context.
MissingVarNames() ([]string, error)
}

type syntheticResult struct {
value bool
isPartialResult bool

op core.CaveatOperation_Operation
exprResultsForDebug []ExpressionResult
missingVarNames *mapz.Set[string]
}

func (sr syntheticResult) Value() bool {
return sr.value
}

func (sr syntheticResult) IsPartial() bool {
return sr.isPartialResult
}

func (sr syntheticResult) MissingVarNames() ([]string, error) {
if sr.isPartialResult {
if sr.missingVarNames != nil {
return sr.missingVarNames.AsSlice(), nil
}

missingVarNames := mapz.NewSet[string]()
for _, exprResult := range sr.exprResultsForDebug {
if exprResult.IsPartial() {
found, err := exprResult.MissingVarNames()
if err != nil {
return nil, err
}

missingVarNames.Extend(found)
}
}

return missingVarNames.AsSlice(), nil
}

return nil, fmt.Errorf("not a partial value")
}

func isFalseResult(result ExpressionResult) bool {
return !result.Value() && !result.IsPartial()
}

func isTrueResult(result ExpressionResult) bool {
return result.Value() && !result.IsPartial()
}
6 changes: 3 additions & 3 deletions internal/caveats/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ func TestRunCaveatExpressions(t *testing.T) {
t.Run(fmt.Sprintf("%v", debugOption), func(t *testing.T) {
req := require.New(t)

result, err := caveats.RunCaveatExpression(context.Background(), tc.expression, tc.context, reader, debugOption)
result, err := caveats.RunSingleCaveatExpression(context.Background(), tc.expression, tc.context, reader, debugOption)
req.NoError(err)
req.Equal(tc.expectedValue, result.Value())

Expand Down Expand Up @@ -520,7 +520,7 @@ func TestRunCaveatWithMissingMap(t *testing.T) {

reader := ds.SnapshotReader(headRevision)

result, err := caveats.RunCaveatExpression(
result, err := caveats.RunSingleCaveatExpression(
context.Background(),
caveatexpr("some_caveat"),
map[string]any{},
Expand Down Expand Up @@ -549,7 +549,7 @@ func TestRunCaveatWithEmptyMap(t *testing.T) {

reader := ds.SnapshotReader(headRevision)

_, err = caveats.RunCaveatExpression(
_, err = caveats.RunSingleCaveatExpression(
context.Background(),
caveatexpr("some_caveat"),
map[string]any{
Expand Down
8 changes: 5 additions & 3 deletions internal/graph/computed/computecheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ func computeCheck(ctx context.Context,
return nil, nil, spiceerrors.MustBugf("failed to create new traversal bloom filter")
}

caveatRunner := cexpr.NewCaveatRunner()

// TODO(jschorr): Should we make this run in parallel via the preloadedTaskRunner?
_, err = slicez.ForEachChunkUntil(resourceIDs, dispatchChunkSize, func(resourceIDsToCheck []string) (bool, error) {
checkResult, err := d.DispatchCheck(ctx, &v1.DispatchCheckRequest{
Expand Down Expand Up @@ -141,7 +143,7 @@ func computeCheck(ctx context.Context,
}

for _, resourceID := range resourceIDsToCheck {
computed, err := computeCaveatedCheckResult(ctx, params, resourceID, checkResult)
computed, err := computeCaveatedCheckResult(ctx, caveatRunner, params, resourceID, checkResult)
if err != nil {
return false, err
}
Expand All @@ -153,7 +155,7 @@ func computeCheck(ctx context.Context,
return results, metadata, err
}

func computeCaveatedCheckResult(ctx context.Context, params CheckParameters, resourceID string, checkResult *v1.DispatchCheckResponse) (*v1.ResourceCheckResult, error) {
func computeCaveatedCheckResult(ctx context.Context, runner *cexpr.CaveatRunner, params CheckParameters, resourceID string, checkResult *v1.DispatchCheckResponse) (*v1.ResourceCheckResult, error) {
result, ok := checkResult.ResultsByResourceId[resourceID]
if !ok {
return &v1.ResourceCheckResult{
Expand All @@ -168,7 +170,7 @@ func computeCaveatedCheckResult(ctx context.Context, params CheckParameters, res
ds := datastoremw.MustFromContext(ctx)
reader := ds.SnapshotReader(params.AtRevision)

caveatResult, err := cexpr.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging)
caveatResult, err := runner.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 3ab909a

Please sign in to comment.