From e9c1da4abf861cb076b229c7a8974b1ec6cf92d6 Mon Sep 17 00:00:00 2001 From: Thomas Wies Date: Mon, 18 Nov 2024 19:19:20 -0500 Subject: [PATCH] add support for thread spawning --- lib/ast/astDef.ml | 7 ++++--- lib/ast/rewriter.ml | 2 +- lib/ast/symbolTbl.ml | 16 +++++++++------- lib/frontend/lexer.mll | 1 + lib/frontend/parser.mly | 14 +++++++++++++- lib/frontend/rewrites/rewrites.ml | 9 ++++++--- lib/frontend/typing.ml | 4 +++- 7 files changed, 37 insertions(+), 16 deletions(-) diff --git a/lib/ast/astDef.ml b/lib/ast/astDef.ml index 381deaf..6f5fdfc 100644 --- a/lib/ast/astDef.ml +++ b/lib/ast/astDef.ml @@ -1212,6 +1212,7 @@ module Stmt = struct call_lhs : qual_ident list; call_name : qual_ident; call_args : expr list; + call_is_spawn : bool; } @@ -1377,7 +1378,7 @@ module Stmt = struct | Call cstm -> ( match cstm.call_lhs with | [] -> - fprintf ppf "@[%a(@[%a@])@]" QualIdent.pr cstm.call_name + fprintf ppf "@[%s%a(@[%a@])@]" (if cstm.call_is_spawn then "spawn " else "") QualIdent.pr cstm.call_name Expr.pr_list cstm.call_args | _ -> fprintf ppf "@[<2>%a@ :=@ @[%a(@[%a@])@]@]" QualIdent.pr_list @@ -1508,8 +1509,8 @@ module Stmt = struct let mk_cond ~loc ?(cond_if_assumes_false = false) test then_ else_ = { stmt_desc = Cond { cond_test = test; cond_then = then_; cond_else = else_; cond_if_assumes_false }; stmt_loc = loc } - let mk_call ~loc ?(lhs=[]) name args = - { stmt_desc = Basic (Call { call_lhs = lhs; call_name = name; call_args = args }); stmt_loc = loc } + let mk_call ~loc ?(lhs=[]) name args ~is_spawn = + { stmt_desc = Basic (Call { call_lhs = lhs; call_name = name; call_args = args; call_is_spawn = is_spawn }); stmt_loc = loc } let mk_assign ~loc lhs rhs = { stmt_desc = Basic (Assign { assign_lhs = lhs; assign_rhs = rhs; assign_is_init = false }); stmt_loc = loc } diff --git a/lib/ast/rewriter.ml b/lib/ast/rewriter.ml index e2ae892..ac1d214 100644 --- a/lib/ast/rewriter.ml +++ b/lib/ast/rewriter.ml @@ -864,7 +864,7 @@ module Stmt = struct let call_name = f call.call_name in { stmt with - stmt_desc = Basic (Call { call_lhs; call_name; call_args }); + stmt_desc = Basic (Call { call with call_lhs; call_name; call_args }); } | Use use_desc -> let use_name = f use_desc.use_name in diff --git a/lib/ast/symbolTbl.ml b/lib/ast/symbolTbl.ml index 1e72d11..d9b889c 100644 --- a/lib/ast/symbolTbl.ml +++ b/lib/ast/symbolTbl.ml @@ -217,7 +217,8 @@ let resolve name (tbl : t) : else Set.add inst_scopes target_qual_ident in go_forward new_inst_scopes tbl.tbl_root subst new_path - (* /// why do we jump to tbl.root from here? *) + (* /// why do we jump to tbl.root from here? + TW: Because the remainder of the path to be traversed has been requalified relative to target_qual_ident, which is a fully qualified id. *) | Import qual_ident, _ -> (* Logs.debug (fun m -> m "SymbolTbl.resolve.go_forward: 2"); *) let target_qual_ident = (*QualIdent.requalify subst*) qual_ident in @@ -442,11 +443,6 @@ let add_symbol ?(scope : scope option = None) symbol tbl = | ModInst mod_inst -> let mod_inst_qual_ident, subst = match mod_inst.mod_inst_def with - | None -> - let _, mod_inst_type, _mod_inst_symbol, _ = - resolve_and_find_exn mod_inst.mod_inst_type tbl - in - (mod_inst_type, []) | Some (mod_inst_func, mod_inst_args) -> ( let _, mod_inst_func, mod_inst_symbol, _subst1 = resolve_and_find_exn mod_inst_func tbl @@ -467,12 +463,18 @@ let add_symbol ?(scope : scope option = None) symbol tbl = (formal_id, QualIdent.to_list arg)) in match res with - | Ok subst -> (mod_inst_func, subst) + | Ok subst -> + (mod_inst_func, subst) | Unequal_lengths -> Error.type_error symbol_loc (Printf.sprintf !"Module %{QualIdent} expects %d arguments" mod_inst_func (List.length formals))) + | None -> + let _, mod_inst_type, _mod_inst_symbol, _ = + resolve_and_find_exn mod_inst.mod_inst_type tbl + in + (mod_inst_type, []) in let is_abstract = mod_inst.mod_inst_is_interface in add_to_map diff --git a/lib/frontend/lexer.mll b/lib/frontend/lexer.mll index 9b2e704..1468388 100644 --- a/lib/frontend/lexer.mll +++ b/lib/frontend/lexer.mll @@ -61,6 +61,7 @@ let _ = ("returns", RETURNS); (* ("subseteq", SUBSETEQ); *) ("Set", SET); + ("spawn", SPAWN); ("true", CONSTVAL (Expr.Bool true)); ("type", TYPE); ("unfold", USE (Stmt.Unfold)); diff --git a/lib/frontend/parser.mly b/lib/frontend/parser.mly index ec0b937..aa9e793 100644 --- a/lib/frontend/parser.mly +++ b/lib/frontend/parser.mly @@ -20,7 +20,7 @@ open Ast %token SPEC %token USE %token HAVOC NEW RETURN OWN AU -%token IF ELSE WHILE CAS +%token IF ELSE WHILE CAS SPAWN %token FUNC %token PROC AXIOM LEMMA %token CASE DATA INT REAL BOOL PERM SET MAP ATOMICTOKEN FIELD REF @@ -432,6 +432,18 @@ stmt_wo_trailing_substmt: in Stmt.(Basic (Assign assign)) } +(* thread spawn *) +| SPAWN; id = qual_ident; LPAREN; args = separated_list(COMMA, expr); RPAREN; SEMICOLON { + let open Stmt in + let call = + { call_lhs = []; + call_name = Expr.to_qual_ident id; + call_args = args; + call_is_spawn = true + } + in + Basic (Call call) +} (* assignment / allocation *) | es = separated_nonempty_list(COMMA, expr); COLONEQ; e = assign_rhs; SEMICOLON { let open Stmt in diff --git a/lib/frontend/rewrites/rewrites.ml b/lib/frontend/rewrites/rewrites.ml index b250be4..c7a1115 100644 --- a/lib/frontend/rewrites/rewrites.ml +++ b/lib/frontend/rewrites/rewrites.ml @@ -620,7 +620,7 @@ let rec rewrite_loops (stmt : Stmt.t) : Stmt.t Rewriter.t = Stmt.mk_call ~loc ~lhs:lhs_list (QualIdent.from_ident loop_proc_name) - args_list + args_list ~is_spawn:false in (* TODO: Rename variables from curr_vars to loop_vars in loop body *) @@ -700,7 +700,7 @@ let rec rewrite_loops (stmt : Stmt.t) : Stmt.t Rewriter.t = Stmt.mk_call ~loc ~lhs:lhs_list (QualIdent.from_ident loop_proc_name) - args_list + args_list ~is_spawn:false in Logs.debug (fun m -> m "Loop new_stmt:\n %a" Stmt.pr new_stmt); @@ -1104,7 +1104,10 @@ let rec rewrite_call_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t = let call_decl, call_def = match symbol with - | CallDef c -> (c.call_decl, c.call_def) + | CallDef c -> + if call_desc.call_is_spawn + then ({c.call_decl with call_decl_postcond = []; call_decl_returns= []}, c.call_def) + else (c.call_decl, c.call_def) | _ -> Error.error stmt.stmt_loc "Expected a call_def" in diff --git a/lib/frontend/typing.ml b/lib/frontend/typing.ml index 4f89876..88f5818 100644 --- a/lib/frontend/typing.ml +++ b/lib/frontend/typing.ml @@ -1282,6 +1282,7 @@ module ProcessCallable = struct List.map var_decls_lhs ~f:(fun var -> var.var_name |> QualIdent.from_ident); call_name = proc_qual_ident; call_args = args; + call_is_spawn = false; } in (Stmt.Call call_desc, disam_tbl) @@ -1562,7 +1563,7 @@ module ProcessCallable = struct This function is not expected to go over these parts of the AST again. If the following constructs are discovered by this function, then something unexpected has happened. *) (* Now that we call process_symbol on arbitrarily AST elements, we need to deal with these constructs too *) - | Call call_desc -> ( + | Call call_desc -> ( let* call_lhs, var_decls_lhs = Rewriter.List.fold_right call_desc.call_lhs ~init:([], []) ~f:(fun orig_qual_ident (assign_lhs, var_decls_lhs) -> @@ -1592,6 +1593,7 @@ module ProcessCallable = struct | App (Var proc_qual_ident, args, _expr_attr) -> let (call_desc : Stmt.call_desc) = { + call_desc with call_lhs; call_name = proc_qual_ident; call_args = args;