From 2f20d82b8de23beadf57adeea72acc2f86c8ca22 Mon Sep 17 00:00:00 2001 From: Thomas Wies Date: Sat, 16 Nov 2024 17:12:32 -0500 Subject: [PATCH] clean up AST wrt. assignment statement rep; other minor signature changes --- lib/ast/astDef.ml | 44 +- lib/ast/rewriter.ml | 29 +- lib/backend/smt_solver.ml | 2 +- lib/frontend/parser.mly | 17 +- lib/frontend/rewrites/heapsExplicitTrnsl.ml | 127 +- lib/frontend/rewrites/rewrites.ml | 132 +- lib/frontend/typing.ml | 1478 +++++++++---------- 7 files changed, 830 insertions(+), 999 deletions(-) diff --git a/lib/ast/astDef.ml b/lib/ast/astDef.ml index e396091..a846569 100644 --- a/lib/ast/astDef.ml +++ b/lib/ast/astDef.ml @@ -365,7 +365,7 @@ module Type = struct let mk_fld loc tpf = App (Fld, [tpf], mk_attr loc) let mk_perm loc = App (Perm, [], mk_attr loc) let mk_data id decls loc = App (Data (id, decls), [], mk_attr loc) - let mk_var loc qid = App (Var qid, [], mk_attr loc) + let mk_var qid = App (Var qid, [], mk_attr (QualIdent.to_loc qid)) let mk_atomic_token loc = App (AtomicToken, [], mk_attr loc) let mk_prod loc tp_list = match tp_list with @@ -387,7 +387,7 @@ module Type = struct let map = mk_map Loc.dummy let perm = mk_perm Loc.dummy let data id decls = mk_data id decls Loc.dummy - let var qid = mk_var Loc.dummy qid + let var qid = mk_var qid let atomic_token = mk_atomic_token Loc.dummy (** Equality and Subtyping *) @@ -443,7 +443,7 @@ module Type = struct (** Auxiliary utility functions *) - let mk_var_decl ?(loc = Loc.dummy) ?(const = false) ?(ghost = false) ?(implicit = false) name tp = + let mk_var_decl ?(const = false) ?(ghost = false) ?(implicit = false) name ?(loc = Ident.to_loc name) tp = { var_name = name; var_loc = loc; var_type = tp; var_const = const; var_ghost = ghost; var_implicit = implicit } let is_num tp = @@ -764,8 +764,8 @@ module Expr = struct let mk_app ?(loc = Loc.dummy) ~typ c es = App (c, es, mk_attr loc typ) - let mk_var ?(loc = Loc.dummy) ~typ (qual_ident: qual_ident) = - mk_app ~loc ~typ (Var qual_ident) [] + let mk_var ~typ (qual_ident: qual_ident) = + mk_app ~loc:(QualIdent.to_loc qual_ident) ~typ (Var qual_ident) [] let mk_binder ?(loc = Loc.dummy) ?(typ = Type.bool) ?(trigs = []) b vs e = match vs with @@ -866,7 +866,7 @@ module Expr = struct App (MapLookUp, [ e1; e2 ], mk_attr loc t) let from_var_decl (var_decl:var_decl) = - mk_var ~loc:var_decl.var_loc ~typ:var_decl.var_type (QualIdent.from_ident var_decl.var_name) + mk_var ~typ:var_decl.var_type (QualIdent.from_ident var_decl.var_name) (** Auxiliary functions *) @@ -1169,7 +1169,7 @@ module Stmt = struct new_args : (qual_ident * expr option) list; } - type assign_desc = { assign_lhs : expr list; assign_rhs : expr; assign_is_init : bool } + type assign_desc = { assign_lhs : qual_ident list; assign_rhs : expr; assign_is_init : bool } type bind_desc = { bind_lhs : expr list; bind_rhs : expr } type field_read_desc = { @@ -1322,8 +1322,8 @@ module Stmt = struct | Assign astm -> ( match astm.assign_lhs with | [] -> Expr.pr ppf astm.assign_rhs - | es -> - fprintf ppf "@[<2>%a@ :=@ %a@]" Expr.pr_list es Expr.pr + | vs -> + fprintf ppf "@[<2>%a@ :=@ %a@]" QualIdent.pr_list vs Expr.pr astm.assign_rhs) | Bind bstm -> ( match bstm.bind_lhs with @@ -1497,6 +1497,11 @@ module Stmt = struct let mk_assign ~loc lhs rhs = { stmt_desc = Basic (Assign { assign_lhs = lhs; assign_rhs = rhs; assign_is_init = false }); stmt_loc = loc } + let mk_field_write ~loc ref field v = + { stmt_desc = Basic (FieldWrite {field_write_ref = ref; field_write_field = field; field_write_val = v}); + stmt_loc = loc + } + let mk_return ~loc e = { stmt_desc = Basic (Return e); stmt_loc = loc } let mk_bind ~loc lhs rhs = @@ -1544,7 +1549,10 @@ module Stmt = struct Set.add accesses new_desc.new_lhs | Assign assign_desc -> - scan_expr_list accesses (assign_desc.assign_rhs :: assign_desc.assign_lhs) + let accesses = + List.fold assign_desc.assign_lhs ~init:accesses ~f:Set.add + in + scan_expr_list accesses [assign_desc.assign_rhs] | Bind bind_desc -> scan_expr_list accesses (bind_desc.bind_rhs :: bind_desc.bind_lhs) @@ -1641,15 +1649,8 @@ module Stmt = struct [] | Assign assign_desc -> - List.filter_map assign_desc.assign_lhs ~f:(fun e -> - match e with - | App (Var qi, _, _) -> - if List.is_empty qi.qual_path then - Some qi.qual_base - else - None - | _ -> None - ) + List.map assign_desc.assign_lhs ~f:QualIdent.unqualify + | Bind bind_desc -> List.filter_map bind_desc.bind_lhs ~f:(fun e -> @@ -1748,7 +1749,7 @@ module Stmt = struct [] | Assign assign_desc -> - List.concat_map (assign_desc.assign_rhs :: assign_desc.assign_lhs) ~f:(fun e -> Expr.expr_fields_accessed e) + Expr.expr_fields_accessed assign_desc.assign_rhs | Bind bind_desc -> List.concat_map (bind_desc.bind_rhs :: bind_desc.bind_lhs) ~f:(fun e -> Expr.expr_fields_accessed e) @@ -2489,8 +2490,7 @@ module ProgStats = struct | Assign assign_desc -> let is_ghost = try (List.exists assign_desc.assign_lhs - ~f:(fun e -> - let qi = Expr.to_qual_ident e in + ~f:(fun qi -> List.exists proc_decl.Callable.call_decl_locals ~f:(fun vd -> Ident.(vd.var_name = (QualIdent.to_ident qi)) ) ) ) diff --git a/lib/ast/rewriter.ml b/lib/ast/rewriter.ml index 207ad91..256e5fb 100644 --- a/lib/ast/rewriter.ml +++ b/lib/ast/rewriter.ml @@ -641,11 +641,10 @@ module Stmt = struct Basic (Spec (sk, { spec with spec_form = new_spec_form })); } | Assign assign -> - let* assign_lhs = List.map assign.assign_lhs ~f in let+ assign_rhs = f assign.assign_rhs in { stmt with - stmt_desc = Basic (Assign { assign with assign_lhs; assign_rhs }); + stmt_desc = Basic (Assign { assign with assign_rhs }); } | Bind bind_desc -> let* bind_lhs = List.map bind_desc.bind_lhs ~f in @@ -781,10 +780,8 @@ module Stmt = struct in { stmt with stmt_desc = Basic (VarDef { var_decl; var_init }) } | Assign assign -> - let+ assign_rhs = Expr.rewrite_qual_idents ~f assign.assign_rhs - and+ assign_lhs = - List.map assign.assign_lhs ~f:(Expr.rewrite_qual_idents ~f) - in + let+ assign_rhs = Expr.rewrite_qual_idents ~f assign.assign_rhs in + let assign_lhs = Base.List.map assign.assign_lhs ~f in { stmt with stmt_desc = Basic (Assign { assign with assign_lhs; assign_rhs }); @@ -1279,8 +1276,7 @@ module Symbol = struct Some tp_expr | ModDef { mod_decl = { mod_decl_rep = Some rep_id; _ }; _ } -> let+ tp_expr = - AstDef.Type.mk_var (QualIdent.to_loc name) - (QualIdent.append name rep_id) + AstDef.Type.mk_var (QualIdent.append name rep_id) |> Type.rewrite_qual_idents ~f:(QualIdent.requalify subst) in Some tp_expr @@ -1351,8 +1347,8 @@ let find_and_reify loc name : (AstDef.Module.symbol, 'a) t_ext = m "Rewriter.find_and_reify: symbol = %a" Symbol.pr symbol);*) Symbol.reify symbol -let is_local loc qual_ident s = - let s, qual_ident = resolve loc qual_ident s in +let is_local qual_ident s = + let s, qual_ident = resolve (QualIdent.to_loc qual_ident) qual_ident s in (s, Base.List.is_empty qual_ident.qual_path) module ProgUtils = struct @@ -1640,7 +1636,6 @@ module ProgUtils = struct let get_ra_rep_type (ra_qual_iden : qual_ident) : type_expr = AstDef.Type.mk_var - (QualIdent.to_loc ra_qual_iden) (QualIdent.append ra_qual_iden (Ident.make (QualIdent.to_loc ra_qual_iden) "T" 0)) @@ -1842,7 +1837,7 @@ module ProgUtils = struct QualIdent.append field_utils_module (heap_utils_id_ident loc) in - return @@ AstDef.Expr.mk_var ~loc id_qual_ident ~typ:field_elem_type + return @@ AstDef.Expr.mk_var id_qual_ident ~typ:field_elem_type let get_pred_utils_id loc pred_name : expr t = let open Syntax in @@ -1852,13 +1847,13 @@ module ProgUtils = struct let* pred_elem_type_name = get_pred_utils_rep_type loc pred_name in - let pred_elem_type = AstDef.Type.mk_var loc pred_elem_type_name in + let pred_elem_type = AstDef.Type.mk_var pred_elem_type_name in let id_qual_ident = QualIdent.append pred_utils_module (heap_utils_id_ident loc) in - return @@ AstDef.Expr.mk_var ~loc id_qual_ident ~typ:pred_elem_type + return @@ AstDef.Expr.mk_var id_qual_ident ~typ:pred_elem_type let get_au_utils_id loc call_name : expr t = let open Syntax in @@ -1868,13 +1863,13 @@ module ProgUtils = struct let* call_elem_type_name = get_au_utils_rep_type loc call_name in - let call_elem_type = AstDef.Type.mk_var loc call_elem_type_name in + let call_elem_type = AstDef.Type.mk_var call_elem_type_name in let id_qual_ident = QualIdent.append call_utils_module (heap_utils_id_ident loc) in - return @@ AstDef.Expr.mk_var ~loc id_qual_ident ~typ:call_elem_type + return @@ AstDef.Expr.mk_var id_qual_ident ~typ:call_elem_type (* ======================== *) @@ -1987,7 +1982,7 @@ module ProgUtils = struct AstDef.Type.mk_map (QualIdent.to_loc pred_name) (AstDef.Type.mk_prod (QualIdent.to_loc pred_name) pred_in_types) - (AstDef.Type.mk_var (QualIdent.to_loc pred_name) pred_rep_type) + (AstDef.Type.mk_var pred_rep_type) let rec is_expr_pure (expr : expr) : (bool, 'a) t_ext = let open Syntax in diff --git a/lib/backend/smt_solver.ml b/lib/backend/smt_solver.ml index ef440b8..096eca0 100644 --- a/lib/backend/smt_solver.ml +++ b/lib/backend/smt_solver.ml @@ -316,7 +316,7 @@ let declare_tuple_sort (arity : int) : command = let destrs_sorts = Base.List.map2_exn destrs params ~f:(fun destr param -> - (destr, Ast.Type.mk_var Util.Loc.dummy param)) + (destr, Ast.Type.mk_var param)) in mk_declare_datatype (tuple_sort_name, params, [ (constr, destrs_sorts) ]) diff --git a/lib/frontend/parser.mly b/lib/frontend/parser.mly index 77f702b..9306439 100644 --- a/lib/frontend/parser.mly +++ b/lib/frontend/parser.mly @@ -425,7 +425,7 @@ stmt_wo_trailing_substmt: Stmt.(Basic (Assign assign)) } (* assignment / allocation *) -| es = separated_nonempty_list(COMMA, expr); COLONEQ; e = new_or_expr; SEMICOLON { +| es = separated_nonempty_list(COMMA, expr); COLONEQ; e = assign_rhs; SEMICOLON { let open Stmt in match e with | Basic (New new_descr) -> @@ -436,7 +436,14 @@ stmt_wo_trailing_substmt: (match es with | [Expr.(App (Read, [field_write_ref; App (Var field_write_field, [], _)], _))] -> Basic (FieldWrite { field_write_ref; field_write_field; field_write_val = assign.assign_rhs }) - | _ -> Basic (Assign { assign with assign_lhs = es })) + | _ -> + let vs = List.map (function + | Expr.(App (Var qual_ident, [], _)) + when QualIdent.is_local qual_ident -> qual_ident + | e -> Error.syntax_error (Expr.to_loc e) "Expected single field location or local variables on left-hand side of assignment") + es + in + Basic (Assign { assign with assign_lhs = vs })) | _ -> assert false } (* bind *) @@ -549,7 +556,7 @@ with_clause: | _ -> Error.syntax_error (Loc.make $startpos $startpos) "A 'with' clause is only allowed in assert statements" } -new_or_expr: +assign_rhs: | NEW LPAREN fes = separated_list(COMMA, pair(qual_ident, option(preceded(COLON, expr)))) RPAREN { let new_descr = Stmt.{ new_lhs = QualIdent.from_ident (Ident.make Loc.dummy "" 0); @@ -1007,10 +1014,10 @@ type_expr: | REF { Type.mk_ref (Loc.make $startpos $endpos) } | PERM { Type.mk_perm (Loc.make $startpos $endpos)} | ATOMICTOKEN { Type.mk_atomic_token (Loc.make $startpos $endpos) } -//| x = IDENT { Type.mk_var (Loc.make $startpos $endpos) (QualIdent.from_ident x) } +//| x = IDENT { Type.mk_var (QualIdent.from_ident x) } | SET LBRACKET t = type_expr RBRACKET { Type.mk_set (Loc.make $startpos $endpos) t } | MAP LBRACKET; t1 = type_expr; COMMA; t2 = type_expr; RBRACKET { Type.mk_map (Loc.make $startpos $endpos) t1 t2 } -| x = mod_ident { Type.mk_var (Loc.make $startpos $endpos) x } +| x = mod_ident { Type.mk_var x } | LPAREN ts = type_expr_list RPAREN { Type.mk_prod (Loc.make $startpos $endpos) ts } | x = mod_ident LBRACKET; ts = type_expr_list; RBRACKET { Type.(App(Var x, ts, Type.mk_attr (Loc.make $startpos $endpos))) } diff --git a/lib/frontend/rewrites/heapsExplicitTrnsl.ml b/lib/frontend/rewrites/heapsExplicitTrnsl.ml index 4d5a30e..1a85048 100644 --- a/lib/frontend/rewrites/heapsExplicitTrnsl.ml +++ b/lib/frontend/rewrites/heapsExplicitTrnsl.ml @@ -223,9 +223,9 @@ let compute_env_local_var_decls ~loc (expr: expr) (conds: conditions) (universal let generate_inv_function ~loc (universal_quants : universal_quants) (conds : conditions) (inv_expr : expr) ~(arg_expr : expr) : expr Rewriter.t = - (* Logs.debug (fun m -> m "heapsExplicitTrnsl.generate_inv_function: Generating inv function for %a" Expr.pr inv_expr); + Logs.debug (fun m -> m "heapsExplicitTrnsl.generate_inv_function: Generating inv function for %a" Expr.pr inv_expr); Logs.debug (fun m -> m "arg_expr: %a" Expr.pr arg_expr); - Logs.debug (fun m -> m "inv_expr_type: %a; arg_expr_type: %a" Type.pr (Expr.to_type inv_expr) Type.pr (Expr.to_type arg_expr)); *) + Logs.debug (fun m -> m "inv_expr_type: %a; arg_expr_type: %a" Type.pr (Expr.to_type inv_expr) Type.pr (Expr.to_type arg_expr)); let open Rewriter.Syntax in let* tp1 = Typing.ProcessTypeExpr.expand_type_expr (Expr.to_type inv_expr) and* tp2 = Typing.ProcessTypeExpr.expand_type_expr (Expr.to_type arg_expr) in @@ -635,7 +635,7 @@ let generate_utils_module ~(is_field : bool) ?(is_frac_field = false) (mod_ident let* mod_def = let type_ident = Rewriter.ProgUtils.heap_utils_rep_type_ident loc in - let type_tp_expr = Type.mk_var loc (QualIdent.from_ident type_ident) in + let type_tp_expr = Type.mk_var (QualIdent.from_ident type_ident) in let type_def = { @@ -654,7 +654,7 @@ let generate_utils_module ~(is_field : bool) ?(is_frac_field = false) (mod_ident type_tp_expr; var_init = Some - (Expr.mk_var ~loc ~typ:fld_elem_type + (Expr.mk_var ~typ:fld_elem_type (Rewriter.ProgUtils.get_ra_id ra_qual_ident)); } in @@ -740,7 +740,7 @@ let generate_utils_module ~(is_field : bool) ?(is_frac_field = false) (mod_ident (Expr.mk_maplookup ~loc (Expr.from_var_decl heap_formal_arg) (Expr.mk_null ())) - (Expr.mk_var ~loc ~typ:type_tp_expr + (Expr.mk_var ~typ:type_tp_expr (Rewriter.ProgUtils.get_ra_id ra_qual_ident)); in @@ -1196,7 +1196,7 @@ let introduce_heaps_in_stmts ~loc ~fields_list ~preds_list ~au_preds_list body : in let pred_heap_elem_type = - Type.mk_var loc pred_heap_elem_type_qual_ident + Type.mk_var pred_heap_elem_type_qual_ident in let* pred_in_types = Rewriter.ProgUtils.pred_in_types pred_name in @@ -1303,7 +1303,7 @@ let introduce_heaps_in_stmts ~loc ~fields_list ~preds_list ~au_preds_list body : Rewriter.ProgUtils.get_au_utils_rep_type loc call_name in - let au_heap_elem_type = Type.mk_var loc au_heap_elem_type_qual_ident in + let au_heap_elem_type = Type.mk_var au_heap_elem_type_qual_ident in (* Done so that Ident is aware of this name being used; prevents the same name from being generated again during SSA transform *) let _ = Ident.fresh loc (au_heap_name call_name).ident_name in @@ -1422,7 +1422,7 @@ let rec rewrite_fpu (stmt : Stmt.t) : Stmt.t Rewriter.t = in let field_expr = - Expr.mk_var ~loc:stmt.stmt_loc ~typ:field_symbol.field_type + Expr.mk_var ~typ:field_symbol.field_type fpu_desc.fpu_field in @@ -1893,7 +1893,7 @@ module TrnslInhale = struct Stmt.mk_assign ~loc:(Expr.to_loc bind_desc.bind_rhs) - [ Expr.from_var_decl var_decl ] + [ var_decl.var_name |> QualIdent.from_ident ] rhs) in @@ -2113,7 +2113,7 @@ module TrnslInhale = struct (* field$Heap := field$Heap2 *) let eq_stmt = - Stmt.mk_assign ~loc [ field_heap_expr ] field_heap2_expr + Stmt.mk_assign ~loc [ field_heap_expr |> Expr.to_qual_ident ] field_heap2_expr in let assume_heap_valid = @@ -2151,7 +2151,7 @@ module TrnslInhale = struct Rewriter.ProgUtils.get_au_utils_rep_type loc call_qual_ident in - let heap_elem_type = Type.mk_var loc heap_elem_type_qual_iden in + let heap_elem_type = Type.mk_var heap_elem_type_qual_iden in let call_name = call_qual_ident in let au_heap_name = au_heap_name call_name in @@ -2348,7 +2348,7 @@ module TrnslInhale = struct in (* au$Heap := au$Heap2 *) - let eq_stmt = Stmt.mk_assign ~loc [ au_heap_expr ] au_heap2_expr in + let eq_stmt = Stmt.mk_assign ~loc [ au_heap_expr |> Expr.to_qual_ident ] au_heap2_expr in let assume_heap_valid = Stmt.mk_assume_expr ~loc @@ -2408,7 +2408,7 @@ module TrnslInhale = struct in let heap_elem_type = - Type.mk_var loc heap_elem_type_qual_iden + Type.mk_var heap_elem_type_qual_iden in let pred_name = qual_ident in @@ -2657,7 +2657,7 @@ module TrnslInhale = struct (* pred$Heap := pred$Heap2 *) let eq_stmt = - Stmt.mk_assign ~loc [ pred_heap_expr ] pred_heap2_expr + Stmt.mk_assign ~loc [ pred_heap_expr |> Expr.to_qual_ident ] pred_heap2_expr in let assume_heap_valid = @@ -2764,7 +2764,7 @@ module TrnslInhale = struct Rewriter.ProgUtils.get_au_utils_rep_type loc call_qual_ident in - let heap_elem_type = Type.mk_var loc heap_elem_type_qual_iden in + let heap_elem_type = Type.mk_var heap_elem_type_qual_iden in let call_name = call_qual_ident in let au_heap_name = au_heap_name call_name in @@ -2855,7 +2855,7 @@ module TrnslInhale = struct in let heap_elem_type = - Type.mk_var loc heap_elem_type_qual_iden + Type.mk_var heap_elem_type_qual_iden in let pred_name = qual_ident in @@ -4150,7 +4150,7 @@ module TrnslExhale = struct (* field$Heap := field$Heap2 *) let eq_stmt = - Stmt.mk_assign ~loc [ field_heap_expr ] field_heap2_expr + Stmt.mk_assign ~loc [ field_heap_expr |> Expr.to_qual_ident ] field_heap2_expr in let assert_heap_valid = @@ -4183,7 +4183,7 @@ module TrnslExhale = struct Rewriter.ProgUtils.get_au_utils_rep_type loc call_qual_ident in - let heap_elem_type = Type.mk_var loc heap_elem_type_qual_iden in + let heap_elem_type = Type.mk_var heap_elem_type_qual_iden in let call_name = call_qual_ident in let au_heap_name = au_heap_name call_name in @@ -4377,7 +4377,7 @@ module TrnslExhale = struct in (* pred$Heap := pred$Heap2 *) - let eq_stmt = Stmt.mk_assign ~loc [ au_heap_expr ] au_heap2_expr in + let eq_stmt = Stmt.mk_assign ~loc [ au_heap_expr |> Expr.to_qual_ident ] au_heap2_expr in let assert_heap_valid = Stmt.mk_assert_expr ~loc ~spec_error @@ -4439,7 +4439,7 @@ module TrnslExhale = struct in let heap_elem_type = - Type.mk_var loc heap_elem_type_qual_iden + Type.mk_var heap_elem_type_qual_iden in let pred_name = qual_ident in @@ -4687,7 +4687,7 @@ module TrnslExhale = struct (* pred$Heap := pred$Heap2 *) let eq_stmt = - Stmt.mk_assign ~loc [ pred_heap_expr ] pred_heap2_expr + Stmt.mk_assign ~loc [ pred_heap_expr |> Expr.to_qual_ident ] pred_heap2_expr in let assert_heap_valid = @@ -4726,8 +4726,10 @@ end let rec rewrite_make_heaps_explicit (s : Stmt.t) : Stmt.t Rewriter.t = let open Rewriter.Syntax in match s.stmt_desc with - | Stmt.Basic basic_stmt -> ( + | Stmt.Basic basic_stmt -> begin match basic_stmt with + | VarDef _ | Use _ | New _ | Assign _ | Bind _ | Cas _ | Havoc _ | Return _ | AUAction _ | Fpu _ | Call _ -> + Rewriter.return s | Spec (spec_kind, spec) -> ( match spec_kind with | Inhale -> @@ -4893,7 +4895,7 @@ let rec rewrite_make_heaps_explicit (s : Stmt.t) : Stmt.t Rewriter.t = let assign_stmt = Stmt.mk_assign ~loc:s.stmt_loc - [ Expr.from_var_decl lhs_var ] + [ fr_desc.field_read_lhs ] (Expr.mk_app ~typ:lhs_var.var_type (DataDestr field_val_destr) [ Expr.mk_maplookup field_heap_expr fr_desc.field_read_ref ]) in @@ -4965,7 +4967,7 @@ let rec rewrite_make_heaps_explicit (s : Stmt.t) : Stmt.t Rewriter.t = assert_expr in let assign_stmt = - Stmt.mk_assign ~loc:s.stmt_loc [ field_heap_expr ] + Stmt.mk_assign ~loc:s.stmt_loc [ field_heap_expr |> Expr.to_qual_ident ] (Expr.mk_app ~typ:(Type.mk_map s.stmt_loc Type.ref field_ra_type) MapUpdate @@ -4975,82 +4977,7 @@ let rec rewrite_make_heaps_explicit (s : Stmt.t) : Stmt.t Rewriter.t = Rewriter.return (Stmt.mk_block_stmt ~loc:s.stmt_loc [ assert_stmt; assign_stmt ]) else Error.type_error s.stmt_loc "Expected a FracRA type." - | Assign - { - assign_lhs = [ Expr.App (Read, [ ref_expr; field_expr ], _) ]; - assign_rhs; - _; - } -> - let field_name = Expr.to_qual_ident field_expr in - - let* field_symbol = Rewriter.find_and_reify s.stmt_loc field_name in - - let field_symbol = - match field_symbol with - | FieldDef field_symbol -> field_symbol - | _ -> Error.type_error s.stmt_loc "Expected a field definition." - in - - let field_ra = - Rewriter.ProgUtils.field_get_ra_qual_iden field_symbol - in - - let* orig_ra_name, ra_def, _ = Rewriter.find s.stmt_loc field_ra in - - if QualIdent.(orig_ra_name = Predefs.lib_frac_mod_qual_ident) then - let field_ra_type = Rewriter.ProgUtils.get_ra_rep_type field_ra in - - let field_heap_name = field_heap_name field_name in - let field_heap_expr = - Expr.mk_var - ~typ:(Type.mk_map s.stmt_loc Type.ref field_ra_type) - (QualIdent.from_ident field_heap_name) - in - - let field_frac_destr = - QualIdent.append field_ra Predefs.lib_frac_chunk_destr2_ident - in - let field_frac_constr = - QualIdent.append field_ra Predefs.lib_frac_chunk_constr_ident - in - - let assert_expr = - Expr.mk_app ~loc:s.stmt_loc ~typ:Type.bool Expr.Geq - [ - Expr.mk_app ~typ:Type.real (DataDestr field_frac_destr) - [ Expr.mk_maplookup field_heap_expr ref_expr ]; - Expr.mk_real 1.; - ] - in - - let new_val = - Expr.mk_app ~typ:field_ra_type (DataConstr field_frac_constr) - [ assign_rhs; Expr.mk_real 1. ] - in - - let assert_stmt = - let error = - ( Error.Verification, - s.stmt_loc, - "Could not assert sufficient permissions to assign this field" - ) - in - Stmt.mk_assert_expr ~loc:s.stmt_loc - ~spec_error:[ Stmt.mk_const_spec_error error ] - assert_expr - in - let assign_stmt = - Stmt.mk_assign ~loc:s.stmt_loc [ field_heap_expr ] - (Expr.mk_app - ~typ:(Type.mk_map s.stmt_loc Type.ref field_ra_type) - MapUpdate - [ field_heap_expr; ref_expr; new_val ]) - in - - Rewriter.return - (Stmt.mk_block_stmt ~loc:s.stmt_loc [ assert_stmt; assign_stmt ]) - else Error.type_error s.stmt_loc "Expected a FracRA type." - | _ -> Rewriter.return s) + end | _ -> let* s = Rewriter.Stmt.descend s ~f:rewrite_make_heaps_explicit in diff --git a/lib/frontend/rewrites/rewrites.ml b/lib/frontend/rewrites/rewrites.ml index dc630af..b51a8c3 100644 --- a/lib/frontend/rewrites/rewrites.ml +++ b/lib/frontend/rewrites/rewrites.ml @@ -99,7 +99,7 @@ let rec rewrite_compr_expr (expr : expr) : expr Rewriter.t = var_implicit = false; } :: formals, - Expr.mk_var ~loc:(Expr.to_loc inner_expr) ~typ:data key + Expr.set_loc (Expr.mk_var ~typ:data key) (Expr.to_loc inner_expr) :: actuals )) in @@ -603,7 +603,7 @@ let rec rewrite_loops (stmt : Stmt.t) : Stmt.t Rewriter.t = let set_ret_vals_to_initial_args = List.map (Map.to_alist loop_ret_renaming_map) ~f:(fun (old_var, new_expr) -> - Stmt.mk_assign ~loc [ new_expr ] + Stmt.mk_assign ~loc [ new_expr |> Expr.to_qual_ident ] (Map.find_exn loop_arg_renaming_map old_var)) in @@ -746,7 +746,7 @@ let rec rewrite_ret_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t = let postconds_exhale_stmts = if Callable.is_atomic callable_decl then let atomic_token_var = - Expr.mk_var ~loc ~typ:Type.atomic_token + Expr.mk_var ~typ:Type.atomic_token (QualIdent.from_ident (Rewriter.ProgUtils.callable_au_token_ident ~loc callable_decl.call_decl_name)) @@ -840,8 +840,8 @@ let rec rewrite_new_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t = let inhale_expr = Expr.mk_app ~typ:Type.perm ~loc:stmt.stmt_loc Expr.Own [ - Expr.mk_var ~loc:stmt.stmt_loc ~typ:Type.ref new_desc.new_lhs; - Expr.mk_var ~loc:stmt.stmt_loc ~typ:field_type field_name; + Expr.mk_var ~typ:Type.ref new_desc.new_lhs; + Expr.mk_var ~typ:field_type field_name; field_val; ] in @@ -885,46 +885,15 @@ let rec rewrite_cas_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t = (Expr.from_var_decl new_var_decl) cas_desc.cas_old_val) in - let* symbol = Rewriter.find_and_reify stmt.stmt_loc cas_desc.cas_lhs in - let lhs_var_decl = - match symbol with - | VarDef v -> v.var_decl - | _ -> - Error.error stmt.stmt_loc - ("Expected a variable (3); found " ^ Symbol.to_string symbol) - in - let lhs_expr = Expr.from_var_decl lhs_var_decl in let then1_ = - Stmt.mk_assign ~loc:stmt.stmt_loc [ lhs_expr ] (Expr.mk_bool true) - in - let* field_symbol = - Rewriter.find_and_reify stmt.stmt_loc cas_desc.cas_field - in - let field_type, field_underlying_type = - match field_symbol with - | FieldDef f -> ( - match f.field_type with - | App (Fld, [ tp_expr ], _) -> (f.field_type, tp_expr) - | _ -> Error.type_error stmt.stmt_loc "Expected field identifier.") - | _ -> Error.error stmt.stmt_loc "Expected a field_def" - in - let expr_attr = Expr.mk_attr stmt.stmt_loc field_underlying_type in - let field_expr_attr = Expr.mk_attr stmt.stmt_loc field_type in - let read_expr = - Expr.App - ( Read, - [ - cas_desc.cas_ref; - Expr.App (Var cas_desc.cas_field, [], field_expr_attr); - ], - expr_attr ) + Stmt.mk_assign ~loc:stmt.stmt_loc [ cas_desc.cas_lhs ] (Expr.mk_bool true) in let then2_ = - Stmt.mk_assign ~loc:stmt.stmt_loc [ read_expr ] cas_desc.cas_new_val + Stmt.mk_field_write ~loc:stmt.stmt_loc cas_desc.cas_ref cas_desc.cas_field cas_desc.cas_new_val in let then_ = Stmt.mk_block_stmt ~loc:stmt.stmt_loc [ then1_; then2_ ] in let else_ = - Stmt.mk_assign ~loc:stmt.stmt_loc [ lhs_expr ] (Expr.mk_bool false) + Stmt.mk_assign ~loc:stmt.stmt_loc [ cas_desc.cas_lhs ] (Expr.mk_bool false) in let ite_stmt = Stmt.mk_cond ~loc:stmt.stmt_loc test_ then_ else_ in let new_stmts = @@ -1314,7 +1283,7 @@ let rec rewrite_call_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t = let reassign_lhs_stmt = Stmt.mk_assign ~loc:stmt.stmt_loc - (List.map lhs_list ~f:Expr.from_var_decl) + (List.map lhs_list ~f:(fun decl -> QualIdent.from_ident decl.var_name)) (Expr.mk_tuple new_lhs_list) in @@ -1350,14 +1319,14 @@ let rec rewrite_call_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t = in let new_assign_stmt = - Stmt.mk_assign ~loc:stmt.stmt_loc new_lhs_list + Stmt.mk_assign ~loc:stmt.stmt_loc (List.map new_lhs_list ~f:Expr.to_qual_ident) (Expr.mk_app ~loc:stmt.stmt_loc ~typ:ret_typ (Expr.Var call_desc.call_name) call_desc.call_args) in let reassign_lhs_stmt = Stmt.mk_assign ~loc:stmt.stmt_loc - (List.map lhs_list ~f:Expr.from_var_decl) + (List.map lhs_list ~f:(fun decl -> QualIdent.from_ident decl.var_name)) (Expr.mk_tuple new_lhs_list) in @@ -1418,7 +1387,7 @@ let rewrite_callable_pre_post_conds (c : Callable.t) : Callable.t Rewriter.t = let* callable_fully_qual_name = Rewriter.current_scope_id in let atomic_token_var = - Expr.mk_var ~loc ~typ:Type.atomic_token + Expr.mk_var ~typ:Type.atomic_token (QualIdent.from_ident (Rewriter.ProgUtils.callable_au_token_ident ~loc c.call_decl.call_decl_name)) @@ -1607,7 +1576,7 @@ let rec rewrite_frac_field_types (symbol : Module.symbol) : let frac_type = Type.mk_fld f.field_loc - (Type.mk_var f.field_loc + (Type.mk_var (QualIdent.append frac_mod_name (Ident.make f.field_loc "T" 0))) in @@ -1675,7 +1644,7 @@ let rec rewrite_own_expr_4_arg (expr : Expr.t) : Expr.t Rewriter.t = in let frac_type = - Type.mk_var (Expr.to_loc expr) + Type.mk_var (QualIdent.append frac_mod_name (Ident.make (Expr.to_loc expr) "T" 0)) in @@ -1746,7 +1715,7 @@ let rec rewrite_new_fpu_stmt_heap_arg (stmt : Stmt.t) : Stmt.t Rewriter.t = in let frac_type = - Type.mk_var (Expr.to_loc expr) + Type.mk_var (QualIdent.append frac_mod_name (Ident.make (Expr.to_loc expr) "T" 0)) in @@ -1803,7 +1772,7 @@ let rec rewrite_new_fpu_stmt_heap_arg (stmt : Stmt.t) : Stmt.t Rewriter.t = in let frac_type = - Type.mk_var loc + Type.mk_var (QualIdent.append frac_mod_name (Ident.make loc "T" 0)) in let new_expr = @@ -2155,26 +2124,14 @@ module AtomicityAnalysis = struct Rewriter.return stmt | Basic (Assign assign_desc) -> let* is_assign_desc_lhs_ghost = - Rewriter.List.for_all assign_desc.assign_lhs ~f:(fun expr -> - match expr with - | App (Var qual_iden, [], _) -> ( - let* symbol = - Rewriter.find_and_reify stmt.stmt_loc qual_iden - in - match symbol with - | VarDef v -> Rewriter.return v.var_decl.var_ghost - | _ -> Error.error stmt.stmt_loc "Expected a var_def" - ) - | App (Read, [ loc_expr; field_expr ], _) -> - let* field_symbol = Rewriter.find_and_reify loc (Expr.to_qual_ident field_expr) in - - begin match field_symbol with - | FieldDef f -> - Rewriter.return @@ f.field_is_ghost - | _ -> - Rewriter.return false - end - | _ -> Error.error stmt.stmt_loc "Expected a variable") + Rewriter.List.for_all assign_desc.assign_lhs ~f:(fun qual_ident -> + let* symbol = + Rewriter.find_and_reify stmt.stmt_loc qual_ident + in + match symbol with + | VarDef v -> Rewriter.return v.var_decl.var_ghost + | _ -> Error.error stmt.stmt_loc "Expected a var_def" + ) in if is_assign_desc_lhs_ghost then Rewriter.return stmt @@ -2368,7 +2325,7 @@ module AtomicityAnalysis = struct let assign_stmt = Stmt.mk_assign ~loc - [ Expr.from_var_decl qual_iden_var.var_decl ] + [ qual_iden ] (Expr.from_var_decl au_token_var.var_decl) in @@ -2714,11 +2671,10 @@ let rec rewrite_ssa_stmts (s : Stmt.t) : Expr.alpha_renaming assign_stmt.assign_rhs subst_map in let* assign_lhs = - Rewriter.List.map assign_stmt.assign_lhs ~f:(fun lhs_expr -> + Rewriter.List.map assign_stmt.assign_lhs ~f:(fun qual_ident -> let* var_map = Rewriter.current_user_state in - if Expr.is_ident lhs_expr then ( - let local_var = Expr.to_ident lhs_expr in + let local_var = QualIdent.to_ident qual_ident in Logs.debug (fun m -> m @@ -2747,8 +2703,7 @@ let rec rewrite_ssa_stmts (s : Stmt.t) : let+ _ = Rewriter.set_user_state var_map in - Expr.from_var_decl new_var_decl) - else Rewriter.return lhs_expr) + QualIdent.from_ident new_var_decl.var_name) in Rewriter.return @@ -2902,7 +2857,7 @@ let rec rewrite_ssa_stmts (s : Stmt.t) : let new_var_decl = Map.find_exn new_var_map var in Stmt.mk_assign ~loc:s.stmt_loc - [ Expr.from_var_decl new_var_decl ] + [ QualIdent.from_ident new_var_decl.var_name ] (Expr.from_var_decl old_var_decl)) in @@ -2916,7 +2871,7 @@ let rec rewrite_ssa_stmts (s : Stmt.t) : let new_var_decl = Map.find_exn new_var_map var in Stmt.mk_assign ~loc:s.stmt_loc - [ Expr.from_var_decl new_var_decl ] + [ QualIdent.from_ident new_var_decl.var_name ] (Expr.from_var_decl old_var_decl)) in @@ -3001,14 +2956,27 @@ let rec rewrite_assign_stmts (s : Stmt.t) : Stmt.t Rewriter.t = match s.stmt_desc with | Basic (Assign assign_stmt) -> - let assume_stmt = - Stmt.mk_assume_expr ~loc:s.stmt_loc - (Expr.mk_eq - (Expr.mk_tuple assign_stmt.assign_lhs) - assign_stmt.assign_rhs) - in + let+ assign_lhs = + Rewriter.List.map assign_stmt.assign_lhs ~f:(fun qual_ident -> + let* qual_ident, symbol = + Rewriter.resolve_and_find (QualIdent.to_loc qual_ident) qual_ident + in + let+ symbol = Rewriter.Symbol.reify symbol in + match symbol with + | VarDef { var_decl; _ } -> + Expr.mk_var ~typ:var_decl.var_type qual_ident + | _ -> assert false + ) + in + + let assume_stmt = + Stmt.mk_assume_expr ~loc:s.stmt_loc + (Expr.mk_eq + (Expr.mk_tuple assign_lhs) + assign_stmt.assign_rhs) + in - Rewriter.return assume_stmt + assume_stmt | _ -> let* s = Rewriter.Stmt.descend s ~f:rewrite_assign_stmts in diff --git a/lib/frontend/typing.ml b/lib/frontend/typing.ml index e1d7461..3ac84f2 100644 --- a/lib/frontend/typing.ml +++ b/lib/frontend/typing.ml @@ -797,16 +797,17 @@ module ProcessCallable = struct let disambiguate_ident (qual_ident : qual_ident) (disam_tbl : DisambiguationTbl.t) : qual_ident Rewriter.t = let open Rewriter.Syntax in - if List.is_empty qual_ident.qual_path then + if QualIdent.is_local qual_ident then + let ident = qual_ident |> QualIdent.unqualify in let* base = if Predefs.is_qual_ident_au_cmnd qual_ident then - Rewriter.return qual_ident.qual_base + Rewriter.return ident else - match DisambiguationTbl.find disam_tbl qual_ident.qual_base with + match DisambiguationTbl.find disam_tbl ident with | Some iden -> Rewriter.return iden | None -> let* is_local = - Rewriter.is_local qual_ident.qual_base.ident_loc qual_ident + Rewriter.is_local qual_ident in if is_local then (* if variable is local and it doesn't exist in DisambiguationTbl, then it is not defined in scope *) @@ -895,226 +896,719 @@ module ProcessCallable = struct (* Takes an expr, and returns a pure expression along with a set of temp variables that need to be defined *) () *) - let process_au_action_stmt (stmt : Stmt.stmt_desc) (loc : location) + let process_au_action_stmt (assign_lhs: qual_ident list) (var_decls_lhs: var_decl list) assign_rhs (loc : location) (disam_tbl : DisambiguationTbl.t) : - (Stmt.stmt_desc * DisambiguationTbl.t) Rewriter.t = + (Stmt.basic_stmt_desc * DisambiguationTbl.t) Rewriter.t = let open Rewriter.Syntax in - match stmt with - | Basic (Assign assign_desc) -> ( - Logs.debug (fun m -> - m "process_au_action_stmt: Assign: %a" Stmt.pr_basic_stmt - (Assign assign_desc)); - match assign_desc.assign_rhs with - | App (Var qual_ident, args, expr_attr) -> - if - QualIdent.(qual_ident = QualIdent.from_ident Predefs.bindAU_ident) - then - match (args, assign_desc.assign_lhs) with - | [], [ token ] -> ( - let+ token_expr = - disambiguate_process_expr token Type.atomic_token disam_tbl - in + match assign_rhs with + | Expr.App (Var qual_ident, args, expr_attr) -> + (* bindAU *) + if + QualIdent.(qual_ident = QualIdent.from_ident Predefs.bindAU_ident) + then + match (args, assign_lhs) with + | [], [ token_qual_ident ] -> + (* TODO: check type Type.atomic_token *) + Rewriter.return + ( Stmt.AUAction { auaction_kind = BindAU token_qual_ident }, + disam_tbl ) + | _ -> Error.type_error loc "bindAU takes exactly one argument" + else + (* openAU *) + if + QualIdent.(qual_ident = QualIdent.from_ident Predefs.openAU_ident) + then + let bound_vars = + Base.List.map2_exn assign_lhs var_decls_lhs ~f:(fun qual_ident var_decl -> + Expr.mk_var ~typ:var_decl.var_type qual_ident) + in - match token_expr with - | App (Var token_qual_ident, [], _) -> - ( Stmt.Basic - (AUAction { auaction_kind = BindAU token_qual_ident }), - disam_tbl ) - | _ -> - Error.type_error (Expr.to_loc token) - "The left-hand side of this ghost assignment must be a ghost variable") - | _ -> Error.type_error loc "bindAU takes exactly one argument" - else if - QualIdent.(qual_ident = QualIdent.from_ident Predefs.openAU_ident) - then - - let* bound_vars = - Rewriter.List.map assign_desc.assign_lhs ~f:(fun expr -> - let+ expr = - disambiguate_process_expr expr Type.any disam_tbl - in - match expr with - | App (Var qual_ident, [], _) -> expr - | _ -> - Error.type_error loc - "openAU bound_variables expected to be a variable") - in + match args with + | [ token ] -> ( + let* token_expr = + disambiguate_process_expr token Type.atomic_token disam_tbl + in + match token_expr with + | App (Var token_qual_ident, [], _) -> + Rewriter.return + ( Stmt.AUAction + { + auaction_kind = + OpenAU (token_qual_ident, None, bound_vars); + }, + disam_tbl ) + | _ -> + Error.type_error loc + "openAU token expected to be a variable") + | [ token; proc_name ] -> ( + let* token_expr = + disambiguate_process_expr token Type.atomic_token disam_tbl + in + let+ proc_name_expr = + disambiguate_process_expr proc_name Type.any disam_tbl + in + + match (token_expr, proc_name_expr) with + | ( App (Var token_qual_ident, [], _), + App (Var proc_qual_ident, [], _) ) -> + ( Stmt.AUAction + { + auaction_kind = + OpenAU + ( token_qual_ident, + Some proc_qual_ident, + bound_vars ); + }, + disam_tbl ) + | _ -> + Error.type_error loc + "openAU token and process name expected to be a \ + variable") + | _ -> + Error.type_error loc + "openAU takes exactly one or two arguments" + else if + QualIdent.( + qual_ident = QualIdent.from_ident Predefs.commitAU_ident) + then + match args with + | token :: args -> ( + let* token_expr = + disambiguate_process_expr token Type.atomic_token disam_tbl + in + + let+ args = + Rewriter.List.map args ~f:(fun arg -> + disambiguate_process_expr arg Type.any disam_tbl) + in + + match token_expr with + | App (Var token_qual_ident, [], _) -> + ( Stmt.AUAction + { + auaction_kind = CommitAU (token_qual_ident, args); + }, + disam_tbl ) + | _ -> + Error.type_error loc + "commitAU token expected to be a variable") + | _ -> Error.type_error loc "commitAU takes at least one argument" + else if + QualIdent.( + qual_ident = QualIdent.from_ident Predefs.abortAU_ident) + then + match args with + | [ token ] -> ( + let+ token_expr = + disambiguate_process_expr token Type.atomic_token disam_tbl + in + + match token_expr with + | App (Var token_qual_ident, [], _) -> + ( Stmt.AUAction { auaction_kind = AbortAU token_qual_ident }, + disam_tbl ) + | _ -> + Error.type_error loc + "abortAU token expected to be a variable") + | _ -> Error.type_error loc "abortAU takes exactly one argument" + else if + QualIdent.(qual_ident = QualIdent.from_ident Predefs.fpu_ident) + then + let field_opt = function + | Expr.App (Var qual_ident, [], _) as field_expr -> + let* field_expr = + disambiguate_process_expr field_expr Type.any disam_tbl + in + let* field_qual_ident, symbol = + Rewriter.resolve_and_find (QualIdent.to_loc qual_ident) qual_ident + in + let+ symbol = Rewriter.Symbol.reify symbol in + begin match symbol with + | FieldDef { field_type = App (Fld, [ given_type ], _); _ } -> + Some (field_qual_ident, given_type) + | _ -> None + end + | _ -> Rewriter.return None + in + let* ref_expr, field, fpu_exprs = + let* opt_list = + match args with + | Expr.App (Read, [ref_expr; field_expr], _) :: expr2 :: expr3_opt -> + let* field = field_opt field_expr in + field |> Rewriter.Option.map ~f:(fun field -> + Rewriter.return (ref_expr, field, expr2 :: expr3_opt)) + | _ -> Rewriter.return None + in + opt_list + |> Rewriter.Option.lazy_value ~default:(fun () -> match args with - | [ token ] -> ( - let* token_expr = - disambiguate_process_expr token Type.atomic_token disam_tbl - in - - match token_expr with - | App (Var token_qual_ident, [], _) -> - Rewriter.return - ( Stmt.Basic - (AUAction - { - auaction_kind = - OpenAU (token_qual_ident, None, bound_vars); - }), - disam_tbl ) - | _ -> - Error.type_error loc - "openAU token expected to be a variable") - | [ token; proc_name ] -> ( - let* token_expr = - disambiguate_process_expr token Type.atomic_token disam_tbl - in - let+ proc_name_expr = - disambiguate_process_expr proc_name Type.any disam_tbl - in + | expr1 :: expr2 :: expr3_opt -> + let* field = field_opt expr2 in + let* field =Rewriter.Option.map field ~f:(fun field_qual_ident -> + Rewriter.return (expr1, field_qual_ident, expr3_opt)) + in + Rewriter.Option.lazy_value field + ~default:(fun () -> Error.type_error (Expr.to_loc expr1) "Expected field location") + | _ -> Error.type_error loc "Could not find field location in fpu" + ) + in + let* ref_expr = + disambiguate_process_expr ref_expr Type.ref disam_tbl + in + let field_qual_ident, given_type = field in + let+ fpu_exprs = + Rewriter.List.map fpu_exprs ~f:(fun fpu_expr -> + disambiguate_process_expr fpu_expr given_type + disam_tbl) + in + + (* let* old_val_expr = disambiguate_process_expr old_val_expr given_type disam_tbl in + let+ new_val_expr = disambiguate_process_expr new_val_expr given_type disam_tbl in *) + let old_val_expr, new_val_expr = + match fpu_exprs with + | [ old_val_expr; new_val_expr ] -> + (Some old_val_expr, new_val_expr) + | [ new_val_expr ] -> (None, new_val_expr) + | _ -> + Error.type_error loc + "fpu takes exactly three or four arguments" + in + + ( Stmt.Fpu + { + fpu_ref = ref_expr; + fpu_field = field_qual_ident; + fpu_old_val = old_val_expr; + fpu_new_val = new_val_expr; + }, + disam_tbl ) + else Error.type_error loc "Unknown AU action" + | _ -> + Error.error loc + "Internal error: process_au_action_stmt called with non-callable \ + expression" - match (token_expr, proc_name_expr) with - | ( App (Var token_qual_ident, [], _), - App (Var proc_qual_ident, [], _) ) -> - ( Stmt.Basic - (AUAction - { - auaction_kind = - OpenAU - ( token_qual_ident, - Some proc_qual_ident, - bound_vars ); - }), - disam_tbl ) - | _ -> - Error.type_error loc - "openAU token and process name expected to be a \ - variable") + let rec process_basic_stmt (expected_return_type : Type.t) + (basic_stmt : Stmt.basic_stmt_desc) (stmt_loc: Loc.t) (disam_tbl : DisambiguationTbl.t) : + (Stmt.basic_stmt_desc * DisambiguationTbl.t) Rewriter.t = + let open Rewriter.Syntax in + match basic_stmt with + | VarDef var_def -> + let* var_decl = + ProcessTypeExpr.process_var_decl var_def.var_decl + in + let var_decl, disam_tbl' = + DisambiguationTbl.add_var_decl var_decl disam_tbl + in + let* _ = + Rewriter.introduce_symbol + (VarDef { var_decl; var_init = None }) + in + let var = QualIdent.from_ident var_decl.var_name in + begin match var_def.var_init with + | None -> + Rewriter.return @@ (Stmt.Havoc var, disam_tbl') + | Some expr -> + let assign_desc = + Stmt. + { + assign_lhs = [ var_def.var_decl.var_name |> QualIdent.from_ident ]; + assign_rhs = expr; + assign_is_init = true; + } + in + + process_basic_stmt + expected_return_type + (Assign assign_desc) + stmt_loc + disam_tbl' + end + | Spec (sk, spec) -> + let+ spec = process_stmt_spec disam_tbl spec in + (Stmt.Spec (sk, spec), disam_tbl) + | Assign assign_desc -> begin + let* assign_lhs, var_decls_lhs = + Rewriter.List.fold_right assign_desc.assign_lhs ~init:([], []) ~f:(fun qual_ident (assign_lhs, var_decls_lhs) -> + let* qual_ident = disambiguate_ident qual_ident disam_tbl in + let* qual_ident, symbol = + Rewriter.resolve_and_find (QualIdent.to_loc qual_ident) qual_ident + in + let+ symbol = Rewriter.Symbol.reify symbol in + match symbol with + | VarDef { var_decl; _ } when not var_decl.var_const || assign_desc.assign_is_init -> + qual_ident :: assign_lhs, var_decl :: var_decls_lhs | _ -> - Error.type_error loc - "openAU takes exactly one or two arguments" - else if - QualIdent.( - qual_ident = QualIdent.from_ident Predefs.commitAU_ident) - then - match args with - | token :: args -> ( - let* token_expr = - disambiguate_process_expr token Type.atomic_token disam_tbl - in + Error.type_error (QualIdent.to_loc qual_ident) + (Printf.sprintf !"Cannot assign to %s %{QualIdent}" (Symbol.kind symbol) qual_ident) + ) + in - let+ args = - Rewriter.List.map args ~f:(fun arg -> - disambiguate_process_expr arg Type.any disam_tbl) - in + match assign_desc.assign_rhs with + (* Field read *) + | App (Read, [ ref_expr; field_expr ], _) -> + Logs.debug (fun m -> + m "process_stmt: read_assign_rhs: %a" Expr.pr + assign_desc.assign_rhs); + let field_qual_ident = Expr.to_qual_ident field_expr in + let field_read_lhs = + match assign_desc.assign_lhs with + | [ lhs ] -> lhs + | _ -> + Error.type_error stmt_loc + "Expected exactly one variable on left-hand side of field read" + in + + let field_read_desc = + Stmt. + { + field_read_lhs; + field_read_field = field_qual_ident; + field_read_ref = ref_expr; + } + in + process_basic_stmt expected_return_type (Stmt.FieldRead field_read_desc) stmt_loc disam_tbl + (* AU action *) + | App (Var qual_ident, _, _) as assign_rhs when Predefs.is_qual_ident_au_cmnd qual_ident -> + process_au_action_stmt assign_lhs var_decls_lhs assign_rhs stmt_loc disam_tbl + | _ -> + Logs.debug (fun m -> + m "process_stmt: assign_desc: %a" Stmt.pr_basic_stmt + (Assign assign_desc)); + + let expected_type = + Type.mk_prod + (Expr.to_loc assign_desc.assign_rhs) + (List.map var_decls_lhs ~f:(fun var -> var.var_type)) + in - match token_expr with - | App (Var token_qual_ident, [], _) -> - ( Stmt.Basic - (AUAction - { - auaction_kind = CommitAU (token_qual_ident, args); - }), - disam_tbl ) - | _ -> - Error.type_error loc - "commitAU token expected to be a variable") - | _ -> Error.type_error loc "commitAU takes at least one argument" - else if - QualIdent.( - qual_ident = QualIdent.from_ident Predefs.abortAU_ident) - then - match args with - | [ token ] -> ( - let+ token_expr = - disambiguate_process_expr token Type.atomic_token disam_tbl - in + let* assign_rhs = + disambiguate_process_expr assign_desc.assign_rhs expected_type disam_tbl + in - match token_expr with - | App (Var token_qual_ident, [], _) -> - ( Stmt.Basic - (AUAction { auaction_kind = AbortAU token_qual_ident }), - disam_tbl ) - | _ -> - Error.type_error loc - "abortAU token expected to be a variable") - | _ -> Error.type_error loc "abortAU takes exactly one argument" - else if - QualIdent.(qual_ident = QualIdent.from_ident Predefs.fpu_ident) - then - let field_opt = function - | Expr.App (Var qual_ident, [], _) as field_expr -> - let* field_expr = - disambiguate_process_expr field_expr Type.any disam_tbl - in - let* field_qual_ident, symbol = - Rewriter.resolve_and_find (QualIdent.to_loc qual_ident) qual_ident - in - let+ symbol = Rewriter.Symbol.reify symbol in - begin match symbol with - | FieldDef { field_type = App (Fld, [ given_type ], _); _ } -> - Some (field_qual_ident, given_type) - | _ -> None - end - | _ -> Rewriter.return None - in - let* ref_expr, field, fpu_exprs = - let* opt_list = - match args with - | Expr.App (Read, [ref_expr; field_expr], _) :: expr2 :: expr3_opt -> - let* field = field_opt field_expr in - field |> Rewriter.Option.map ~f:(fun field -> - Rewriter.return (ref_expr, field, expr2 :: expr3_opt)) - | _ -> Rewriter.return None - in - opt_list - |> Rewriter.Option.lazy_value ~default:(fun () -> - match args with - | expr1 :: expr2 :: expr3_opt -> - let* field = field_opt expr2 in - let* field =Rewriter.Option.map field ~f:(fun field_qual_ident -> - Rewriter.return (expr1, field_qual_ident, expr3_opt)) - in - Rewriter.Option.lazy_value field - ~default:(fun () -> Error.type_error (Expr.to_loc expr1) "Expected field location") - | _ -> Error.type_error loc "Could not find field location in fpu" - ) - in - let* ref_expr = - disambiguate_process_expr ref_expr Type.ref disam_tbl - in - let field_qual_ident, given_type = field in - let+ fpu_exprs = - Rewriter.List.map fpu_exprs ~f:(fun fpu_expr -> - disambiguate_process_expr fpu_expr given_type - disam_tbl) + Logs.debug (fun m -> + m "process_stmt: disam_assign_rhs: %a" Expr.pr + assign_rhs); + + let+ assign_rhs_callable_opt = + match assign_rhs with + | App (Var qual_ident, args, _) -> ( + let* qual_ident, symbol = + Rewriter.resolve_and_find (QualIdent.to_loc qual_ident) qual_ident in + let+ symbol = Rewriter.Symbol.reify symbol in + match symbol with CallDef call_def -> Some (symbol, qual_ident, args) | _ -> None) + | _ -> Rewriter.return None + in + - (* let* old_val_expr = disambiguate_process_expr old_val_expr given_type disam_tbl in - let+ new_val_expr = disambiguate_process_expr new_val_expr given_type disam_tbl in *) - let old_val_expr, new_val_expr = - match fpu_exprs with - | [ old_val_expr; new_val_expr ] -> - (Some old_val_expr, new_val_expr) - | [ new_val_expr ] -> (None, new_val_expr) + match assign_rhs_callable_opt with + | Some (symbol, proc_qual_ident, args) -> + begin + Logs.debug (fun m -> + m "process_stmt: assign_rhs_qual_ident: %a; %b" + QualIdent.pr proc_qual_ident + QualIdent.( + proc_qual_ident + = QualIdent.from_ident Predefs.bindAU_ident)); + + let (call_desc : Stmt.call_desc) = + { + call_lhs = + List.map var_decls_lhs ~f:(fun var -> var.var_name |> QualIdent.from_ident); + call_name = proc_qual_ident; + call_args = args; + } + in + (Stmt.Call call_desc, disam_tbl) + end + | None -> + match assign_rhs with + | App + ( Cas, + [ + App (Read, [ ref_expr; field_expr ], _); + old_val_expr; + new_val_expr; + ], + _ ) -> + Logs.debug (fun m -> + m "process_stmt: cas_assign_rhs: %a" Expr.pr + assign_rhs); + let field_qual_ident = Expr.to_qual_ident field_expr in + let cas_lhs = + match var_decls_lhs with + | [ lhs ] -> lhs.var_name |> QualIdent.from_ident | _ -> - Error.type_error loc - "fpu takes exactly three or four arguments" + Error.type_error stmt_loc + "Expected exactly one variable on left-hand side of cas" in - ( Stmt.Basic - (Fpu - { - fpu_ref = ref_expr; - fpu_field = field_qual_ident; - fpu_old_val = old_val_expr; - fpu_new_val = new_val_expr; - }), - disam_tbl ) - else Error.type_error loc "Unknown AU action" + let cas_desc = + Stmt. + { + cas_lhs; + cas_field = field_qual_ident; + cas_ref = ref_expr; + cas_old_val = old_val_expr; + cas_new_val = new_val_expr; + } + in + (Stmt.Cas cas_desc, disam_tbl) + | _ -> + let assign_desc = + Stmt.{ assign_desc with assign_lhs; assign_rhs } + in + (Stmt.Assign assign_desc, disam_tbl) + end + | Bind bind_desc -> + let* bind_lhs = + Rewriter.List.map bind_desc.bind_lhs ~f:(fun e -> + match e with + | App (Var qual_ident, [], _) + when not (QualIdent.is_qualified qual_ident) -> + disambiguate_process_expr e Type.any disam_tbl + | _ -> + Error.type_error stmt_loc + "Expected var identifier on left-hand side of bind") + in + let* bind_rhs = + disambiguate_process_expr bind_desc.bind_rhs Type.any + disam_tbl + in + let bind_desc = Stmt.{ bind_lhs; bind_rhs } in + Rewriter.return (Stmt.Bind bind_desc, disam_tbl) + | FieldWrite fw_desc -> + let* field_write_field, symbol = + Rewriter.resolve_and_find (QualIdent.to_loc fw_desc.field_write_field) fw_desc.field_write_field + in + let* symbol = Rewriter.Symbol.reify symbol in + let field_type = match symbol with + | FieldDef { field_type = App (Fld, [ field_type ], _); _ } -> + field_type + | _ -> Error.type_error (QualIdent.to_loc fw_desc.field_write_field) "Expected field" + in + let* field_write_ref = + disambiguate_process_expr fw_desc.field_write_ref Type.ref + disam_tbl + in + let+ field_write_val = + disambiguate_process_expr fw_desc.field_write_val field_type + disam_tbl + in + Stmt.FieldWrite { field_write_ref; field_write_field; field_write_val }, disam_tbl + + | FieldRead fr_desc -> + let* fr_var_qual_ident = + disambiguate_ident fr_desc.field_read_lhs disam_tbl + in + let* fr_var_qual_ident, symbol = + Rewriter.resolve_and_find stmt_loc fr_var_qual_ident + in + let* symbol = Rewriter.Symbol.reify symbol in + let* fr_type = + match symbol with + | VarDef var_def -> + let* var_type_expanded = + ProcessTypeExpr.expand_type_expr + var_def.var_decl.var_type + in + Rewriter.return var_type_expanded | _ -> - Error.error loc - "Internal error: process_au_action_stmt called with non-callable \ - expression") - | _ -> - Error.error loc - "Internal error: process_au_action_stmt called with non-assignment \ - statement" + Error.type_error stmt_loc + "Expected var identifier on left-hand side of field read" + in + let field_read_expr = + Expr.App + ( Read, + [ + fr_desc.field_read_ref; + App + ( Var fr_desc.field_read_field, + [], + { + Expr.expr_loc = stmt_loc; + expr_type = Type.bot; + } ); + ], + { Expr.expr_loc = stmt_loc; expr_type = Type.bot } ) + in + + let+ field_read_expr = + disambiguate_process_expr field_read_expr fr_type disam_tbl + in + + begin match field_read_expr with + | App + ( Read, + [ field_read_ref; App (Var field_read_field, [], _) ], + _ ) -> + let field_read_desc = + Stmt. + { + field_read_lhs = fr_var_qual_ident; + field_read_field; + field_read_ref; + } + in + (Stmt.FieldRead field_read_desc, disam_tbl) + | _ -> failwith "Unexpected error during type checking." + end + | Cas cs_desc -> ( + let* cs_var_qual_ident = + disambiguate_ident cs_desc.cas_lhs disam_tbl + in + let* cs_var_qual_ident, symbol = + Rewriter.resolve_and_find stmt_loc cs_var_qual_ident + in + let* symbol = Rewriter.Symbol.reify symbol in + let* cs_type = + match symbol with + | VarDef var_def -> + let* var_type_expanded = + ProcessTypeExpr.expand_type_expr + var_def.var_decl.var_type + in + Rewriter.return var_type_expanded + | _ -> + Error.type_error stmt_loc + "Expected var identifier on left-hand side of cas" + in + let expr_attr = + { Expr.expr_loc = stmt_loc; expr_type = Type.bot } + in + let cas_expr = + Expr.App + ( Cas, + [ + App + ( Read, + [ + cs_desc.cas_ref; + App (Var cs_desc.cas_field, [], expr_attr); + ], + expr_attr ); + cs_desc.cas_old_val; + cs_desc.cas_new_val; + ], + { Expr.expr_loc = stmt_loc; expr_type = Type.bool } + ) + in + + let+ cas_expr = + disambiguate_process_expr cas_expr cs_type disam_tbl + in + + match cas_expr with + | App + ( Cas, + [ + App (Read, [ cas_ref; App (Var cas_field, [], _) ], _); + cas_old_val; + cas_new_val; + ], + _ ) -> + let cas_desc = + Stmt. + { + cas_lhs = cs_var_qual_ident; + cas_field; + cas_ref; + cas_old_val; + cas_new_val; + } + in + (Stmt.Cas cas_desc, disam_tbl) + | _ -> failwith "Unexpected error during type checking.") + | Havoc qual_ident -> + let* qual_ident = disambiguate_ident qual_ident disam_tbl in + Rewriter.return (Stmt.Havoc qual_ident, disam_tbl) + | Return expr -> + let+ expr = + disambiguate_process_expr expr expected_return_type disam_tbl + in + (Stmt.Return expr, disam_tbl) + | Use use_desc -> + let* use_name, symbol = + let* id = disambiguate_ident use_desc.use_name disam_tbl in + Rewriter.resolve_and_find stmt_loc id + in + let* symbol = Rewriter.Symbol.reify symbol in + + let pred_decl, pred_def = + match symbol with + | CallDef + { + call_decl = { call_decl_kind = Pred; _ } as pred_decl; + call_def = FuncDef {func_body = pred_def} + } -> + pred_decl, pred_def + | CallDef + { + call_decl = + { call_decl_kind = Invariant; _ } as pred_decl; + call_def = FuncDef {func_body = pred_def} + } -> + pred_decl, pred_def + | _ -> + Error.type_error stmt_loc + ("Expected predicate or invariant identifier, but found " + ^ QualIdent.to_string use_name) + in + + let exists_vars = + Option.value pred_def ~default:(Expr.mk_unit Loc.dummy) + |> Expr.existential_vars_type + in + let find_type ident = + let ty_opt = Map.fold exists_vars ~init:None ~f:(fun ~key ~data acc -> + if Option.is_none acc + && String.(Ident.name ident = Ident.name key) + then Some data else acc) + in + match ty_opt with + | Some ty -> ty + | _ -> Error.type_error (Ident.to_loc ident) + (Printf.sprintf !"Could not find existential variable %{Ident} in %s %{QualIdent}" ident (Symbol.kind symbol) use_desc.use_name) + in + + let* use_args = + Rewriter.List.map use_desc.use_args ~f:(fun expr -> + disambiguate_expr expr disam_tbl) + in + + let* use_args = + process_callable_args stmt_loc pred_decl use_args + in + + let+ use_witnesses_or_binds = + Rewriter.List.map use_desc.use_witnesses_or_binds ~f:(fun (i, e) -> + match use_desc.use_kind with + | Fold -> + let ty = find_type i in + let+ e = disambiguate_process_expr e ty disam_tbl in + (i, e) + | Unfold -> + match e with + | App (Var qual_ident, [], _) when QualIdent.is_local qual_ident -> + let ty = + find_type (QualIdent.unqualify qual_ident) + in + let+ ie = disambiguate_process_expr (Expr.mk_var ~typ:(Type.mk_any (Ident.to_loc i)) (QualIdent.from_ident i)) ty disam_tbl in + (Expr.to_ident ie, e) + | _ -> Error.type_error (Expr.to_loc e) "Expected local identifier" + ) + in + + ( Stmt.Use { use_desc with use_name; use_args; use_witnesses_or_binds }, + disam_tbl ) + | New new_desc -> + let* new_qual_ident = + disambiguate_ident new_desc.new_lhs disam_tbl + in + let* new_qual_ident, symbol = + Rewriter.resolve_and_find stmt_loc new_qual_ident + in + let* symbol = Rewriter.Symbol.reify symbol in + let var_decl = + match symbol with + | VarDef var_def -> var_def.var_decl + | _ -> + Error.type_error stmt_loc + "Expected variable identifier on left-hand side of new" + in + let* var_type_expanded = + ProcessTypeExpr.expand_type_expr var_decl.var_type + in + + if Type.equal var_type_expanded Type.ref then + let process_field_init (field_name, expr_opt) = + let* field_name, symbol = + Rewriter.resolve_and_find stmt_loc field_name + in + let* field_type = + Rewriter.Symbol.reify_field_type stmt_loc symbol + in + let+ expr_opt = + Rewriter.Option.map expr_opt ~f:(fun expr -> + disambiguate_process_expr expr field_type disam_tbl) + in + (field_name, expr_opt) + in + let+ new_args = + Rewriter.List.map new_desc.new_args ~f:process_field_init + in + + let new_desc = Stmt.{ new_lhs = new_qual_ident; new_args } in + + (Stmt.New new_desc, disam_tbl) + else + type_mismatch_error stmt_loc Type.ref var_decl.var_type + (* The following constructs are not expected here because the parser stores these commands as Assign stmts. + The job of this function is to intercept the Assign stmts with the specific expressions on the RHS, and then transform + them to the appropriate construct, ie Call, New, BindAU, OpenAU, AbortAU, CommitAU etc. + + This function is not expected to go over these parts of the AST again. If the following constructs are + discovered by this function, then something unexpected has happened. *) + (* Now that we call process_symbol on arbitrarily AST elements, we need to deal with these constructs too *) + | Call call_desc -> ( + let* call_lhs = + Rewriter.List.map call_desc.call_lhs ~f:(fun qual_iden -> + let* qual_iden = disambiguate_ident qual_iden disam_tbl in + Rewriter.resolve stmt_loc qual_iden) + in + let* call_lhs_types = + Rewriter.List.map call_lhs ~f:(fun qual_iden -> + let* qual_iden, symbol = + Rewriter.resolve_and_find stmt_loc qual_iden + in + let* symbol = Rewriter.Symbol.reify symbol in + match symbol with + | VarDef var_def -> + let* var_type_expanded = + ProcessTypeExpr.expand_type_expr + var_def.var_decl.var_type + in + Rewriter.return var_type_expanded + | _ -> + Error.type_error stmt_loc + "Expected variable identifier on left-hand side of \ + call") + in + + let expected_return_type = + Type.mk_prod stmt_loc call_lhs_types + in + + let+ call_expr = + Expr.App + ( Var call_desc.call_name, + call_desc.call_args, + { Expr.expr_loc = stmt_loc; expr_type = Type.bot } ) + |> fun expr -> + disambiguate_process_expr expr expected_return_type disam_tbl + in + match call_expr with + | App (Var proc_qual_ident, args, _expr_attr) -> + let (call_desc : Stmt.call_desc) = + { + call_lhs; + call_name = proc_qual_ident; + call_args = args; + } + in + + (Stmt.Call call_desc, disam_tbl) + | _ -> failwith "Unexpected error during type checking.") + | AUAction _au_action_kind -> + internal_error stmt_loc + "Did not expect AU action stmts in AST at this stage." + | Fpu _fpu_desc -> + internal_error stmt_loc + "Did not expect Fpu stmts in AST at this stage." + let process_stmt ?(new_scope = true) (expected_return_type : Type.t) (stmt : Stmt.t) (disam_tbl : DisambiguationTbl.t) : (Stmt.t * DisambiguationTbl.t) Rewriter.t = @@ -1140,571 +1634,11 @@ module ProcessCallable = struct in (Stmt.Block { block_desc with block_body = stmt_list }, disam_tbl) - | Basic basic_stmt -> ( - match basic_stmt with - | VarDef var_def -> - let* var_decl = - ProcessTypeExpr.process_var_decl var_def.var_decl - in - let var_decl, disam_tbl' = - DisambiguationTbl.add_var_decl var_decl disam_tbl - in - let* _ = - Rewriter.introduce_symbol - (VarDef { var_decl; var_init = None }) - in - let+ stmt, disam_tbl' = - let var = QualIdent.from_ident var_decl.var_name in - match var_def.var_init with - | None -> - Rewriter.return @@ (Stmt.Basic (Havoc var), disam_tbl') - | Some expr -> - (* let* expr = disambiguate_process_expr expr var_decl.var_type disam_tbl in *) - let var_expr = Expr.from_var_decl var_def.var_decl in - - (* Expr.App (Var var, [], {expr_loc = stmt.stmt_loc; expr_type = var_decl.var_type}) in *) - let assign_desc = - Stmt. - { - assign_lhs = [ var_expr ]; - assign_rhs = expr; - assign_is_init = true; - } - in - - let+ stmt, disam_tbl' = - process_stmt - { - stmt_desc = Stmt.Basic (Assign assign_desc); - stmt_loc = stmt.stmt_loc; - } - disam_tbl' - in - - (stmt.stmt_desc, disam_tbl') - in - (stmt, disam_tbl') - | Spec (sk, spec) -> - let+ spec = process_stmt_spec disam_tbl spec in - (Stmt.Basic (Spec (sk, spec)), disam_tbl) - | Assign assign_desc -> ( - let* assign_lhs = - Rewriter.List.map assign_desc.assign_lhs ~f:(fun expr -> - disambiguate_process_expr expr Type.any disam_tbl) - in - - let* _ = - Rewriter.List.iter assign_lhs ~f:(fun expr -> - match expr with - | App (Var qual_ident, [], _) -> ( - let+ _, symbol = - Rewriter.resolve_and_find stmt.stmt_loc qual_ident - in - match Rewriter.Symbol.orig_symbol symbol with - | VarDef { var_decl = { var_const = true; _ }; _ } - when not assign_desc.assign_is_init -> - Error.type_error (Expr.to_loc expr) - (Printf.sprintf !"Cannot assign to val %{QualIdent}" qual_ident) - | VarDef _ -> () - | _ -> - Error.type_error (Expr.to_loc expr) - "Expected assignable expression on left-hand \ - side of assignment") - | App (Read, [ ref_expr; field_expr ], _) -> - Rewriter.return () - | _ -> - Error.type_error (Expr.to_loc expr) - "Expected assignable expression on left-hand side \ - of assignment") - in - - Logs.debug (fun m -> - m "process_stmt: assign_desc: %a" Stmt.pr_basic_stmt - (Assign assign_desc)); - let* disam_assign_rhs = - disambiguate_expr assign_desc.assign_rhs disam_tbl - in - - Logs.debug (fun m -> - m "process_stmt: disam_assign_rhs: %a" Expr.pr - disam_assign_rhs); - - let* is_assign_rhs_callable = - match disam_assign_rhs with - | App (Var qual_ident, _, _) -> ( - if Predefs.is_qual_ident_au_cmnd qual_ident then - Rewriter.return true - else - let+ _, symbol, _ = - Logs.debug (fun m -> - m - "process_stmt: disam_find: \ - assign_rhs_qual_ident: %a" - QualIdent.pr qual_ident); - Rewriter.find stmt.stmt_loc qual_ident - in - match symbol with CallDef _ -> true | _ -> false) - | _ -> Rewriter.return false - in - - match is_assign_rhs_callable with - | true -> ( - match assign_desc.assign_rhs with - | App (Var proc_qual_ident, args, expr_attr) -> ( - Logs.debug (fun m -> - m "process_stmt: assign_rhs_qual_ident: %a; %b" - QualIdent.pr proc_qual_ident - QualIdent.( - proc_qual_ident - = QualIdent.from_ident Predefs.bindAU_ident)); - if Predefs.is_qual_ident_au_cmnd proc_qual_ident then - process_au_action_stmt stmt.stmt_desc stmt.stmt_loc - disam_tbl - else - let expected_return_type = - Type.mk_prod - (Expr.to_loc assign_desc.assign_rhs) - (List.map assign_lhs ~f:Expr.to_type) - in - - let+ call_expr = - Expr.App (Var proc_qual_ident, args, expr_attr) - |> fun expr -> - disambiguate_process_expr expr expected_return_type - disam_tbl - in - - match call_expr with - | App (Var proc_qual_ident, args, _expr_attr) -> - let (call_desc : Stmt.call_desc) = - { - call_lhs = - List.map assign_lhs ~f:Expr.to_qual_ident; - call_name = proc_qual_ident; - call_args = args; - } - in - - (Stmt.Basic (Call call_desc), disam_tbl) - | _ -> - failwith "Unexpected error during type checking.") - | _ -> failwith "Unexpected error during type checking.") - | false -> ( - let expected_type = - Type.mk_prod stmt.stmt_loc - (List.map assign_lhs ~f:Expr.to_type) - in - - let+ assign_rhs = - disambiguate_process_expr assign_desc.assign_rhs - expected_type disam_tbl - in - - match assign_rhs with - | App (Read, [ ref_expr; field_expr ], _) -> - Logs.debug (fun m -> - m "process_stmt: read_assign_rhs: %a" Expr.pr - assign_rhs); - let field_qual_ident = Expr.to_qual_ident field_expr in - let field_read_lhs = - match assign_lhs with - | [ lhs ] -> Expr.to_qual_ident lhs - | _ -> - Error.type_error stmt.stmt_loc - "Expected exactly one left-hand side expression of field read" - in - - let field_read_desc = - Stmt. - { - field_read_lhs; - field_read_field = field_qual_ident; - field_read_ref = ref_expr; - } - in - (Stmt.Basic (FieldRead field_read_desc), disam_tbl) - | App - ( Cas, - [ - App (Read, [ ref_expr; field_expr ], _); - old_val_expr; - new_val_expr; - ], - _ ) -> - Logs.debug (fun m -> - m "process_stmt: cas_assign_rhs: %a" Expr.pr - assign_rhs); - let field_qual_ident = Expr.to_qual_ident field_expr in - let cas_lhs = - match assign_lhs with - | [ lhs ] -> Expr.to_qual_ident lhs - | _ -> - Error.type_error stmt.stmt_loc - "Expected exactly one left-hand side expression in cas" - in - - let cas_desc = - Stmt. - { - cas_lhs; - cas_field = field_qual_ident; - cas_ref = ref_expr; - cas_old_val = old_val_expr; - cas_new_val = new_val_expr; - } - in - (Stmt.Basic (Cas cas_desc), disam_tbl) - | _ -> - let assign_desc = - Stmt.{ assign_desc with assign_lhs; assign_rhs } - in - - (Stmt.Basic (Assign assign_desc), disam_tbl))) - | Bind bind_desc -> - let* bind_lhs = - Rewriter.List.map bind_desc.bind_lhs ~f:(fun e -> - match e with - | App (Var qual_ident, [], _) - when not (QualIdent.is_qualified qual_ident) -> - disambiguate_process_expr e Type.any disam_tbl - | _ -> - Error.type_error stmt.stmt_loc - "Expected var identifier on left-hand side of bind") - in - let* bind_rhs = - disambiguate_process_expr bind_desc.bind_rhs Type.any - disam_tbl - in - let bind_desc = Stmt.{ bind_lhs; bind_rhs } in - Rewriter.return (Stmt.Basic (Bind bind_desc), disam_tbl) - | FieldWrite fw_desc -> - let* field_write_field, symbol = - Rewriter.resolve_and_find (QualIdent.to_loc fw_desc.field_write_field) fw_desc.field_write_field - in - let* symbol = Rewriter.Symbol.reify symbol in - let field_type = match symbol with - | FieldDef { field_type = App (Fld, [ field_type ], _); _ } -> - field_type - | _ -> Error.type_error (QualIdent.to_loc fw_desc.field_write_field) "Expected field" - in - let* field_write_ref = - disambiguate_process_expr fw_desc.field_write_ref Type.ref - disam_tbl - in - let+ field_write_val = - disambiguate_process_expr fw_desc.field_write_val field_type - disam_tbl - in - Stmt.Basic (FieldWrite { field_write_ref; field_write_field; field_write_val }), disam_tbl - - | FieldRead fr_desc -> ( - let* fr_var_qual_ident = - disambiguate_ident fr_desc.field_read_lhs disam_tbl - in - let* fr_var_qual_ident, symbol = - Rewriter.resolve_and_find stmt.stmt_loc fr_var_qual_ident - in - let* symbol = Rewriter.Symbol.reify symbol in - let* fr_type = - match symbol with - | VarDef var_def -> - let* var_type_expanded = - ProcessTypeExpr.expand_type_expr - var_def.var_decl.var_type - in - Rewriter.return var_type_expanded - | _ -> - Error.type_error stmt.stmt_loc - "Expected var identifier on left-hand side of field read" - in - let field_read_expr = - Expr.App - ( Read, - [ - fr_desc.field_read_ref; - App - ( Var fr_desc.field_read_field, - [], - { - Expr.expr_loc = stmt.stmt_loc; - expr_type = Type.bot; - } ); - ], - { Expr.expr_loc = stmt.stmt_loc; expr_type = Type.bot } ) - in - - let+ field_read_expr = - disambiguate_process_expr field_read_expr fr_type disam_tbl - in - - match field_read_expr with - | App - ( Read, - [ field_read_ref; App (Var field_read_field, [], _) ], - _ ) -> - let field_read_desc = - Stmt. - { - field_read_lhs = fr_var_qual_ident; - field_read_field; - field_read_ref; - } - in - (Stmt.Basic (FieldRead field_read_desc), disam_tbl) - | _ -> failwith "Unexpected error during type checking.") - | Cas cs_desc -> ( - let* cs_var_qual_ident = - disambiguate_ident cs_desc.cas_lhs disam_tbl - in - let* cs_var_qual_ident, symbol = - Rewriter.resolve_and_find stmt.stmt_loc cs_var_qual_ident - in - let* symbol = Rewriter.Symbol.reify symbol in - let* cs_type = - match symbol with - | VarDef var_def -> - let* var_type_expanded = - ProcessTypeExpr.expand_type_expr - var_def.var_decl.var_type - in - Rewriter.return var_type_expanded - | _ -> - Error.type_error stmt.stmt_loc - "Expected var identifier on left-hand side of cas" - in - let expr_attr = - { Expr.expr_loc = stmt.stmt_loc; expr_type = Type.bot } - in - let cas_expr = - Expr.App - ( Cas, - [ - App - ( Read, - [ - cs_desc.cas_ref; - App (Var cs_desc.cas_field, [], expr_attr); - ], - expr_attr ); - cs_desc.cas_old_val; - cs_desc.cas_new_val; - ], - { Expr.expr_loc = stmt.stmt_loc; expr_type = Type.bool } - ) - in - - let+ cas_expr = - disambiguate_process_expr cas_expr cs_type disam_tbl - in - - match cas_expr with - | App - ( Cas, - [ - App (Read, [ cas_ref; App (Var cas_field, [], _) ], _); - cas_old_val; - cas_new_val; - ], - _ ) -> - let cas_desc = - Stmt. - { - cas_lhs = cs_var_qual_ident; - cas_field; - cas_ref; - cas_old_val; - cas_new_val; - } - in - (Stmt.Basic (Cas cas_desc), disam_tbl) - | _ -> failwith "Unexpected error during type checking.") - | Havoc qual_ident -> - let* qual_ident = disambiguate_ident qual_ident disam_tbl in - Rewriter.return (Stmt.Basic (Havoc qual_ident), disam_tbl) - | Return expr -> - let+ expr = - disambiguate_process_expr expr expected_return_type disam_tbl - in - (Stmt.Basic (Return expr), disam_tbl) - | Use use_desc -> - let* use_name, symbol = - let* id = disambiguate_ident use_desc.use_name disam_tbl in - Rewriter.resolve_and_find stmt.stmt_loc id - in - let* symbol = Rewriter.Symbol.reify symbol in - - let pred_decl, pred_def = - match symbol with - | CallDef - { - call_decl = { call_decl_kind = Pred; _ } as pred_decl; - call_def = FuncDef {func_body = pred_def} - } -> - pred_decl, pred_def - | CallDef - { - call_decl = - { call_decl_kind = Invariant; _ } as pred_decl; - call_def = FuncDef {func_body = pred_def} - } -> - pred_decl, pred_def - | _ -> - Error.type_error stmt.stmt_loc - ("Expected predicate or invariant identifier, but found " - ^ QualIdent.to_string use_name) - in - - let exists_vars = - Option.value pred_def ~default:(Expr.mk_unit Loc.dummy) - |> Expr.existential_vars_type - in - let find_type ident = - let ty_opt = Map.fold exists_vars ~init:None ~f:(fun ~key ~data acc -> - if Option.is_none acc - && String.(Ident.name ident = Ident.name key) - then Some data else acc) - in - match ty_opt with - | Some ty -> ty - | _ -> Error.type_error (Ident.to_loc ident) - (Printf.sprintf !"Could not find existential variable %{Ident} in %s %{QualIdent}" ident (Symbol.kind symbol) use_desc.use_name) - in - - let* use_args = - Rewriter.List.map use_desc.use_args ~f:(fun expr -> - disambiguate_expr expr disam_tbl) - in - - let* use_args = - process_callable_args stmt.stmt_loc pred_decl use_args - in - - let+ use_witnesses_or_binds = - Rewriter.List.map use_desc.use_witnesses_or_binds ~f:(fun (i, e) -> - match use_desc.use_kind with - | Fold -> - let ty = find_type i in - let+ e = disambiguate_process_expr e ty disam_tbl in - (i, e) - | Unfold -> - match e with - | App (Var qual_ident, [], _) when QualIdent.is_local qual_ident -> - let ty = - find_type (QualIdent.unqualify qual_ident) - in - let+ ie = disambiguate_process_expr (Expr.mk_var ~typ:(Type.mk_any (Ident.to_loc i)) (QualIdent.from_ident i)) ty disam_tbl in - (Expr.to_ident ie, e) - | _ -> Error.type_error (Expr.to_loc e) "Expected local identifier" - ) - in - - ( Stmt.Basic (Use { use_desc with use_name; use_args; use_witnesses_or_binds }), - disam_tbl ) - | New new_desc -> - let* new_qual_ident = - disambiguate_ident new_desc.new_lhs disam_tbl - in - let* new_qual_ident, symbol = - Rewriter.resolve_and_find stmt.stmt_loc new_qual_ident - in - let* symbol = Rewriter.Symbol.reify symbol in - let var_decl = - match symbol with - | VarDef var_def -> var_def.var_decl - | _ -> - Error.type_error stmt.stmt_loc - "Expected variable identifier on left-hand side of new" - in - let* var_type_expanded = - ProcessTypeExpr.expand_type_expr var_decl.var_type - in - - if Type.equal var_type_expanded Type.ref then - let process_field_init (field_name, expr_opt) = - let* field_name, symbol = - Rewriter.resolve_and_find stmt.stmt_loc field_name - in - let* field_type = - Rewriter.Symbol.reify_field_type stmt.stmt_loc symbol - in - let+ expr_opt = - Rewriter.Option.map expr_opt ~f:(fun expr -> - disambiguate_process_expr expr field_type disam_tbl) - in - (field_name, expr_opt) - in - let+ new_args = - Rewriter.List.map new_desc.new_args ~f:process_field_init - in - - let new_desc = Stmt.{ new_lhs = new_qual_ident; new_args } in - - (Stmt.Basic (New new_desc), disam_tbl) - else - type_mismatch_error stmt.stmt_loc Type.ref var_decl.var_type - (* The following constructs are not expected here because the parser stores these commands as Assign stmts. - The job of this function is to intercept the Assign stmts with the specific expressions on the RHS, and then transform - them to the appropriate construct, ie Call, New, BindAU, OpenAU, AbortAU, CommitAU etc. - - This function is not expected to go over these parts of the AST again. If the following constructs are - discovered by this function, then something unexpected has happened. *) - (* Now that we call process_symbol on arbitrarily AST elements, we need to deal with these constructs too *) - | Call call_desc -> ( - let* call_lhs = - Rewriter.List.map call_desc.call_lhs ~f:(fun qual_iden -> - let* qual_iden = disambiguate_ident qual_iden disam_tbl in - Rewriter.resolve stmt.stmt_loc qual_iden) - in - let* call_lhs_types = - Rewriter.List.map call_lhs ~f:(fun qual_iden -> - let* qual_iden, symbol = - Rewriter.resolve_and_find stmt.stmt_loc qual_iden - in - let* symbol = Rewriter.Symbol.reify symbol in - match symbol with - | VarDef var_def -> - let* var_type_expanded = - ProcessTypeExpr.expand_type_expr - var_def.var_decl.var_type - in - Rewriter.return var_type_expanded - | _ -> - Error.type_error stmt.stmt_loc - "Expected variable identifier on left-hand side of \ - call") - in - - let expected_return_type = - Type.mk_prod stmt.stmt_loc call_lhs_types - in - - let+ call_expr = - Expr.App - ( Var call_desc.call_name, - call_desc.call_args, - { Expr.expr_loc = stmt.stmt_loc; expr_type = Type.bot } ) - |> fun expr -> - disambiguate_process_expr expr expected_return_type disam_tbl - in - - match call_expr with - | App (Var proc_qual_ident, args, _expr_attr) -> - let (call_desc : Stmt.call_desc) = - { - call_lhs; - call_name = proc_qual_ident; - call_args = args; - } - in - - (Stmt.Basic (Call call_desc), disam_tbl) - | _ -> failwith "Unexpected error during type checking.") - | AUAction _au_action_kind -> - internal_error (Stmt.to_loc stmt) - "Did not expect AU action stmts in AST at this stage." - | Fpu _fpu_desc -> - internal_error (Stmt.to_loc stmt) - "Did not expect Fpu stmts in AST at this stage.") + | Basic basic_stmt -> + let+ basic_stmt, disam_tbl' = + process_basic_stmt expected_return_type basic_stmt (Stmt.to_loc stmt) disam_tbl + in + (Stmt.Basic basic_stmt, disam_tbl') | Loop loop_desc -> let* loop_contract = Rewriter.List.map loop_desc.loop_contract