diff --git a/internal/caveats/run.go b/internal/caveats/run.go index 030afa6de7..c0afa8a1a0 100644 --- a/internal/caveats/run.go +++ b/internal/caveats/run.go @@ -6,6 +6,9 @@ import ( "fmt" "maps" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "github.com/authzed/spicedb/pkg/caveats" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" "github.com/authzed/spicedb/pkg/datastore" @@ -14,6 +17,8 @@ import ( "github.com/authzed/spicedb/pkg/spiceerrors" ) +var tracer = otel.Tracer("spicedb/internal/caveats/run") + // RunCaveatExpressionDebugOption are the options for running caveat expression evaluation // with debugging enabled or disabled. type RunCaveatExpressionDebugOption int @@ -26,125 +31,101 @@ const ( RunCaveatExpressionWithDebugInformation RunCaveatExpressionDebugOption = 1 ) -// RunCaveatExpression runs a caveat expression over the given context and returns the result. -func RunCaveatExpression( +// RunSingleCaveatExpression runs a caveat expression over the given context and returns the result. +// This instantiates its own CaveatRunner, and should therefore only be used in one-off situations. +func RunSingleCaveatExpression( ctx context.Context, expr *core.CaveatExpression, context map[string]any, reader datastore.CaveatReader, debugOption RunCaveatExpressionDebugOption, ) (ExpressionResult, error) { - env := caveats.NewEnvironment() - return runExpression(ctx, env, expr, context, reader, debugOption) -} - -// ExpressionResult is the result of a caveat expression being run. -type ExpressionResult interface { - // Value is the resolved value for the expression. For partially applied expressions, this value will be false. - Value() bool - - // IsPartial returns whether the expression was only partially applied. - IsPartial() bool - - // MissingVarNames returns the names of the parameters missing from the context. - MissingVarNames() ([]string, error) + runner := NewCaveatRunner() + return runner.RunCaveatExpression(ctx, expr, context, reader, debugOption) } -type syntheticResult struct { - value bool - isPartialResult bool - - op core.CaveatOperation_Operation - exprResultsForDebug []ExpressionResult - missingVarNames *mapz.Set[string] -} - -func (sr syntheticResult) Value() bool { - return sr.value -} - -func (sr syntheticResult) IsPartial() bool { - return sr.isPartialResult +// CaveatRunner is a helper for running caveats, providing a cache for deserialized caveats. +type CaveatRunner struct { + caveatDefs map[string]*core.CaveatDefinition + deserializedCaveats map[string]*caveats.CompiledCaveat } -func (sr syntheticResult) MissingVarNames() ([]string, error) { - if sr.isPartialResult { - if sr.missingVarNames != nil { - return sr.missingVarNames.AsSlice(), nil - } - - missingVarNames := mapz.NewSet[string]() - for _, exprResult := range sr.exprResultsForDebug { - if exprResult.IsPartial() { - found, err := exprResult.MissingVarNames() - if err != nil { - return nil, err - } - - missingVarNames.Extend(found) - } - } - - return missingVarNames.AsSlice(), nil +// NewCaveatRunner creates a new CaveatRunner. +func NewCaveatRunner() *CaveatRunner { + return &CaveatRunner{ + caveatDefs: map[string]*core.CaveatDefinition{}, + deserializedCaveats: map[string]*caveats.CompiledCaveat{}, } - - return nil, fmt.Errorf("not a partial value") -} - -func isFalseResult(result ExpressionResult) bool { - return !result.Value() && !result.IsPartial() -} - -func isTrueResult(result ExpressionResult) bool { - return result.Value() && !result.IsPartial() } -func runExpression( +// RunCaveatExpression runs a caveat expression over the given context and returns the result. +func (cr *CaveatRunner) RunCaveatExpression( ctx context.Context, - env *caveats.Environment, expr *core.CaveatExpression, context map[string]any, reader datastore.CaveatReader, debugOption RunCaveatExpressionDebugOption, ) (ExpressionResult, error) { + ctx, span := tracer.Start(ctx, "RunCaveatExpression") + defer span.End() + + if err := cr.PopulateCaveatDefinitionsForExpr(ctx, expr, reader); err != nil { + return nil, err + } + + env := caveats.NewEnvironment() + return cr.runExpressionWithCaveats(ctx, env, expr, context, debugOption) +} + +// PopulateCaveatDefinitionsForExpr populates the CaveatRunner's cache with the definitions +// referenced in the given caveat expression. +func (cr *CaveatRunner) PopulateCaveatDefinitionsForExpr(ctx context.Context, expr *core.CaveatExpression, reader datastore.CaveatReader) error { + ctx, span := tracer.Start(ctx, "PopulateCaveatDefinitions") + defer span.End() + // Collect all referenced caveat definitions in the expression. caveatNames := mapz.NewSet[string]() collectCaveatNames(expr, caveatNames) + span.AddEvent("collected caveat names") + span.SetAttributes(attribute.StringSlice("caveat-names", caveatNames.AsSlice())) + if caveatNames.IsEmpty() { - return nil, fmt.Errorf("received empty caveat expression") + return fmt.Errorf("received empty caveat expression") + } + + // Remove any caveats already loaded. + for name := range cr.caveatDefs { + caveatNames.Delete(name) + } + + if caveatNames.IsEmpty() { + return nil } // Bulk lookup all of the referenced caveat definitions. caveatDefs, err := reader.LookupCaveatsWithNames(ctx, caveatNames.AsSlice()) if err != nil { - return nil, err - } - - lc := loadedCaveats{ - caveatDefs: map[string]*core.CaveatDefinition{}, - deserializedCaveats: map[string]*caveats.CompiledCaveat{}, + return err } + span.AddEvent("looked up caveats") for _, cd := range caveatDefs { - lc.caveatDefs[cd.Definition.GetName()] = cd.Definition + cr.caveatDefs[cd.Definition.GetName()] = cd.Definition } - return runExpressionWithCaveats(ctx, env, expr, context, lc, debugOption) + return nil } -type loadedCaveats struct { - caveatDefs map[string]*core.CaveatDefinition - deserializedCaveats map[string]*caveats.CompiledCaveat -} - -func (lc loadedCaveats) Get(caveatDefName string) (*core.CaveatDefinition, *caveats.CompiledCaveat, error) { - caveat, ok := lc.caveatDefs[caveatDefName] +// get retrieves a caveat definition and its deserialized form. The caveat name must be +// present in the CaveatRunner's cache. +func (cr *CaveatRunner) get(caveatDefName string) (*core.CaveatDefinition, *caveats.CompiledCaveat, error) { + caveat, ok := cr.caveatDefs[caveatDefName] if !ok { return nil, nil, datastore.NewCaveatNameNotFoundErr(caveatDefName) } - deserialized, ok := lc.deserializedCaveats[caveatDefName] + deserialized, ok := cr.deserializedCaveats[caveatDefName] if ok { return caveat, deserialized, nil } @@ -159,20 +140,36 @@ func (lc loadedCaveats) Get(caveatDefName string) (*core.CaveatDefinition, *cave return caveat, nil, err } - lc.deserializedCaveats[caveatDefName] = justDeserialized + cr.deserializedCaveats[caveatDefName] = justDeserialized return caveat, justDeserialized, nil } -func runExpressionWithCaveats( +func collectCaveatNames(expr *core.CaveatExpression, caveatNames *mapz.Set[string]) { + if expr.GetCaveat() != nil { + caveatNames.Add(expr.GetCaveat().CaveatName) + return + } + + cop := expr.GetOperation() + for _, child := range cop.Children { + collectCaveatNames(child, caveatNames) + } +} + +func (cr *CaveatRunner) runExpressionWithCaveats( ctx context.Context, env *caveats.Environment, expr *core.CaveatExpression, context map[string]any, - loadedCaveats loadedCaveats, debugOption RunCaveatExpressionDebugOption, ) (ExpressionResult, error) { + ctx, span := tracer.Start(ctx, "runExpressionWithCaveats") + defer span.End() + if expr.GetCaveat() != nil { - caveat, compiled, err := loadedCaveats.Get(expr.GetCaveat().CaveatName) + span.SetAttributes(attribute.String("caveat-name", expr.GetCaveat().CaveatName)) + + caveat, compiled, err := cr.get(expr.GetCaveat().CaveatName) if err != nil { return nil, err } @@ -210,6 +207,8 @@ func runExpressionWithCaveats( } cop := expr.GetOperation() + span.SetAttributes(attribute.String("caveat-operation", cop.Op.String())) + var currentResult ExpressionResult = syntheticResult{ value: cop.Op == core.CaveatOperation_AND, isPartialResult: false, @@ -311,7 +310,7 @@ func runExpressionWithCaveats( } for _, child := range cop.Children { - childResult, err := runExpressionWithCaveats(ctx, env, child, context, loadedCaveats, debugOption) + childResult, err := cr.runExpressionWithCaveats(ctx, env, child, context, debugOption) if err != nil { return nil, err } @@ -361,14 +360,63 @@ func runExpressionWithCaveats( return currentResult, nil } -func collectCaveatNames(expr *core.CaveatExpression, caveatNames *mapz.Set[string]) { - if expr.GetCaveat() != nil { - caveatNames.Add(expr.GetCaveat().CaveatName) - return - } +// ExpressionResult is the result of a caveat expression being run. +type ExpressionResult interface { + // Value is the resolved value for the expression. For partially applied expressions, this value will be false. + Value() bool - cop := expr.GetOperation() - for _, child := range cop.Children { - collectCaveatNames(child, caveatNames) + // IsPartial returns whether the expression was only partially applied. + IsPartial() bool + + // MissingVarNames returns the names of the parameters missing from the context. + MissingVarNames() ([]string, error) +} + +type syntheticResult struct { + value bool + isPartialResult bool + + op core.CaveatOperation_Operation + exprResultsForDebug []ExpressionResult + missingVarNames *mapz.Set[string] +} + +func (sr syntheticResult) Value() bool { + return sr.value +} + +func (sr syntheticResult) IsPartial() bool { + return sr.isPartialResult +} + +func (sr syntheticResult) MissingVarNames() ([]string, error) { + if sr.isPartialResult { + if sr.missingVarNames != nil { + return sr.missingVarNames.AsSlice(), nil + } + + missingVarNames := mapz.NewSet[string]() + for _, exprResult := range sr.exprResultsForDebug { + if exprResult.IsPartial() { + found, err := exprResult.MissingVarNames() + if err != nil { + return nil, err + } + + missingVarNames.Extend(found) + } + } + + return missingVarNames.AsSlice(), nil } + + return nil, fmt.Errorf("not a partial value") +} + +func isFalseResult(result ExpressionResult) bool { + return !result.Value() && !result.IsPartial() +} + +func isTrueResult(result ExpressionResult) bool { + return result.Value() && !result.IsPartial() } diff --git a/internal/caveats/run_test.go b/internal/caveats/run_test.go index f4938741fa..e2d80acc51 100644 --- a/internal/caveats/run_test.go +++ b/internal/caveats/run_test.go @@ -476,7 +476,7 @@ func TestRunCaveatExpressions(t *testing.T) { t.Run(fmt.Sprintf("%v", debugOption), func(t *testing.T) { req := require.New(t) - result, err := caveats.RunCaveatExpression(context.Background(), tc.expression, tc.context, reader, debugOption) + result, err := caveats.RunSingleCaveatExpression(context.Background(), tc.expression, tc.context, reader, debugOption) req.NoError(err) req.Equal(tc.expectedValue, result.Value()) @@ -520,7 +520,7 @@ func TestRunCaveatWithMissingMap(t *testing.T) { reader := ds.SnapshotReader(headRevision) - result, err := caveats.RunCaveatExpression( + result, err := caveats.RunSingleCaveatExpression( context.Background(), caveatexpr("some_caveat"), map[string]any{}, @@ -549,7 +549,7 @@ func TestRunCaveatWithEmptyMap(t *testing.T) { reader := ds.SnapshotReader(headRevision) - _, err = caveats.RunCaveatExpression( + _, err = caveats.RunSingleCaveatExpression( context.Background(), caveatexpr("some_caveat"), map[string]any{ diff --git a/internal/graph/computed/computecheck.go b/internal/graph/computed/computecheck.go index a3d86cb916..44c4b4b738 100644 --- a/internal/graph/computed/computecheck.go +++ b/internal/graph/computed/computecheck.go @@ -110,6 +110,8 @@ func computeCheck(ctx context.Context, return nil, nil, spiceerrors.MustBugf("failed to create new traversal bloom filter") } + caveatRunner := cexpr.NewCaveatRunner() + // TODO(jschorr): Should we make this run in parallel via the preloadedTaskRunner? _, err = slicez.ForEachChunkUntil(resourceIDs, dispatchChunkSize, func(resourceIDsToCheck []string) (bool, error) { checkResult, err := d.DispatchCheck(ctx, &v1.DispatchCheckRequest{ @@ -141,7 +143,7 @@ func computeCheck(ctx context.Context, } for _, resourceID := range resourceIDsToCheck { - computed, err := computeCaveatedCheckResult(ctx, params, resourceID, checkResult) + computed, err := computeCaveatedCheckResult(ctx, caveatRunner, params, resourceID, checkResult) if err != nil { return false, err } @@ -153,7 +155,7 @@ func computeCheck(ctx context.Context, return results, metadata, err } -func computeCaveatedCheckResult(ctx context.Context, params CheckParameters, resourceID string, checkResult *v1.DispatchCheckResponse) (*v1.ResourceCheckResult, error) { +func computeCaveatedCheckResult(ctx context.Context, runner *cexpr.CaveatRunner, params CheckParameters, resourceID string, checkResult *v1.DispatchCheckResponse) (*v1.ResourceCheckResult, error) { result, ok := checkResult.ResultsByResourceId[resourceID] if !ok { return &v1.ResourceCheckResult{ @@ -168,7 +170,7 @@ func computeCaveatedCheckResult(ctx context.Context, params CheckParameters, res ds := datastoremw.MustFromContext(ctx) reader := ds.SnapshotReader(params.AtRevision) - caveatResult, err := cexpr.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging) + caveatResult, err := runner.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging) if err != nil { return nil, err } diff --git a/internal/graph/lookupresources2.go b/internal/graph/lookupresources2.go index 9e60a6ebd0..e1fca79560 100644 --- a/internal/graph/lookupresources2.go +++ b/internal/graph/lookupresources2.go @@ -312,6 +312,7 @@ func (crr *CursoredLookupResources2) redispatchOrReportOverDatabaseQuery( rsm := newResourcesSubjectMap2WithCapacity(config.sourceResourceType, uint32(crr.dispatchChunkSize)) toBeHandled := make([]itemAndPostCursor[dispatchableResourcesSubjectMap2], 0) currentCursor := queryCursor + caveatRunner := caveats.NewCaveatRunner() for rel, err := range it { if err != nil { @@ -323,7 +324,7 @@ func (crr *CursoredLookupResources2) redispatchOrReportOverDatabaseQuery( // If a caveat exists on the relationship, run it and filter the results, marking those that have missing context. if rel.OptionalCaveat != nil && rel.OptionalCaveat.CaveatName != "" { caveatExpr := caveats.CaveatAsExpr(rel.OptionalCaveat) - runResult, err := caveats.RunCaveatExpression(ctx, caveatExpr, config.parentRequest.Context.AsMap(), config.reader, caveats.RunCaveatExpressionNoDebugging) + runResult, err := caveatRunner.RunCaveatExpression(ctx, caveatExpr, config.parentRequest.Context.AsMap(), config.reader, caveats.RunCaveatExpressionNoDebugging) if err != nil { return nil, err } diff --git a/internal/services/v1/debug.go b/internal/services/v1/debug.go index 3a940c78a7..156d9bdea9 100644 --- a/internal/services/v1/debug.go +++ b/internal/services/v1/debug.go @@ -92,7 +92,7 @@ func convertCheckTrace(ctx context.Context, caveatContext map[string]any, ct *di var caveatEvalInfo *v1.CaveatEvalInfo if permissionship == v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION && len(partialResults) == 1 { partialCheckResult := partialResults[0] - computedResult, err := cexpr.RunCaveatExpression(ctx, partialCheckResult.Expression, caveatContext, reader, cexpr.RunCaveatExpressionWithDebugInformation) + computedResult, err := cexpr.RunSingleCaveatExpression(ctx, partialCheckResult.Expression, caveatContext, reader, cexpr.RunCaveatExpressionWithDebugInformation) if err != nil { return nil, err } diff --git a/internal/services/v1/permissions.go b/internal/services/v1/permissions.go index 8105cd62b8..485b38c015 100644 --- a/internal/services/v1/permissions.go +++ b/internal/services/v1/permissions.go @@ -668,7 +668,7 @@ func foundSubjectToResolvedSubject(ctx context.Context, foundSubject *dispatch.F if foundSubject.GetCaveatExpression() != nil { permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION - cr, err := cexpr.RunCaveatExpression(ctx, foundSubject.GetCaveatExpression(), caveatContext, ds, cexpr.RunCaveatExpressionNoDebugging) + cr, err := cexpr.RunSingleCaveatExpression(ctx, foundSubject.GetCaveatExpression(), caveatContext, ds, cexpr.RunCaveatExpressionNoDebugging) if err != nil { return nil, err }