Skip to content

Prototyping: parser support for annotations. #1220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions parser/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ cc_library(
"//parser/internal:cel_cc_parser",
"@antlr4_runtimes//:cpp",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:overload",
Expand Down
25 changes: 25 additions & 0 deletions parser/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ struct ParserOptions final {
//
// Limited to field specifiers in select and message creation.
bool enable_quoted_identifiers = false;

// Enables support for the cel.annotate macro.
//
// Annotations are normally injected by higher level CEL tools to provide
// additional metadata about how to interpret or analyze the expression. This
// macro is intended for adding annotations in the source expression, using
// the same internal mechanisms as annotations added by tools.
//
// The macro takes two arguments:
//
// 1. The expression to annotate.
// 2. A list of annotations to apply to the expression.
//
// example:
// cel.annotate(foo.bar in baz,
// [cel.Annotation{name: "com.example.Explain",
// inspect_only: true,
// value: "check if foo.bar is in baz"}]
// )
//
// Permits the short hand if the annotation has no value:
// cel.annotate(foo.bar in baz, "com.example.MyAnnotation")
//
// The annotation is recorded in the source_info of the parsed expression.
bool enable_annotations = false;
};

} // namespace cel
Expand Down
248 changes: 230 additions & 18 deletions parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "cel/expr/syntax.pb.h"
#include "absl/base/macros.h"
#include "absl/base/nullability.h"
#include "absl/base/optimization.h"
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
Expand Down Expand Up @@ -601,23 +602,151 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) {
return factory_.NewCall(ops_[mid], function_, std::move(arguments));
}

// Lightweight overlay for a registry.
// Adds stateful macros that are relevant per Parse call.
class AugmentedMacroRegistry {
public:
explicit AugmentedMacroRegistry(const cel::MacroRegistry& registry)
: base_(registry) {}

cel::MacroRegistry& overlay() { return overlay_; }

absl::optional<Macro> FindMacro(absl::string_view name, size_t arg_count,
bool receiver_style) const;

private:
const cel::MacroRegistry& base_;
cel::MacroRegistry overlay_;
};

absl::optional<Macro> AugmentedMacroRegistry::FindMacro(
absl::string_view name, size_t arg_count, bool receiver_style) const {
auto result = overlay_.FindMacro(name, arg_count, receiver_style);
if (result.has_value()) {
return result;
}

return base_.FindMacro(name, arg_count, receiver_style);
}

bool IsSupportedAnnotation(const Expr& e) {
if (e.has_const_expr() && e.const_expr().has_string_value()) {
return true;
} else if (e.has_struct_expr() &&
e.struct_expr().name() == "cel.Annotation") {
for (const auto& field : e.struct_expr().fields()) {
if (field.name() != "name" && field.name() != "inspect_only" &&
field.name() != "value") {
return false;
}
}
return true;
}
return false;
}

class AnnotationCollector {
private:
struct AnnotationRep {
Expr expr;
};

struct MacroImpl {
absl::Nonnull<AnnotationCollector*> parent;

// Record a single annotation. Returns a non-empty optional if
// an error is encountered.
absl::optional<Expr> RecordAnnotation(cel::MacroExprFactory& mef,
int64_t id, Expr e) const;

// MacroExpander for "cel.annotate"
absl::optional<Expr> operator()(cel::MacroExprFactory& mef, Expr& target,
absl::Span<Expr> args) const;
};

void Add(int64_t annotated_expr, Expr value);

public:
const absl::btree_map<int64_t, std::vector<AnnotationRep>>& annotations() {
return annotations_;
}

absl::btree_map<int64_t, std::vector<AnnotationRep>> consume_annotations() {
using std::swap;
absl::btree_map<int64_t, std::vector<AnnotationRep>> result;
swap(result, annotations_);
return result;
}

Macro MakeAnnotationImpl() {
auto impl = Macro::Receiver("annotate", 2, MacroImpl{this});
ABSL_CHECK_OK(impl.status());
return std::move(impl).value();
}

private:
absl::btree_map<int64_t, std::vector<AnnotationRep>> annotations_;
};

absl::optional<Expr> AnnotationCollector::MacroImpl::RecordAnnotation(
cel::MacroExprFactory& mef, int64_t id, Expr e) const {
if (IsSupportedAnnotation(e)) {
parent->Add(id, std::move(e));
return absl::nullopt;
}

return mef.ReportErrorAt(
e,
"cel.annotate argument is not a cel.Annotation{} or string expression");
}

absl::optional<Expr> AnnotationCollector::MacroImpl::operator()(
cel::MacroExprFactory& mef, Expr& target, absl::Span<Expr> args) const {
if (!target.has_ident_expr() || target.ident_expr().name() != "cel") {
return absl::nullopt;
}

if (args.size() != 2) {
return mef.ReportErrorAt(
target, "wrong number of arguments for cel.annotate macro");
}

// arg0 (the annotated expression) is the expansion result. The remainder are
// annotations to record.
int64_t id = args[0].id();

absl::optional<Expr> result;
if (args[1].has_list_expr()) {
auto list = args[1].release_list_expr();
for (auto& e : list.mutable_elements()) {
result = RecordAnnotation(mef, id, e.release_expr());
if (result) {
break;
}
}
} else {
result = RecordAnnotation(mef, id, std::move(args[1]));
}

if (result) {
return result;
}

return std::move(args[0]);
}

void AnnotationCollector::Add(int64_t annotated_expr, Expr value) {
annotations_[annotated_expr].push_back({std::move(value)});
}

