From 120fe618e3d11f0294962c70e9cd1fe5ba22c91d Mon Sep 17 00:00:00 2001 From: Alasdair Date: Mon, 18 Nov 2024 21:18:16 +0000 Subject: [PATCH] SV: Further annotations to optimize mappings This introduces 3 annotations that work together to help the Jib compilation process optimize mappings - Whenever we elaborate a mapping 'X' we always generate a X_backwards (or forwards) call under a X_backwards_matches guard. The rewrite notes this fact by placing a $[mapping_guarded] attribute on the call to X_backwards - The last clause of a generated mapping is annotated with a $[mapping_last] attribute. This required extending the AST (and ANF) with attribute support here. - Generated mappings themselves are annotation with a $[mapping_function] attribute. For each mapping_function (X_backwards or similar) we generated a X_backwards_infallible function, which generates a match statement that returns a dummy value rather than match_failure when any $[mapping_last] clause is not matched. Any call to X_backwards is changed to X_backwards_infallible when it has a $[mapping_guarded] attribute. SMT: Always optimize pure matches (Note that match_failure is a side-effect) --- src/lib/anf.ml | 118 +++++++++++------ src/lib/anf.mli | 18 ++- src/lib/chunk_ast.ml | 13 +- src/lib/initial_check.ml | 13 +- src/lib/jib_compile.ml | 227 ++++++++++++++++++++------------ src/lib/jib_compile.mli | 2 - src/lib/mappings.ml | 9 +- src/lib/parse_ast.ml | 7 +- src/lib/parser.mly | 10 +- src/lib/pretty_print_sail.ml | 16 ++- src/lib/rewrites.ml | 30 +++-- src/lib/type_check.ml | 6 +- src/sail_c_backend/c_backend.ml | 9 +- src/sail_smt_backend/jib_smt.ml | 14 +- src/sail_sv_backend/jib_sv.ml | 51 +++++-- 15 files changed, 361 insertions(+), 182 deletions(-) diff --git a/src/lib/anf.ml b/src/lib/anf.ml index 46e42914c..b57602ca3 100644 --- a/src/lib/anf.ml +++ b/src/lib/anf.ml @@ -57,9 +57,9 @@ module Big_int = Nat_big_num (* 1. Conversion to A-normal form (ANF) *) (**************************************************************************) -type aexp_annot = { loc : l; env : Env.t; uannot : uannot } +type anf_annot = { loc : l; env : Env.t; uannot : uannot } -type 'a aexp = AE_aux of 'a aexp_aux * aexp_annot +type 'a aexp = AE_aux of 'a aexp_aux * anf_annot and 'a aexp_aux = | AE_val of 'a aval @@ -73,8 +73,8 @@ and 'a aexp_aux = | AE_throw of 'a aval * 'a | AE_if of 'a aval * 'a aexp * 'a aexp * 'a | AE_field of 'a aval * id * 'a - | AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp) list * 'a - | AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp) list * 'a + | AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a + | AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a | AE_struct_update of 'a aval * 'a aval Bindings.t * 'a | AE_for of id * 'a aexp * 'a aexp * 'a aexp * order * 'a aexp | AE_loop of loop * 'a aexp * 'a aexp @@ -82,7 +82,7 @@ and 'a aexp_aux = and sc_op = SC_and | SC_or -and 'a apat = AP_aux of 'a apat_aux * Env.t * l +and 'a apat = AP_aux of 'a apat_aux * anf_annot and 'a apat_aux = | AP_tuple of 'a apat list @@ -112,7 +112,7 @@ let aexp_loc (AE_aux (_, { loc = l; _ })) = l (* Renaming variables in ANF expressions *) -let rec apat_bindings (AP_aux (apat_aux, _, _)) = +let rec apat_bindings (AP_aux (apat_aux, _)) = match apat_aux with | AP_tuple apats -> List.fold_left IdSet.union IdSet.empty (List.map apat_bindings apats) | AP_id (id, _) -> IdSet.singleton id @@ -127,7 +127,7 @@ let rec apat_bindings (AP_aux (apat_aux, _, _)) = (** This function returns the types of all bound variables in a pattern. It ignores AP_global, apat_globals is used for that. *) -let rec apat_types (AP_aux (apat_aux, env, _)) = +let rec apat_types (AP_aux (apat_aux, { env; _ })) = let merge id b1 b2 = match (b1, b2) with | None, None -> None @@ -148,7 +148,7 @@ let rec apat_types (AP_aux (apat_aux, env, _)) = | AP_struct (afpats, _) -> List.fold_left (Bindings.merge merge) Bindings.empty (List.map (fun (_, apat) -> apat_types apat) afpats) -let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) = +let rec apat_rename from_id to_id (AP_aux (apat_aux, annot)) = let apat_aux = match apat_aux with | AP_tuple apats -> AP_tuple (List.map (apat_rename from_id to_id) apats) @@ -164,7 +164,7 @@ let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) = | AP_struct (afpats, typ) -> AP_struct (List.map (fun (field, apat) -> (field, apat_rename from_id to_id apat)) afpats, typ) in - AP_aux (apat_aux, env, l) + AP_aux (apat_aux, annot) let rec aval_typ = function | AV_lit (_, typ) -> typ @@ -251,9 +251,9 @@ let rec aexp_rename from_id to_id (AE_aux (aexp, annot)) = in AE_aux (aexp, annot) -and apexp_rename from_id to_id (apat, aexp1, aexp2) = - if IdSet.mem from_id (apat_bindings apat) then (apat, aexp1, aexp2) - else (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2) +and apexp_rename from_id to_id (apat, aexp1, aexp2, uannot) = + if IdSet.mem from_id (apat_bindings apat) then (apat, aexp1, aexp2, uannot) + else (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2, uannot) let rec fold_aexp f (AE_aux (aexp, annot)) = let aexp = @@ -269,17 +269,39 @@ let rec fold_aexp f (AE_aux (aexp, annot)) = | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) -> AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4) | AE_match (aval, cases, typ) -> - AE_match (aval, List.map (fun (pat, aexp1, aexp2) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2)) cases, typ) + AE_match + ( aval, + List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2, uannot)) cases, + typ + ) | AE_try (aexp, cases, typ) -> AE_try ( fold_aexp f aexp, - List.map (fun (pat, aexp1, aexp2) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2)) cases, + List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2, uannot)) cases, typ ) | (AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _) as v -> v in f (AE_aux (aexp, annot)) +let rec is_pure_aexp effect_info (AE_aux (aexp, { uannot; _ })) = + match get_attribute "anf_pure" uannot with + | Some _ -> true + | None -> ( + match aexp with + | AE_app (f, _, _) -> Effects.function_is_pure f effect_info + | AE_typ (aexp, _) -> is_pure_aexp effect_info aexp + | AE_let (Immutable, _, _, aexp1, aexp2, _) -> is_pure_aexp effect_info aexp1 && is_pure_aexp effect_info aexp2 + | AE_match (_, arms, _) -> + List.for_all (fun (_, guard, aexp, _) -> is_pure_aexp effect_info guard && is_pure_aexp effect_info aexp) arms + | AE_short_circuit (_, _, aexp) -> is_pure_aexp effect_info aexp + | AE_if (_, then_aexp, else_aexp, _) -> is_pure_aexp effect_info then_aexp && is_pure_aexp effect_info else_aexp + | AE_val _ | AE_field _ -> true + | _ -> false + ) + +let is_pure_case effect_info (_, guard, body, _) = is_pure_aexp effect_info guard && is_pure_aexp effect_info body + let aexp_bindings aexp = let ids = ref IdSet.empty in let collect_lets = function @@ -336,14 +358,14 @@ let rec no_shadow ids (AE_aux (aexp, annot)) = in AE_aux (aexp, annot) -and no_shadow_apexp ids (apat, aexp1, aexp2) = +and no_shadow_apexp ids (apat, aexp1, aexp2, uannot) = let shadows = IdSet.inter (apat_bindings apat) ids in let shadows = List.map (fun id -> (id, new_shadow id)) (IdSet.elements shadows) in let rename aexp = List.fold_left (fun aexp (from_id, to_id) -> aexp_rename from_id to_id aexp) aexp shadows in let rename_apat apat = List.fold_left (fun apat (from_id, to_id) -> apat_rename from_id to_id apat) apat shadows in let ids = IdSet.union (apat_bindings apat) (IdSet.union ids (IdSet.of_list (List.map snd shadows))) in let new_guard = no_shadow ids (rename aexp1) in - (rename_apat apat, new_guard, no_shadow (IdSet.union ids (aexp_bindings new_guard)) (rename aexp2)) + (rename_apat apat, new_guard, no_shadow (IdSet.union ids (aexp_bindings new_guard)) (rename aexp2), uannot) (* Map over all the avals in an aexp. *) let rec map_aval f (AE_aux (aexp, annot)) = @@ -366,10 +388,16 @@ let rec map_aval f (AE_aux (aexp, annot)) = | AE_field (aval, field, typ) -> AE_field (f annot aval, field, typ) | AE_match (aval, cases, typ) -> AE_match - (f annot aval, List.map (fun (pat, aexp1, aexp2) -> (pat, map_aval f aexp1, map_aval f aexp2)) cases, typ) + ( f annot aval, + List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, map_aval f aexp1, map_aval f aexp2, uannot)) cases, + typ + ) | AE_try (aexp, cases, typ) -> AE_try - (map_aval f aexp, List.map (fun (pat, aexp1, aexp2) -> (pat, map_aval f aexp1, map_aval f aexp2)) cases, typ) + ( map_aval f aexp, + List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, map_aval f aexp1, map_aval f aexp2, uannot)) cases, + typ + ) | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, f annot aval, map_aval f aexp) in AE_aux (aexp, annot) @@ -391,11 +419,18 @@ let rec map_functions f (AE_aux (aexp, annot)) = AE_for (id, map_functions f aexp1, map_functions f aexp2, map_functions f aexp3, order, map_functions f aexp4) | AE_match (aval, cases, typ) -> AE_match - (aval, List.map (fun (pat, aexp1, aexp2) -> (pat, map_functions f aexp1, map_functions f aexp2)) cases, typ) + ( aval, + List.map + (fun (pat, aexp1, aexp2, uannot) -> (pat, map_functions f aexp1, map_functions f aexp2, uannot)) + cases, + typ + ) | AE_try (aexp, cases, typ) -> AE_try ( map_functions f aexp, - List.map (fun (pat, aexp1, aexp2) -> (pat, map_functions f aexp1, map_functions f aexp2)) cases, + List.map + (fun (pat, aexp1, aexp2, uannot) -> (pat, map_functions f aexp1, map_functions f aexp2, uannot)) + cases, typ ) | (AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _) as v -> v @@ -426,12 +461,14 @@ let rec pp_alexp = function | AL_addr (id, typ) -> string "*" ^^ parens (pp_annot typ (pp_id id)) | AL_field (alexp, field) -> pp_alexp alexp ^^ dot ^^ pp_id field -let rec pp_aexp (AE_aux (aexp, annot)) = - ( match get_attributes annot.uannot with +let pp_anf_uannot uannot = + match get_attributes uannot with | [] -> empty | attrs -> concat_map (fun (_, attr, arg) -> string (string_of_attribute attr arg |> Util.magenta |> Util.clear)) attrs - ) + +let rec pp_aexp (AE_aux (aexp, annot)) = + pp_anf_uannot annot.uannot ^^ match aexp with | AE_val v -> pp_aval v @@ -492,7 +529,9 @@ let rec pp_aexp (AE_aux (aexp, annot)) = (List.map (fun (id, aval) -> pp_id id ^^ string " = " ^^ pp_aval aval) (Bindings.bindings updates)) ) -and pp_apat (AP_aux (apat_aux, _, _)) = +and pp_apat (AP_aux (apat_aux, annot)) = + pp_anf_uannot annot.uannot + ^^ match apat_aux with | AP_wild _ -> string "_" | AP_id (id, typ) -> pp_annot typ (pp_id id) @@ -513,7 +552,8 @@ and pp_apat (AP_aux (apat_aux, _, _)) = and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace -and pp_case (apat, guard, body) = separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body] +and pp_case (apat, guard, body, uannot) = + pp_anf_uannot uannot ^^ parens (separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body]) and pp_block = function | [] -> string "()" @@ -553,8 +593,8 @@ let rec split_block l = function (exp :: exps, last) | [] -> Reporting.unreachable l __POS__ "empty block found when converting to ANF" [@coverage off] -let rec anf_pat ?(global = false) (P_aux (p_aux, annot) as pat) = - let mk_apat aux = AP_aux (aux, env_of_annot annot, fst annot) in +let rec anf_pat ?(global = false) (P_aux (p_aux, (l, tannot)) as pat) = + let mk_apat aux = AP_aux (aux, { loc = l; env = env_of_tannot tannot; uannot = untyped_annot tannot }) in match p_aux with | P_id id when global -> mk_apat (AP_global (id, typ_of_pat pat)) | P_id id -> mk_apat (AP_id (id, typ_of_pat pat)) @@ -575,11 +615,9 @@ let rec anf_pat ?(global = false) (P_aux (p_aux, annot) as pat) = | P_as (pat, id) -> mk_apat (AP_as (anf_pat ~global pat, id, typ_of_pat pat)) | P_struct (fpats, FP_no_wild) -> mk_apat (AP_struct (List.map (fun (field, pat) -> (field, anf_pat ~global pat)) fpats, typ_of_pat pat)) - | _ -> - Reporting.unreachable (fst annot) __POS__ - ("Could not convert pattern to ANF: " ^ string_of_pat pat) [@coverage off] + | _ -> Reporting.unreachable l __POS__ ("Could not convert pattern to ANF: " ^ string_of_pat pat) [@coverage off] -let rec apat_globals (AP_aux (aux, _, _)) = +let rec apat_globals (AP_aux (aux, _)) = match aux with | AP_nil _ | AP_wild _ | AP_id _ -> [] | AP_global (id, typ) -> [(id, typ)] @@ -737,22 +775,28 @@ let rec anf (E_aux (e_aux, (l, tannot)) as exp) = mk_aexp (AE_val (AV_ref (id, lvar))) | E_match (match_exp, pexps) -> let match_aval, match_wrap = to_aval (anf match_exp) in - let anf_pexp (Pat_aux (pat_aux, (l, _))) = + let anf_pexp (Pat_aux (pat_aux, (l, tannot))) = match pat_aux with - | Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body) + | Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body, untyped_annot tannot) | Pat_exp (pat, body) -> ( anf_pat pat, AE_aux (AE_val (AV_lit (mk_lit L_true, bool_typ)), { loc = l; env = env_of body; uannot = empty_uannot }), - anf body + anf body, + untyped_annot tannot ) in match_wrap (mk_aexp (AE_match (match_aval, List.map anf_pexp pexps, typ_of exp))) | E_try (match_exp, pexps) -> let match_aexp = anf match_exp in - let anf_pexp (Pat_aux (pat_aux, _)) = + let anf_pexp (Pat_aux (pat_aux, (l, tannot))) = match pat_aux with - | Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body) - | Pat_exp (pat, body) -> (anf_pat pat, mk_aexp (AE_val (AV_lit (mk_lit L_true, bool_typ))), anf body) + | Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body, untyped_annot tannot) + | Pat_exp (pat, body) -> + ( anf_pat pat, + AE_aux (AE_val (AV_lit (mk_lit L_true, bool_typ)), { loc = l; env = env_of body; uannot = empty_uannot }), + anf body, + untyped_annot tannot + ) in mk_aexp (AE_try (match_aexp, List.map anf_pexp pexps, typ_of exp)) | ( E_var (LE_aux (LE_id id, _), binding, body) diff --git a/src/lib/anf.mli b/src/lib/anf.mli index a9cdccda0..29228c3c5 100644 --- a/src/lib/anf.mli +++ b/src/lib/anf.mli @@ -82,9 +82,9 @@ open Type_check the original Sail expression, it's typing environment, and the uannot type containing any attributes attached to the original expression. *) -type aexp_annot = { loc : l; env : Env.t; uannot : uannot } +type anf_annot = { loc : l; env : Env.t; uannot : uannot } -type 'a aexp = AE_aux of 'a aexp_aux * aexp_annot +type 'a aexp = AE_aux of 'a aexp_aux * anf_annot and 'a aexp_aux = | AE_val of 'a aval @@ -98,8 +98,8 @@ and 'a aexp_aux = | AE_throw of 'a aval * 'a | AE_if of 'a aval * 'a aexp * 'a aexp * 'a | AE_field of 'a aval * id * 'a - | AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp) list * 'a - | AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp) list * 'a + | AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a + | AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a | AE_struct_update of 'a aval * 'a aval Bindings.t * 'a | AE_for of id * 'a aexp * 'a aexp * 'a aexp * order * 'a aexp | AE_loop of loop * 'a aexp * 'a aexp @@ -110,7 +110,7 @@ and 'a aexp_aux = and sc_op = SC_and | SC_or -and 'a apat = AP_aux of 'a apat_aux * Env.t * l +and 'a apat = AP_aux of 'a apat_aux * anf_annot and 'a apat_aux = | AP_tuple of 'a apat list @@ -152,10 +152,10 @@ val aval_typ : typ aval -> typ val aexp_typ : typ aexp -> typ (** Map over all values in an ANF expression *) -val map_aval : (aexp_annot -> 'a aval -> 'a aval) -> 'a aexp -> 'a aexp +val map_aval : (anf_annot -> 'a aval -> 'a aval) -> 'a aexp -> 'a aexp (** Map over all function calls in an ANF expression *) -val map_functions : (aexp_annot -> id -> 'a aval list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp +val map_functions : (anf_annot -> id -> 'a aval list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp (** This function 'folds' an [aexp] applying the provided function to all leaf subexpressions, then applying the function to their @@ -164,6 +164,10 @@ val fold_aexp : ('a aexp -> 'a aexp) -> 'a aexp -> 'a aexp val aexp_bindings : 'a aexp -> IdSet.t +val is_pure_aexp : Effects.side_effect_info -> 'a aexp -> bool + +val is_pure_case : Effects.side_effect_info -> 'a apat * 'a aexp * 'a aexp * uannot -> bool + (** Remove all variable shadowing in an ANF expression *) val no_shadow : IdSet.t -> 'a aexp -> 'a aexp diff --git a/src/lib/chunk_ast.ml b/src/lib/chunk_ast.ml index e6e374916..b17ef70e6 100644 --- a/src/lib/chunk_ast.ml +++ b/src/lib/chunk_ast.ml @@ -743,7 +743,8 @@ let flatten_block exps = (* Check if a sequence of cases in a match or try statement is aligned *) let is_aligned pexps = - let pexp_exp_column = function + let rec pexp_exp_column = function + | Pat_aux (Pat_attribute (_, _, pexp), _) -> pexp_exp_column pexp | Pat_aux (Pat_exp (_, E_aux (_, l)), _) -> starting_column_num l | Pat_aux (Pat_when (_, _, E_aux (_, l)), _) -> starting_column_num l in @@ -945,7 +946,7 @@ let rec chunk_exp comments chunks (E_aux (aux, l)) = let kind = match match_exp with E_match _ -> Match_match | _ -> Try_match in let exp_chunks = rec_chunk_exp exp in let aligned = is_aligned cases in - let cases = List.map (chunk_pexp ~delim:"," comments) cases in + let cases = List.map (chunk_pexp ~delim:"," comments chunks) cases in Match { kind; exp = exp_chunks; aligned; cases } |> add_chunk chunks | E_vector_update _ | E_vector_update_subrange _ -> let vec_chunks, updates = chunk_vector_update comments (E_aux (aux, l)) in @@ -1052,8 +1053,12 @@ and chunk_vector_update comments (E_aux (aux, l) as exp) = chunk_exp comments exp_chunks exp; (exp_chunks, []) -and chunk_pexp ?delim comments (Pat_aux (aux, l)) = +and chunk_pexp ?delim comments chunks (Pat_aux (aux, l)) = match aux with + | Pat_attribute (attr, arg, pexp) -> + Queue.add (Atom (Ast_util.string_of_attribute attr arg)) chunks; + Queue.add (Spacer (false, 1)) chunks; + chunk_pexp ?delim comments chunks pexp | Pat_exp (pat, exp) -> let funcl_space = match pat with P_aux (P_tuple _, _) -> false | _ -> true in let pat_chunks = Queue.create () in @@ -1087,7 +1092,7 @@ let chunk_funcl comments funcl = Queue.add (Spacer (false, 1)) chunks; chunk_funcl' comments funcl | FCL_doc (_, funcl) -> chunk_funcl' comments funcl - | FCL_funcl (_, pexp) -> chunk_pexp comments pexp + | FCL_funcl (_, pexp) -> chunk_pexp comments chunks pexp in (chunks, chunk_funcl' comments funcl) diff --git a/src/lib/initial_check.ml b/src/lib/initial_check.ml index 088053d53..030969dc6 100644 --- a/src/lib/initial_check.ml +++ b/src/lib/initial_check.ml @@ -707,9 +707,12 @@ module KindInference = struct let* pat = infer_pat ctx pat in wrap (P.P_attribute (attr, arg, pat)) - let infer_case ctx (P.Pat_aux (pexp, l)) = + let rec infer_case ctx (P.Pat_aux (pexp, l)) = let wrap aux = return (P.Pat_aux (aux, l)) in match pexp with + | P.Pat_attribute (attr, arg, pexp) -> + let* pexp = infer_case ctx pexp in + wrap (P.Pat_attribute (attr, arg, pexp)) | P.Pat_exp (pat, exp) -> let* pat = infer_pat ctx pat in wrap (P.Pat_exp (pat, exp)) @@ -1381,8 +1384,12 @@ and to_ast_lexp_vector_concat ctx (P.E_aux (exp_aux, l) as exp) = | P.E_vector_append (exp1, exp2) -> to_ast_lexp ctx exp1 :: to_ast_lexp_vector_concat ctx exp2 | _ -> [to_ast_lexp ctx exp] -and to_ast_case ctx (P.Pat_aux (pex, l) : P.pexp) : uannot pexp = - match pex with +and to_ast_case ctx (P.Pat_aux (pexp_aux, l) : P.pexp) : uannot pexp = + match pexp_aux with + | P.Pat_attribute (attr, arg, pexp) -> + let (Pat_aux (pexp, (pexp_l, annot))) = to_ast_case ctx pexp in + let annot = add_attribute l attr arg annot in + Pat_aux (pexp, (pexp_l, annot)) | P.Pat_exp (pat, exp) -> Pat_aux (Pat_exp (to_ast_pat ctx pat, to_ast_exp ctx exp), (l, empty_uannot)) | P.Pat_when (pat, guard, exp) -> Pat_aux (Pat_when (to_ast_pat ctx pat, to_ast_exp ctx guard, to_ast_exp ctx exp), (l, empty_uannot)) diff --git a/src/lib/jib_compile.ml b/src/lib/jib_compile.ml index 596a063bb..548717449 100644 --- a/src/lib/jib_compile.ml +++ b/src/lib/jib_compile.ml @@ -150,22 +150,6 @@ let ctx_get_extern id ctx = let ctx_has_val_spec id ctx = Bindings.mem id ctx.valspecs || Bindings.mem id (Env.get_val_specs ctx.tc_env) -let rec is_pure_aexp ctx (AE_aux (aexp, { uannot; _ })) = - match get_attribute "anf_pure" uannot with - | Some _ -> true - | None -> ( - match aexp with - | AE_app (f, _, _) -> Effects.function_is_pure f ctx.effect_info - | AE_let (Immutable, _, _, aexp1, aexp2, _) -> is_pure_aexp ctx aexp1 && is_pure_aexp ctx aexp2 - | AE_match (_, arms, _) -> - List.for_all (fun (_, guard, aexp) -> is_pure_aexp ctx guard && is_pure_aexp ctx aexp) arms - | AE_short_circuit (_, _, aexp) -> is_pure_aexp ctx aexp - | AE_val _ -> true - | _ -> false - ) - -let is_pure_case ctx (_, guard, body) = is_pure_aexp ctx guard && is_pure_aexp ctx body - let initial_ctx ?for_target env effect_info = let initial_valspecs = [ @@ -583,7 +567,14 @@ module Make (C : CONFIG) = struct [iclear (CT_list ctyp) gs] ) - let compile_funcall l ctx id args = + (** Compile a function call. + + If called as [compile_funcall ~override_id:foo l ctx bar args], then we will compile + as if we are calling [bar], but insert a call to [foo] in the IR. This is used for + optimizations where we can generate a more efficient version of [foo] that doesn't exist + in the original Sail. + *) + let compile_funcall ?override_id l ctx id args = let setup = ref [] in let cleanup = ref [] in @@ -594,7 +585,7 @@ module Make (C : CONFIG) = struct try Env.get_val_spec id ctx.local_env with Type_error.Type_error _ -> Env.get_val_spec id ctx.tc_env in let arg_typs, ret_typ = match fn_typ with Typ_fn (arg_typs, ret_typ) -> (arg_typs, ret_typ) | _ -> assert false in - let ctx' = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.tc_env } in + let ctx' = { ctx with local_env = Env.add_typquant (id_loc id) quant ctx.local_env } in let arg_ctyps, ret_ctyp = (List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ) in assert (List.length arg_ctyps = List.length args); @@ -611,19 +602,21 @@ module Make (C : CONFIG) = struct let setup_args = List.map2 setup_arg arg_ctyps args in + let call_id = Option.value ~default:id override_id in + ( List.rev !setup, begin fun clexp -> let instantiation = KBindings.union merge_unifiers (ctyp_unify l ret_ctyp (clexp_ctyp clexp)) !instantiation in - ifuncall l clexp (id, KBindings.bindings instantiation |> List.map snd) setup_args + ifuncall l clexp (call_id, KBindings.bindings instantiation |> List.map snd) setup_args (* iblock1 (optimize_call l ctx clexp (id, KBindings.bindings unifiers |> List.map snd) setup_args arg_ctyps ret_ctyp) *) end, !cleanup ) - let rec apat_ctyp ctx (AP_aux (apat, env, _)) = + let rec apat_ctyp ctx (AP_aux (apat, { env; _ })) = let ctx = { ctx with local_env = env } in match apat with | AP_tuple apats -> CT_tup (List.map (apat_ctyp ctx) apats) @@ -634,7 +627,7 @@ module Make (C : CONFIG) = struct | AP_as (_, _, typ) -> ctyp_of_typ ctx typ | AP_struct (_, typ) -> ctyp_of_typ ctx typ - let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval on_failure = + let rec compile_match ctx (AP_aux (apat_aux, { env; loc = l; _ })) cval on_failure = let ctx = { ctx with local_env = env } in let ctyp = cval_ctyp cval in match apat_aux with @@ -753,6 +746,23 @@ module Make (C : CONFIG) = struct | Some def_annot -> Option.is_some (get_def_attribute "optimize_control_flow_order" def_annot) | None -> false + (** Returns true if we have an infalliable mapping case. This occurs + only if the final case is marked with $[mapping_last] by the + mappings.ml rewrite, and we have a $[mapping_infallible] + attribute attached to the containing function in the context. *) + let has_infallible_mapping_case ctx = function + | [] -> true + | cases -> + let in_infallible_mapping = + match ctx.def_annot with + | None -> false + | Some def_annot -> Option.is_some (get_def_attribute "mapping_infallible" def_annot) + in + in_infallible_mapping + && + let _, _, _, uannot = Util.last cases in + Option.is_some (get_attribute "mapping_last" uannot) + let rec compile_aexp ctx (AE_aux (aexp_aux, { env; loc = l; uannot })) = let ctx = { ctx with local_env = env } in match aexp_aux with @@ -767,16 +777,24 @@ module Make (C : CONFIG) = struct let ctx = { ctx with locals = Bindings.add id (mut, binding_ctyp) ctx.locals } in let setup, call, cleanup = compile_aexp ctx body in (letb_setup @ setup, call, cleanup @ letb_cleanup) - | AE_app (id, vs, _) -> compile_funcall l ctx id vs + | AE_app (id, vs, _) -> + if Option.is_some (get_attribute "mapping_guarded" uannot) then ( + let override_id = append_id id "_infallible" in + if Bindings.mem override_id ctx.valspecs then compile_funcall ~override_id l ctx id vs + else compile_funcall l ctx id vs + ) + else compile_funcall l ctx id vs | AE_val aval -> let setup, cval, cleanup = compile_aval l ctx aval in (setup, (fun clexp -> icopy l clexp cval), cleanup) (* Compile case statements *) - | AE_match (aval, cases, typ) when C.eager_control_flow && can_optimize_control_flow_order ctx -> + | AE_match (aval, cases, typ) + when C.eager_control_flow + && (can_optimize_control_flow_order ctx || Option.is_some (get_attribute "anf_pure" uannot)) -> let ctyp = ctyp_of_typ ctx typ in let aval_setup, cval, aval_cleanup = compile_aval l ctx aval in - let compile_case case_match_id case_return_id (apat, guard, body) = - if is_dead_aexp body then [] + let compile_case case_match_id case_return_id (apat, guard, body, case_uannot) = + if is_dead_aexp body then None else ( let trivial_guard = match guard with @@ -790,28 +808,37 @@ module Make (C : CONFIG) = struct in let guard_setup, guard_call, guard_cleanup = compile_aexp ctx guard in let body_setup, body_call, body_cleanup = compile_aexp ctx body in - [idecl l ctyp case_return_id] - @ pre_destructure @ destructure - @ ( if not trivial_guard then - [idecl l CT_bool case_match_id] - @ guard_setup - @ [guard_call (CL_id (case_match_id, CT_bool))] - @ guard_cleanup - else [iinit l CT_bool case_match_id (V_lit (VL_bool true, CT_bool))] + Some + ([idecl l ctyp case_return_id; iinit l CT_bool case_match_id (V_lit (VL_bool true, CT_bool))] + @ pre_destructure @ destructure + @ ( if not trivial_guard then ( + let gs = ngensym () in + guard_setup + @ [ + idecl l CT_bool gs; + guard_call (CL_id (gs, CT_bool)); + icopy l + (CL_id (case_match_id, CT_bool)) + (V_call (Band, [V_id (case_match_id, CT_bool); V_id (gs, CT_bool)])); + ] + @ guard_cleanup + ) + else [] + ) + @ body_setup + @ [body_call (CL_id (case_return_id, ctyp))] + @ body_cleanup @ destructure_cleanup ) - @ body_setup - @ [body_call (CL_id (case_return_id, ctyp))] - @ body_cleanup @ destructure_cleanup ) in let case_ids, cases = - List.map + List.filter_map (fun case -> + let open Util.Option_monad in let case_match_id = ngensym () in let case_return_id = ngensym () in - ( (V_id (case_match_id, CT_bool), V_id (case_return_id, ctyp)), - compile_case case_match_id case_return_id case - ) + let* case = compile_case case_match_id case_return_id case in + Some ((V_id (case_match_id, CT_bool), V_id (case_return_id, ctyp)), case) ) cases |> List.split @@ -823,7 +850,7 @@ module Make (C : CONFIG) = struct in (aval_setup @ List.concat cases, (fun clexp -> icopy l clexp (build_ite case_ids)), aval_cleanup) | AE_match (aval, cases, typ) -> - let is_complete = Option.is_some (get_attribute "complete" uannot) in + let is_complete = Option.is_some (get_attribute "complete" uannot) || has_infallible_mapping_case ctx cases in let ctx = update_coverage_override uannot ctx in let ctyp = ctyp_of_typ ctx typ in let aval_setup, cval, aval_cleanup = compile_aval l ctx aval in @@ -833,7 +860,7 @@ module Make (C : CONFIG) = struct let branch_id, on_reached = if num_cases > 1 then coverage_branch_reached ctx l else (0, []) in let case_return_id = ngensym () in let finish_match_label = label "finish_match_" in - let compile_case is_last (apat, guard, body) = + let compile_case is_last (apat, guard, body, case_uannot) = let case_label = label "case_" in if is_dead_aexp body then [ilabel case_label] else ( @@ -890,7 +917,7 @@ module Make (C : CONFIG) = struct let aexp_setup, aexp_call, aexp_cleanup = compile_aexp ctx aexp in let try_return_id = ngensym () in let post_exception_handlers_label = label "post_exception_handlers_" in - let compile_case (apat, guard, body) = + let compile_case (apat, guard, body, case_uannot) = let trivial_guard = match guard with | AE_aux (AE_val (AV_lit (L_aux (L_true, _), _)), _) @@ -946,36 +973,37 @@ module Make (C : CONFIG) = struct let if_ctyp = ctyp_of_typ ctx if_typ in let setup, cval, cleanup = compile_aval l ctx aval in let pure_attr = get_attribute "anf_pure" uannot in - let eager = C.eager_control_flow && can_optimize_control_flow_order ctx in - match (pure_attr, eager) with - | Some _, _ | _, true -> - let then_gs = ngensym () in - let then_setup, then_call, then_cleanup = compile_aexp ctx then_aexp in - let else_gs = ngensym () in - let else_setup, else_call, else_cleanup = compile_aexp ctx else_aexp in - ( setup @ then_setup @ else_setup - @ [ - idecl l if_ctyp then_gs; - idecl l if_ctyp else_gs; - then_call (CL_id (then_gs, if_ctyp)); - else_call (CL_id (else_gs, if_ctyp)); - ], - (fun clexp -> icopy l clexp (V_call (Ite, [cval; V_id (then_gs, if_ctyp); V_id (else_gs, if_ctyp)]))), - [iclear if_ctyp else_gs; iclear if_ctyp then_gs] @ else_cleanup @ then_cleanup @ cleanup - ) - | _ -> - let branch_id, on_reached = coverage_branch_reached ctx l in - let compile_branch aexp = - let setup, call, cleanup = compile_aexp ctx aexp in - fun clexp -> coverage_branch_target_taken ctx branch_id aexp @ setup @ [call clexp] @ cleanup - in - ( setup, - (fun clexp -> - append_into_block on_reached - (iif l cval (compile_branch then_aexp clexp) (compile_branch else_aexp clexp) if_ctyp) - ), - cleanup - ) + let eager = C.eager_control_flow && (can_optimize_control_flow_order ctx || Option.is_some pure_attr) in + if eager then ( + let then_gs = ngensym () in + let then_setup, then_call, then_cleanup = compile_aexp ctx then_aexp in + let else_gs = ngensym () in + let else_setup, else_call, else_cleanup = compile_aexp ctx else_aexp in + ( setup @ then_setup @ else_setup + @ [ + idecl l if_ctyp then_gs; + idecl l if_ctyp else_gs; + then_call (CL_id (then_gs, if_ctyp)); + else_call (CL_id (else_gs, if_ctyp)); + ], + (fun clexp -> icopy l clexp (V_call (Ite, [cval; V_id (then_gs, if_ctyp); V_id (else_gs, if_ctyp)]))), + [iclear if_ctyp else_gs; iclear if_ctyp then_gs] @ else_cleanup @ then_cleanup @ cleanup + ) + ) + else ( + let branch_id, on_reached = coverage_branch_reached ctx l in + let compile_branch aexp = + let setup, call, cleanup = compile_aexp ctx aexp in + fun clexp -> coverage_branch_target_taken ctx branch_id aexp @ setup @ [call clexp] @ cleanup + in + ( setup, + (fun clexp -> + append_into_block on_reached + (iif l cval (compile_branch then_aexp clexp) (compile_branch else_aexp clexp) if_ctyp) + ), + cleanup + ) + ) ) (* FIXME: AE_struct_update could be AV_record_update - would reduce some copying. *) | AE_struct_update (aval, fields, typ) -> @@ -1562,9 +1590,11 @@ module Make (C : CONFIG) = struct let compile_funcl ctx def_annot id pat guard exp = let debug_attr = get_def_attribute "jib_debug" def_annot in + let mapping_function_attr = get_def_attribute "mapping_function" def_annot in if Option.is_some debug_attr then ( - prerr_endline Util.("Rewritten source for " ^ string_of_id id ^ ":" |> yellow |> bold |> clear); + let extra = if Option.is_some mapping_function_attr then " (mapping)" else "" in + prerr_endline Util.("Rewritten source for " ^ string_of_id id ^ extra ^ ":" |> yellow |> bold |> clear); prerr_endline (Document.to_string (Pretty_print_sail.doc_exp (Type_check.strip_exp exp))) ); @@ -1626,27 +1656,50 @@ module Make (C : CONFIG) = struct prerr_endline (Document.to_string (pp_aexp aexp)) ); - let setup, call, cleanup = compile_aexp ctx aexp in - let destructure, destructure_cleanup = - compiled_args |> List.map snd |> combine_destructure_cleanup |> fix_destructure (id_loc id) fundef_label - in + let compile_body ctx = + let setup, call, cleanup = compile_aexp ctx aexp in + let destructure, destructure_cleanup = + compiled_args |> List.map snd |> combine_destructure_cleanup |> fix_destructure (id_loc id) fundef_label + in - let instrs = - arg_setup @ destructure @ guard_instrs @ setup - @ [call (CL_id (return, ret_ctyp))] - @ cleanup @ destructure_cleanup @ arg_cleanup + let instrs = + arg_setup @ destructure @ guard_instrs @ setup + @ [call (CL_id (return, ret_ctyp))] + @ cleanup @ destructure_cleanup @ arg_cleanup + in + let instrs = fix_early_return (exp_loc exp) (CL_id (return, ret_ctyp)) instrs in + let instrs = unique_names instrs in + let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in + coverage_function_entry ctx id (exp_loc exp) @ instrs in - let instrs = fix_early_return (exp_loc exp) (CL_id (return, ret_ctyp)) instrs in - let instrs = unique_names instrs in - let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in - let instrs = coverage_function_entry ctx id (exp_loc exp) @ instrs in + + let compiled_args = List.map fst compiled_args in + let instrs = compile_body ctx in if Option.is_some debug_attr then ( prerr_endline Util.("IR for " ^ string_of_id id ^ ":" |> yellow |> bold |> clear); List.iter (fun instr -> prerr_endline (string_of_instr instr)) instrs ); - ([CDEF_aux (CDEF_fundef (id, None, List.map fst compiled_args, instrs), def_annot)], orig_ctx) + (* If the function is a mapping, we generate an infallible version (that never causes a match_failure) *) + let mapping_infallible, return_ctx = + match mapping_function_attr with + | Some (attr_l, _) -> + let instrs = + compile_body + { ctx with def_annot = Some (add_def_attribute (gen_loc attr_l) "mapping_infallible" None def_annot) } + in + let id = append_id id "_infallible" in + ( [ + CDEF_aux (CDEF_val (id, None, arg_ctyps, ret_ctyp), def_annot); + CDEF_aux (CDEF_fundef (id, None, compiled_args, instrs), def_annot); + ], + { orig_ctx with valspecs = Bindings.add id (None, arg_ctyps, ret_ctyp, empty_uannot) orig_ctx.valspecs } + ) + | None -> ([], orig_ctx) + in + + ([CDEF_aux (CDEF_fundef (id, None, compiled_args, instrs), def_annot)] @ mapping_infallible, return_ctx) (** Compile a Sail toplevel definition into an IR definition **) let rec compile_def n total ctx (DEF_aux (aux, _) as def) = diff --git a/src/lib/jib_compile.mli b/src/lib/jib_compile.mli index 1f277886c..4459c11e2 100644 --- a/src/lib/jib_compile.mli +++ b/src/lib/jib_compile.mli @@ -96,8 +96,6 @@ val ctx_get_extern : id -> ctx -> string val ctx_has_val_spec : id -> ctx -> bool -val is_pure_aexp : ctx -> 'a aexp -> bool - (** Create an inital Jib compilation context. The target is the name that would appear in a valspec extern section, i.e. diff --git a/src/lib/mappings.ml b/src/lib/mappings.ml index d73af354f..de860a174 100644 --- a/src/lib/mappings.ml +++ b/src/lib/mappings.ml @@ -62,7 +62,7 @@ let y = x in $[mapping_match] $[complete] match ($[complete] match y { => Some(), - z if mapping_forwards_matches(z) => $[complete] match mapping_forwards(z) { + z if mapping_forwards_matches(z) => $[complete] match $[mapping_guarded] mapping_forwards(z) { A => Some(), _ => None(), }, @@ -96,7 +96,7 @@ let y = x in { $[complete] match y { => return , - z if mapping_forwards_matches(z) => $[complete] match mapping_forwards(z) { + z if mapping_forwards_matches(z) => $[complete] match $[mapping_guarded] mapping_forwards(z) { A => return , _ => (), }, @@ -281,8 +281,9 @@ let rec mappings_match ~terminal ~return_position is_mapping subst mappings pexp let direction = mapping_direction l uannot in let mapping_fun_id = mapping_function mapping direction in let mapping_guard_id = mapping_guard mapping direction in - ( mk_exp (E_app (mapping_fun_id, [mk_exp (E_id subst_id)])), - mk_exp (E_app (mapping_guard_id, [mk_exp (E_id subst_id)])), + let guarded_attr = add_attribute (gen_loc l) "mapping_guarded" None empty_uannot in + ( E_aux (E_app (mapping_fun_id, [mk_exp (E_id subst_id)]), (gen_loc l, guarded_attr)), + mk_exp ~loc:(gen_loc l) (E_app (mapping_guard_id, [mk_exp (E_id subst_id)])), subpat ) | _ -> Reporting.unreachable l __POS__ "Non-mapping in mappings_match" diff --git a/src/lib/parse_ast.ml b/src/lib/parse_ast.ml index bd74d7e2f..4765437fb 100644 --- a/src/lib/parse_ast.ml +++ b/src/lib/parse_ast.ml @@ -257,8 +257,11 @@ and opt_default_aux = and opt_default = Def_val_aux of opt_default_aux * l -and pexp_aux = (* Pattern match *) - | Pat_exp of pat * exp | Pat_when of pat * exp * exp +and pexp_aux = + (* Pattern match *) + | Pat_exp of pat * exp + | Pat_when of pat * exp * exp + | Pat_attribute of string * attribute_data option * pexp and pexp = Pat_aux of pexp_aux * l diff --git a/src/lib/parser.mly b/src/lib/parser.mly index 96a9dde21..6e52901d8 100644 --- a/src/lib/parser.mly +++ b/src/lib/parser.mly @@ -716,10 +716,12 @@ exp0: $endpos) } case: - | pat EqGt exp - { mk_pexp (Pat_exp ($1, $3)) $startpos $endpos } - | pat If_ exp EqGt exp - { mk_pexp (Pat_when ($1, $3, $5)) $startpos $endpos } + | p = pat; EqGt; body = exp + { mk_pexp (Pat_exp (p, body)) $startpos $endpos } + | p = pat; If_; guard = exp; EqGt; body = exp + { mk_pexp (Pat_when (p, guard, body)) $startpos $endpos } + | a = attribute; Lparen; c = case; Rparen + { mk_pexp (Pat_attribute (fst a, snd a, c)) $startpos $endpos(a) } case_list: | case diff --git a/src/lib/pretty_print_sail.ml b/src/lib/pretty_print_sail.ml index 88c4fe2ad..83fabe14d 100644 --- a/src/lib/pretty_print_sail.ml +++ b/src/lib/pretty_print_sail.ml @@ -595,10 +595,18 @@ module Printer (Config : PRINT_CONFIG) = struct and doc_pexps pexps = surround 2 0 lbrace (separate_map (comma ^^ hardline) doc_pexp pexps) rbrace - and doc_pexp (Pat_aux (pat_aux, _)) = - match pat_aux with - | Pat_exp (pat, exp) -> separate space [doc_pat pat; string "=>"; doc_exp exp] - | Pat_when (pat, wh, exp) -> separate space [doc_pat pat; string "if"; doc_exp wh; string "=>"; doc_exp exp] + and doc_pexp (Pat_aux (pat_aux, (_, uannot))) = + let wrap, attrs_doc = + match get_attributes uannot with + | [] -> ((fun x -> x), empty) + | _ -> (parens, concat_map (fun (_, attr, arg) -> doc_attr attr arg) (get_attributes uannot)) + in + let pexp_doc = + match pat_aux with + | Pat_exp (pat, exp) -> separate space [doc_pat pat; string "=>"; doc_exp exp] + | Pat_when (pat, wh, exp) -> separate space [doc_pat pat; string "if"; doc_exp wh; string "=>"; doc_exp exp] + in + attrs_doc ^^ wrap pexp_doc and doc_letbind (LB_aux (lb_aux, _)) = match lb_aux with LB_val (pat, exp) -> separate space [doc_pat pat; equals; doc_exp exp] diff --git a/src/lib/rewrites.ml b/src/lib/rewrites.ml index 89864d95c..3a08a0038 100644 --- a/src/lib/rewrites.ml +++ b/src/lib/rewrites.ml @@ -3223,11 +3223,15 @@ let rewrite_ast_realize_mappings effect_info env ast = | Pat_aux (Pat_exp (pat, _), annot) -> Pat_aux (Pat_exp (pat, mk_lit_exp L_true), annot) | Pat_aux (Pat_when (pat, guard, _), annot) -> Pat_aux (Pat_when (pat, guard, mk_lit_exp L_true), annot) in - let realize_mapcl forwards id mapcl = + let annotate_pat ~last = function + | Pat_aux (pexp_aux, (l, uannot)) when last -> Pat_aux (pexp_aux, (l, add_attribute l "mapping_last" None uannot)) + | pexp -> pexp + in + let realize_mapcl ~last ~forwards id mapcl = match mapcl with - | MCL_aux (MCL_bidir (mpexp1, mpexp2), _) -> [realize_mpexps forwards mpexp1 mpexp2] - | MCL_aux (MCL_forwards pexp, _) -> if forwards then [pexp] else [] - | MCL_aux (MCL_backwards pexp, _) -> if forwards then [] else [pexp] + | MCL_aux (MCL_bidir (mpexp1, mpexp2), _) -> [annotate_pat ~last (realize_mpexps forwards mpexp1 mpexp2)] + | MCL_aux (MCL_forwards pexp, _) -> if forwards then [annotate_pat ~last pexp] else [] + | MCL_aux (MCL_backwards pexp, _) -> if forwards then [] else [annotate_pat ~last pexp] in let realize_bool_mapcl forwards id mapcl = match mapcl with @@ -3305,14 +3309,22 @@ let rewrite_ast_realize_mappings effect_info env ast = let forwards_match = mk_exp (E_match - (arg_exp, List.map (fun mapcl -> strip_mapcl mapcl |> realize_mapcl true forwards_id) mapcls |> List.flatten) + ( arg_exp, + Util.map_last + (fun last mapcl -> strip_mapcl mapcl |> realize_mapcl ~last ~forwards:true forwards_id) + mapcls + |> List.flatten + ) ) in let backwards_match = mk_exp (E_match ( arg_exp, - List.map (fun mapcl -> strip_mapcl mapcl |> realize_mapcl false backwards_id) mapcls |> List.flatten + Util.map_last + (fun last mapcl -> strip_mapcl mapcl |> realize_mapcl ~last ~forwards:false backwards_id) + mapcls + |> List.flatten ) ) in @@ -3360,8 +3372,10 @@ let rewrite_ast_realize_mappings effect_info env ast = ) in - let forwards_fun, _ = Type_check.check_fundef env def_annot forwards_fun in - let backwards_fun, _ = Type_check.check_fundef env def_annot backwards_fun in + let fun_def_annot = add_def_attribute (gen_loc l) "mapping_function" None def_annot in + + let forwards_fun, _ = Type_check.check_fundef env fun_def_annot forwards_fun in + let backwards_fun, _ = Type_check.check_fundef env fun_def_annot backwards_fun in let forwards_matches_fun, _ = Type_check.check_fundef env def_annot forwards_matches_fun in let backwards_matches_fun, _ = Type_check.check_fundef env def_annot backwards_matches_fun in diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index d9339b4a7..d36c3cec9 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -2545,7 +2545,7 @@ and check_block l env exps ret_typ = texp :: check_block l env exps ret_typ and check_case env pat_typ pexp typ = - let pat, guard, case, ((l, _) as annot) = destruct_pexp pexp in + let pat, guard, case, (l, uannot) = destruct_pexp pexp in ignore (check_pattern_duplicates env pat); let env = bind_pattern_vector_subranges pat env in match bind_pat env pat pat_typ with @@ -2566,7 +2566,7 @@ and check_case env pat_typ pexp typ = (Some checked_guard, add_opt_constraint l "guard pattern" (assert_constraint env true checked_guard) env) in let checked_case = crule check_exp env' case typ in - construct_pexp (tpat, checked_guard, checked_case, (l, empty_tannot)) + construct_pexp (tpat, checked_guard, checked_case, (l, (None, uannot))) (* AA: Not sure if we still need this *) | exception (Type_error _ as typ_exn) -> ( match pat with @@ -2575,7 +2575,7 @@ and check_case env pat_typ pexp typ = let guard = match guard with None -> guard' | Some guard -> mk_exp (E_app_infix (guard, mk_id "&", guard')) in - check_case env pat_typ (Pat_aux (Pat_when (mk_pat ~loc:l (P_id (mk_id "p#")), guard, case), annot)) typ + check_case env pat_typ (Pat_aux (Pat_when (mk_pat ~loc:l (P_id (mk_id "p#")), guard, case), (l, uannot))) typ | _ -> raise typ_exn ) diff --git a/src/sail_c_backend/c_backend.ml b/src/sail_c_backend/c_backend.ml index a214c90aa..1adfa7ab5 100644 --- a/src/sail_c_backend/c_backend.ml +++ b/src/sail_c_backend/c_backend.ml @@ -436,7 +436,7 @@ end) : CONFIG = struct let aexp4 = analyze_functions ctx f aexp4 in AE_for (id, aexp1, aexp2, aexp3, order, aexp4) | AE_match (aval, cases, typ) -> - let analyze_case ((AP_aux (_, env, _) as pat), aexp1, aexp2) = + let analyze_case ((AP_aux (_, { env; _ }) as pat), aexp1, aexp2, uannot) = let pat_bindings = Bindings.bindings (apat_types pat) in let ctx = { ctx with local_env = env } in let ctx = @@ -444,14 +444,16 @@ end) : CONFIG = struct (fun ctx (id, typ) -> { ctx with locals = Bindings.add id (Immutable, convert_typ ctx typ) ctx.locals }) ctx pat_bindings in - (pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2) + (pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2, uannot) in AE_match (aval, List.map analyze_case cases, typ) | AE_try (aexp, cases, typ) -> AE_try ( analyze_functions ctx f aexp, List.map - (fun (pat, aexp1, aexp2) -> (pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2)) + (fun (pat, aexp1, aexp2, uannot) -> + (pat, analyze_functions ctx f aexp1, analyze_functions ctx f aexp2, uannot) + ) cases, typ ) @@ -1083,6 +1085,7 @@ and sgen_call op cvals = | _ -> assert false end | Get_abstract, [v] -> sgen_cval v + | Ite, [i; t; e] -> sprintf "(%s ? %s : %s)" (sgen_cval i) (sgen_cval t) (sgen_cval e) | _, _ -> failwith "Could not generate cval primop" let sgen_cval_param cval = diff --git a/src/sail_smt_backend/jib_smt.ml b/src/sail_smt_backend/jib_smt.ml index 915e98901..c037516fd 100644 --- a/src/sail_smt_backend/jib_smt.ml +++ b/src/sail_smt_backend/jib_smt.ml @@ -1250,7 +1250,7 @@ end) : Jib_compile.CONFIG = struct let aexp1 = analyze ctx aexp1 in let aexp2 = analyze ctx aexp2 in let annot = - if is_pure_aexp ctx aexp1 && is_pure_aexp ctx aexp2 then + if is_pure_aexp ctx.effect_info aexp1 && is_pure_aexp ctx.effect_info aexp2 then { annot with uannot = add_attribute (gen_loc loc) "anf_pure" None uannot } else annot in @@ -1265,7 +1265,7 @@ end) : Jib_compile.CONFIG = struct let aexp4 = analyze ctx aexp4 in (AE_for (id, aexp1, aexp2, aexp3, order, aexp4), annot) | AE_match (aval, cases, typ) -> - let analyze_case ((AP_aux (_, env, _) as pat), aexp1, aexp2) = + let analyze_case ((AP_aux (_, { env; _ }) as pat), aexp1, aexp2, uannot) = let pat_bindings = Bindings.bindings (apat_types pat) in let ctx = { ctx with local_env = env } in let ctx = @@ -1273,13 +1273,19 @@ end) : Jib_compile.CONFIG = struct (fun ctx (id, typ) -> { ctx with locals = Bindings.add id (Immutable, convert_typ ctx typ) ctx.locals }) ctx pat_bindings in - (pat, analyze ctx aexp1, analyze ctx aexp2) + let uannot = + match get_attribute "complete" uannot with + | Some (l, _) when List.for_all (is_pure_case ctx.effect_info) cases -> + add_attribute (gen_loc l) "anf_pure" None uannot + | _ -> uannot + in + (pat, analyze ctx aexp1, analyze ctx aexp2, uannot) in (AE_match (aval, List.map analyze_case cases, typ), annot) | AE_try (aexp, cases, typ) -> ( AE_try ( analyze ctx aexp, - List.map (fun (pat, aexp1, aexp2) -> (pat, analyze ctx aexp1, analyze ctx aexp2)) cases, + List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, analyze ctx aexp1, analyze ctx aexp2, uannot)) cases, typ ), annot diff --git a/src/sail_sv_backend/jib_sv.ml b/src/sail_sv_backend/jib_sv.ml index a083ad396..8ca853a00 100644 --- a/src/sail_sv_backend/jib_sv.ml +++ b/src/sail_sv_backend/jib_sv.ml @@ -152,6 +152,7 @@ type direct_footprint = { mutable writes_mem : bool; mutable contains_assert : bool; mutable references : CTSet.t; + mutable exits : bool; } let empty_direct_footprint () : direct_footprint = @@ -165,6 +166,7 @@ let empty_direct_footprint () : direct_footprint = writes_mem = false; contains_assert = false; references = CTSet.empty; + exits = false; } let get_bool_attribute name attr_object = @@ -204,6 +206,9 @@ class footprint_visitor ctx registers (footprint : direct_footprint) : jib_visit method! vinstr = function + | I_aux (I_exit _, _) -> + footprint.exits <- true; + SkipChildren | I_aux (I_funcall (_, _, (id, _), args), (l, _)) -> let open Util.Option_monad in if ctx_is_extern id ctx then ( @@ -265,6 +270,7 @@ type footprint = { reads_mem : bool; writes_mem : bool; contains_assert : bool; + exits : bool; } let pure_footprint = @@ -280,6 +286,7 @@ let pure_footprint = reads_mem = false; writes_mem = false; contains_assert = false; + exits = false; } type spec_info = { @@ -370,6 +377,7 @@ let collect_spec_info ctx cdefs = reads_mem = direct_footprint.reads_mem; writes_mem = direct_footprint.writes_mem; contains_assert = direct_footprint.contains_assert; + exits = direct_footprint.exits; } footprints | _ -> footprints @@ -384,10 +392,18 @@ let collect_spec_info ctx cdefs = | CDEF_aux (CDEF_fundef (f, _, _, body), _) -> let footprint = Bindings.find f footprints in let callees = cfg |> IdGraph.reachable (IdSet.singleton f) IdSet.empty |> IdSet.remove f in - let all_reads, all_writes, throws, need_stdout, need_stderr, reads_mem, writes_mem, contains_assert = + let all_reads, all_writes, throws, need_stdout, need_stderr, reads_mem, writes_mem, contains_assert, exits = List.fold_left - (fun (all_reads, all_writes, throws, need_stdout, need_stderr, reads_mem, writes_mem, contains_assert) - callee -> + (fun ( all_reads, + all_writes, + throws, + need_stdout, + need_stderr, + reads_mem, + writes_mem, + contains_assert, + exits + ) callee -> match Bindings.find_opt callee footprints with | Some footprint -> ( IdSet.union all_reads footprint.direct_reads, @@ -397,10 +413,20 @@ let collect_spec_info ctx cdefs = need_stderr || footprint.need_stderr, reads_mem || footprint.reads_mem, writes_mem || footprint.writes_mem, - contains_assert || footprint.contains_assert + contains_assert || footprint.contains_assert, + exits || footprint.exits ) | _ -> - (all_reads, all_writes, throws, need_stdout, need_stderr, reads_mem, writes_mem, contains_assert) + ( all_reads, + all_writes, + throws, + need_stdout, + need_stderr, + reads_mem, + writes_mem, + contains_assert, + exits + ) ) ( footprint.direct_reads, footprint.direct_writes, @@ -409,7 +435,8 @@ let collect_spec_info ctx cdefs = footprint.need_stderr, footprint.reads_mem, footprint.writes_mem, - footprint.contains_assert + footprint.contains_assert, + footprint.exits ) (IdSet.elements callees) in @@ -426,6 +453,7 @@ let collect_spec_info ctx cdefs = reads_mem; writes_mem; contains_assert; + exits; } ) footprints @@ -1449,7 +1477,7 @@ module Make (Config : CONFIG) = struct (natural_sort_ids (IdSet.elements footprint.all_writes)) in let throws = - if footprint.throws then + if footprint.throws || footprint.exits then [CL_id (Have_exception (-1), CT_bool); CL_id (Current_exception (-1), spec_info.exception_ctyp)] else [] in @@ -1927,7 +1955,7 @@ module Make (Config : CONFIG) = struct } ) (natural_sort_ids (IdSet.elements footprint.all_writes)) - @ ( if footprint.throws then + @ ( if footprint.throws || footprint.exits then [ { name = get_final_name (Have_exception (-1)); external_name = "have_exception"; typ = CT_bool }; { @@ -2032,7 +2060,7 @@ module Make (Config : CONFIG) = struct ] in let throws_outputs = - if footprint.throws then + if footprint.throws || footprint.exits then [SVD_var (Have_exception (-1), CT_bool); SVD_var (Current_exception (-1), spec_info.exception_ctyp)] else [] in @@ -2060,7 +2088,10 @@ module Make (Config : CONFIG) = struct @ List.map (fun reg -> SVP_id (Name (prepend_id "out_" reg, -1))) (natural_sort_ids (IdSet.elements footprint.all_writes)) - @ (if footprint.throws then [SVP_id (Have_exception (-1)); SVP_id (Current_exception (-1))] else []) + @ ( if footprint.throws || footprint.exits then + [SVP_id (Have_exception (-1)); SVP_id (Current_exception (-1))] + else [] + ) @ (if footprint.need_stdout then [SVP_id (Name (mk_id "out_stdout", -1))] else []) @ (if footprint.need_stderr then [SVP_id (Name (mk_id "out_stderr", -1))] else []) @ if footprint.writes_mem then [SVP_id (Name (mk_id "out_memory_writes", -1))] else []