32
32
33
33
#include " cel/expr/syntax.pb.h"
34
34
#include " absl/base/macros.h"
35
+ #include " absl/base/nullability.h"
35
36
#include " absl/base/optimization.h"
36
37
#include " absl/container/btree_map.h"
37
38
#include " absl/container/flat_hash_map.h"
@@ -601,23 +602,151 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) {
601
602
return factory_.NewCall (ops_[mid], function_, std::move (arguments));
602
603
}
603
604
605
+ // Lightweight overlay for a registry.
606
+ // Adds stateful macros that are relevant per Parse call.
607
+ class AugmentedMacroRegistry {
608
+ public:
609
+ explicit AugmentedMacroRegistry (const cel::MacroRegistry& registry)
610
+ : base_(registry) {}
611
+
612
+ cel::MacroRegistry& overlay () { return overlay_; }
613
+
614
+ absl::optional<Macro> FindMacro (absl::string_view name, size_t arg_count,
615
+ bool receiver_style) const ;
616
+
617
+ private:
618
+ const cel::MacroRegistry& base_;
619
+ cel::MacroRegistry overlay_;
620
+ };
621
+
622
+ absl::optional<Macro> AugmentedMacroRegistry::FindMacro (
623
+ absl::string_view name, size_t arg_count, bool receiver_style) const {
624
+ auto result = overlay_.FindMacro (name, arg_count, receiver_style);
625
+ if (result.has_value ()) {
626
+ return result;
627
+ }
628
+
629
+ return base_.FindMacro (name, arg_count, receiver_style);
630
+ }
631
+
632
+ bool IsSupportedAnnotation (const Expr& e) {
633
+ if (e.has_const_expr () && e.const_expr ().has_string_value ()) {
634
+ return true ;
635
+ } else if (e.has_struct_expr () &&
636
+ e.struct_expr ().name () == " cel.Annotation" ) {
637
+ for (const auto & field : e.struct_expr ().fields ()) {
638
+ if (field.name () != " name" && field.name () != " inspect_only" &&
639
+ field.name () != " value" ) {
640
+ return false ;
641
+ }
642
+ }
643
+ return true ;
644
+ }
645
+ return false ;
646
+ }
647
+
648
+ class AnnotationCollector {
649
+ private:
650
+ struct AnnotationRep {
651
+ Expr expr;
652
+ };
653
+
654
+ struct MacroImpl {
655
+ absl::Nonnull<AnnotationCollector*> parent;
656
+
657
+ // Record a single annotation. Returns a non-empty optional if
658
+ // an error is encountered.
659
+ absl::optional<Expr> RecordAnnotation (cel::MacroExprFactory& mef,
660
+ int64_t id, Expr e) const ;
661
+
662
+ // MacroExpander for "cel.annotate"
663
+ absl::optional<Expr> operator ()(cel::MacroExprFactory& mef, Expr& target,
664
+ absl::Span<Expr> args) const ;
665
+ };
666
+
667
+ void Add (int64_t annotated_expr, Expr value);
668
+
669
+ public:
670
+ const absl::btree_map<int64_t , std::vector<AnnotationRep>>& annotations () {
671
+ return annotations_;
672
+ }
673
+
674
+ absl::btree_map<int64_t , std::vector<AnnotationRep>> consume_annotations () {
675
+ using std::swap;
676
+ absl::btree_map<int64_t , std::vector<AnnotationRep>> result;
677
+ swap (result, annotations_);
678
+ return result;
679
+ }
680
+
681
+ Macro MakeAnnotationImpl () {
682
+ auto impl = Macro::Receiver (" annotate" , 2 , MacroImpl{this });
683
+ ABSL_CHECK_OK (impl.status ());
684
+ return std::move (impl).value ();
685
+ }
686
+
687
+ private:
688
+ absl::btree_map<int64_t , std::vector<AnnotationRep>> annotations_;
689
+ };
690
+
691
+ absl::optional<Expr> AnnotationCollector::MacroImpl::RecordAnnotation (
692
+ cel::MacroExprFactory& mef, int64_t id, Expr e) const {
693
+ if (IsSupportedAnnotation (e)) {
694
+ parent->Add (id, std::move (e));
695
+ return absl::nullopt;
696
+ }
697
+
698
+ return mef.ReportErrorAt (
699
+ e,
700
+ " cel.annotate argument is not a cel.Annotation{} or string expression" );
701
+ }
702
+
703
+ absl::optional<Expr> AnnotationCollector::MacroImpl::operator ()(
704
+ cel::MacroExprFactory& mef, Expr& target, absl::Span<Expr> args) const {
705
+ if (!target.has_ident_expr () || target.ident_expr ().name () != " cel" ) {
706
+ return absl::nullopt;
707
+ }
708
+
709
+ if (args.size () != 2 ) {
710
+ return mef.ReportErrorAt (
711
+ target, " wrong number of arguments for cel.annotate macro" );
712
+ }
713
+
714
+ // arg0 (the annotated expression) is the expansion result. The remainder are
715
+ // annotations to record.
716
+ int64_t id = args[0 ].id ();
717
+
718
+ absl::optional<Expr> result;
719
+ if (args[1 ].has_list_expr ()) {
720
+ auto list = args[1 ].release_list_expr ();
721
+ for (auto & e : list.mutable_elements ()) {
722
+ result = RecordAnnotation (mef, id, e.release_expr ());
723
+ if (result) {
724
+ break ;
725
+ }
726
+ }
727
+ } else {
728
+ result = RecordAnnotation (mef, id, std::move (args[1 ]));
729
+ }
730
+
731
+ if (result) {
732
+ return result;
733
+ }
734
+
735
+ return std::move (args[0 ]);
736
+ }
737
+
738
+ void AnnotationCollector::Add (int64_t annotated_expr, Expr value) {
739
+ annotations_[annotated_expr].push_back ({std::move (value)});
740
+ }
741
+
604
742
class ParserVisitor final : public CelBaseVisitor,
605
743
public antlr4::BaseErrorListener {
606
744
public:
607
745
ParserVisitor (const cel::Source& source, int max_recursion_depth,
608
746
absl::string_view accu_var,
609
- const cel::MacroRegistry& macro_registry,
610
- bool add_macro_calls = false ,
611
- bool enable_optional_syntax = false ,
612
- bool enable_quoted_identifiers = false )
613
- : source_(source),
614
- factory_ (source_, accu_var),
615
- macro_registry_(macro_registry),
616
- recursion_depth_(0 ),
617
- max_recursion_depth_(max_recursion_depth),
618
- add_macro_calls_(add_macro_calls),
619
- enable_optional_syntax_(enable_optional_syntax),
620
- enable_quoted_identifiers_(enable_quoted_identifiers) {}
747
+ const cel::MacroRegistry& macro_registry, bool add_macro_calls,
748
+ bool enable_optional_syntax, bool enable_quoted_identifiers,
749
+ bool enable_annotations);
621
750
622
751
~ParserVisitor () override = default ;
623
752
@@ -675,6 +804,8 @@ class ParserVisitor final : public CelBaseVisitor,
675
804
676
805
std::string ErrorMessage ();
677
806
807
+ Expr PackAnnotations (Expr ast);
808
+
678
809
private:
679
810
template <typename ... Args>
680
811
Expr GlobalCallOrMacro (int64_t expr_id, absl::string_view function,
@@ -702,14 +833,38 @@ class ParserVisitor final : public CelBaseVisitor,
702
833
private:
703
834
const cel::Source& source_;
704
835
cel::ParserMacroExprFactory factory_;
705
- const cel::MacroRegistry& macro_registry_;
836
+ AugmentedMacroRegistry macro_registry_;
837
+ AnnotationCollector annotations_;
706
838
int recursion_depth_;
707
839
const int max_recursion_depth_;
708
840
const bool add_macro_calls_;
709
841
const bool enable_optional_syntax_;
710
842
const bool enable_quoted_identifiers_;
843
+ const bool enable_annotations_;
711
844
};
712
845
846
+ ParserVisitor::ParserVisitor (const cel::Source& source, int max_recursion_depth,
847
+ absl::string_view accu_var,
848
+ const cel::MacroRegistry& macro_registry,
849
+ bool add_macro_calls, bool enable_optional_syntax,
850
+ bool enable_quoted_identifiers,
851
+ bool enable_annotations)
852
+ : source_(source),
853
+ factory_(source_, accu_var),
854
+ macro_registry_(macro_registry),
855
+ recursion_depth_(0 ),
856
+ max_recursion_depth_(max_recursion_depth),
857
+ add_macro_calls_(add_macro_calls),
858
+ enable_optional_syntax_(enable_optional_syntax),
859
+ enable_quoted_identifiers_(enable_quoted_identifiers),
860
+ enable_annotations_(enable_annotations) {
861
+ if (enable_annotations_) {
862
+ macro_registry_.overlay ()
863
+ .RegisterMacro (annotations_.MakeAnnotationImpl ())
864
+ .IgnoreError ();
865
+ }
866
+ }
867
+
713
868
template <typename T, typename = std::enable_if_t <
714
869
std::is_base_of<antlr4::tree::ParseTree, T>::value>>
715
870
T* tree_as (antlr4::tree::ParseTree* tree) {
@@ -1638,6 +1793,61 @@ struct ParseResult {
1638
1793
EnrichedSourceInfo enriched_source_info;
1639
1794
};
1640
1795
1796
+ Expr NormalizeAnnotation (cel::ParserMacroExprFactory& mef, Expr expr) {
1797
+ if (expr.has_struct_expr ()) {
1798
+ return expr;
1799
+ }
1800
+
1801
+ if (expr.has_const_expr ()) {
1802
+ std::vector<cel::StructExprField> fields;
1803
+ fields.reserve (2 );
1804
+ fields.push_back (
1805
+ mef.NewStructField (mef.NextId ({}), " name" , std::move (expr)));
1806
+ auto bool_const = mef.NewBoolConst (mef.NextId ({}), true );
1807
+ fields.push_back (mef.NewStructField (mef.NextId ({}), " inspect_only" ,
1808
+ std::move (bool_const)));
1809
+ return mef.NewStruct (mef.NextId ({}), " cel.Annotation" , std::move (fields));
1810
+ }
1811
+
1812
+ return mef.ReportError (" invalid annotation encountered finalizing AST" );
1813
+ }
1814
+
1815
+ Expr ParserVisitor::PackAnnotations (Expr ast) {
1816
+ if (annotations_.annotations ().empty ()) {
1817
+ return ast;
1818
+ }
1819
+
1820
+ auto annotations = annotations_.consume_annotations ();
1821
+ std::vector<MapExprEntry> entries;
1822
+ entries.reserve (annotations.size ());
1823
+
1824
+ for (auto & annotation : annotations) {
1825
+ std::vector<cel::ListExprElement> annotation_values;
1826
+ annotation_values.reserve (annotation.second .size ());
1827
+
1828
+ for (auto & annotation_value : annotation.second ) {
1829
+ auto annotation =
1830
+ NormalizeAnnotation (factory_, std::move (annotation_value.expr ));
1831
+ annotation_values.push_back (
1832
+ factory_.NewListElement (std::move (annotation)));
1833
+ }
1834
+ auto id = factory_.NewIntConst (factory_.NextId ({}), annotation.first );
1835
+ auto annotation_list =
1836
+ factory_.NewList (factory_.NextId ({}), std::move (annotation_values));
1837
+ entries.push_back (factory_.NewMapEntry (factory_.NextId ({}), std::move (id),
1838
+ std::move (annotation_list)));
1839
+ }
1840
+
1841
+ std::vector<Expr> args;
1842
+ args.push_back (std::move (ast));
1843
+ args.push_back (factory_.NewMap (factory_.NextId ({}), std::move (entries)));
1844
+
1845
+ auto result =
1846
+ factory_.NewCall (factory_.NextId ({}), " cel.@annotated" , std::move (args));
1847
+
1848
+ return result;
1849
+ }
1850
+
1641
1851
absl::StatusOr<ParseResult> ParseImpl (const cel::Source& source,
1642
1852
const cel::MacroRegistry& registry,
1643
1853
const ParserOptions& options) {
@@ -1656,10 +1866,10 @@ absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
1656
1866
if (options.enable_hidden_accumulator_var ) {
1657
1867
accu_var = cel::kHiddenAccumulatorVariableName ;
1658
1868
}
1659
- ParserVisitor visitor (source, options. max_recursion_depth , accu_var,
1660
- registry , options.add_macro_calls ,
1661
- options.enable_optional_syntax ,
1662
- options.enable_quoted_identifiers );
1869
+ ParserVisitor visitor (
1870
+ source , options.max_recursion_depth , accu_var, registry ,
1871
+ options. add_macro_calls , options.enable_optional_syntax ,
1872
+ options.enable_quoted_identifiers , options. enable_annotations );
1663
1873
1664
1874
lexer.removeErrorListeners ();
1665
1875
parser.removeErrorListeners ();
@@ -1686,7 +1896,9 @@ absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
1686
1896
if (visitor.HasErrored ()) {
1687
1897
return absl::InvalidArgumentError (visitor.ErrorMessage ());
1688
1898
}
1689
-
1899
+ if (options.enable_annotations ) {
1900
+ expr = visitor.PackAnnotations (std::move (expr));
1901
+ }
1690
1902
return {
1691
1903
ParseResult{.expr = std::move (expr),
1692
1904
.source_info = visitor.GetSourceInfo (),
0 commit comments