class ParserVisitor final : public CelBaseVisitor,
public antlr4::BaseErrorListener {
public:
ParserVisitor(const cel::Source& source, int max_recursion_depth,
absl::string_view accu_var,
const cel::MacroRegistry& macro_registry,
bool add_macro_calls = false,
bool enable_optional_syntax = false,
bool enable_quoted_identifiers = false)
: source_(source),
factory_(source_, accu_var),
macro_registry_(macro_registry),
recursion_depth_(0),
max_recursion_depth_(max_recursion_depth),
add_macro_calls_(add_macro_calls),
enable_optional_syntax_(enable_optional_syntax),
enable_quoted_identifiers_(enable_quoted_identifiers) {}
const cel::MacroRegistry& macro_registry, bool add_macro_calls,
bool enable_optional_syntax, bool enable_quoted_identifiers,
bool enable_annotations);

~ParserVisitor() override = default;

Expand Down Expand Up @@ -675,6 +804,8 @@ class ParserVisitor final : public CelBaseVisitor,

std::string ErrorMessage();

Expr PackAnnotations(Expr ast);

private:
template <typename... Args>
Expr GlobalCallOrMacro(int64_t expr_id, absl::string_view function,
Expand Down Expand Up @@ -702,14 +833,38 @@ class ParserVisitor final : public CelBaseVisitor,
private:
const cel::Source& source_;
cel::ParserMacroExprFactory factory_;
const cel::MacroRegistry& macro_registry_;
AugmentedMacroRegistry macro_registry_;
AnnotationCollector annotations_;
int recursion_depth_;
const int max_recursion_depth_;
const bool add_macro_calls_;
const bool enable_optional_syntax_;
const bool enable_quoted_identifiers_;
const bool enable_annotations_;
};

ParserVisitor::ParserVisitor(const cel::Source& source, int max_recursion_depth,
absl::string_view accu_var,
const cel::MacroRegistry& macro_registry,
bool add_macro_calls, bool enable_optional_syntax,
bool enable_quoted_identifiers,
bool enable_annotations)
: source_(source),
factory_(source_, accu_var),
macro_registry_(macro_registry),
recursion_depth_(0),
max_recursion_depth_(max_recursion_depth),
add_macro_calls_(add_macro_calls),
enable_optional_syntax_(enable_optional_syntax),
enable_quoted_identifiers_(enable_quoted_identifiers),
enable_annotations_(enable_annotations) {
if (enable_annotations_) {
macro_registry_.overlay()
.RegisterMacro(annotations_.MakeAnnotationImpl())
.IgnoreError();
}
}

template <typename T, typename = std::enable_if_t<
std::is_base_of<antlr4::tree::ParseTree, T>::value>>
T* tree_as(antlr4::tree::ParseTree* tree) {
Expand Down Expand Up @@ -1638,6 +1793,61 @@ struct ParseResult {
EnrichedSourceInfo enriched_source_info;
};

Expr NormalizeAnnotation(cel::ParserMacroExprFactory& mef, Expr expr) {
if (expr.has_struct_expr()) {
return expr;
}

if (expr.has_const_expr()) {
std::vector<cel::StructExprField> fields;
fields.reserve(2);
fields.push_back(
mef.NewStructField(mef.NextId({}), "name", std::move(expr)));
auto bool_const = mef.NewBoolConst(mef.NextId({}), true);
fields.push_back(mef.NewStructField(mef.NextId({}), "inspect_only",
std::move(bool_const)));
return mef.NewStruct(mef.NextId({}), "cel.Annotation", std::move(fields));
}

return mef.ReportError("invalid annotation encountered finalizing AST");
}

Expr ParserVisitor::PackAnnotations(Expr ast) {
if (annotations_.annotations().empty()) {
return ast;
}

auto annotations = annotations_.consume_annotations();
std::vector<MapExprEntry> entries;
entries.reserve(annotations.size());

for (auto& annotation : annotations) {
std::vector<cel::ListExprElement> annotation_values;
annotation_values.reserve(annotation.second.size());

for (auto& annotation_value : annotation.second) {
auto annotation =
NormalizeAnnotation(factory_, std::move(annotation_value.expr));
annotation_values.push_back(
factory_.NewListElement(std::move(annotation)));
}
auto id = factory_.NewIntConst(factory_.NextId({}), annotation.first);
auto annotation_list =
factory_.NewList(factory_.NextId({}), std::move(annotation_values));
entries.push_back(factory_.NewMapEntry(factory_.NextId({}), std::move(id),
std::move(annotation_list)));
}

std::vector<Expr> args;
args.push_back(std::move(ast));
args.push_back(factory_.NewMap(factory_.NextId({}), std::move(entries)));

auto result =
factory_.NewCall(factory_.NextId({}), "cel.@annotated", std::move(args));

return result;
}

absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
const cel::MacroRegistry& registry,
const ParserOptions& options) {
Expand All @@ -1656,10 +1866,10 @@ absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
if (options.enable_hidden_accumulator_var) {
accu_var = cel::kHiddenAccumulatorVariableName;
}
ParserVisitor visitor(source, options.max_recursion_depth, accu_var,
registry, options.add_macro_calls,
options.enable_optional_syntax,
options.enable_quoted_identifiers);
ParserVisitor visitor(
source, options.max_recursion_depth, accu_var, registry,
options.add_macro_calls, options.enable_optional_syntax,
options.enable_quoted_identifiers, options.enable_annotations);

lexer.removeErrorListeners();
parser.removeErrorListeners();
Expand All @@ -1686,7 +1896,9 @@ absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
if (visitor.HasErrored()) {
return absl::InvalidArgumentError(visitor.ErrorMessage());
}

if (options.enable_annotations) {
expr = visitor.PackAnnotations(std::move(expr));
}
return {
ParseResult{.expr = std::move(expr),
.source_info = visitor.GetSourceInfo(),
Expand Down
Loading