Skip to content

Commit

Permalink
SV: Further annotations to optimize mappings
Browse files Browse the repository at this point in the history
This introduces 3 annotations that work together to help the Jib
compilation process optimize mappings

- Whenever we elaborate a mapping 'X' we always generate a
  X_backwards (or forwards) call under a X_backwards_matches
  guard. The rewrite notes this fact by placing a $[mapping_guarded]
  attribute on the call to X_backwards

- The last clause of a generated mapping is annotated with a $[mapping_last]
  attribute. This required extending the AST (and ANF) with attribute support
  here.

- Generated mappings themselves are annotation with a $[mapping_function]
  attribute.

For each mapping_function (X_backwards or similar) we generated a
X_backwards_infallible function, which generates a match statement that
returns a dummy value rather than match_failure when any $[mapping_last]
clause is not matched. Any call to X_backwards is changed to
X_backwards_infallible when it has a $[mapping_guarded] attribute.

SMT: Always optimize pure matches

(Note that match_failure is a side-effect)
  • Loading branch information
Alasdair committed Nov 20, 2024
1 parent 018a013 commit 120fe61
Show file tree
Hide file tree
Showing 15 changed files with 361 additions and 182 deletions.
118 changes: 81 additions & 37 deletions src/lib/anf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ module Big_int = Nat_big_num
(* 1. Conversion to A-normal form (ANF) *)
(**************************************************************************)

type aexp_annot = { loc : l; env : Env.t; uannot : uannot }
type anf_annot = { loc : l; env : Env.t; uannot : uannot }

type 'a aexp = AE_aux of 'a aexp_aux * aexp_annot
type 'a aexp = AE_aux of 'a aexp_aux * anf_annot

and 'a aexp_aux =
| AE_val of 'a aval
Expand All @@ -73,16 +73,16 @@ and 'a aexp_aux =
| AE_throw of 'a aval * 'a
| AE_if of 'a aval * 'a aexp * 'a aexp * 'a
| AE_field of 'a aval * id * 'a
| AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp) list * 'a
| AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp) list * 'a
| AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a
| AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a
| AE_struct_update of 'a aval * 'a aval Bindings.t * 'a
| AE_for of id * 'a aexp * 'a aexp * 'a aexp * order * 'a aexp
| AE_loop of loop * 'a aexp * 'a aexp
| AE_short_circuit of sc_op * 'a aval * 'a aexp

and sc_op = SC_and | SC_or

and 'a apat = AP_aux of 'a apat_aux * Env.t * l
and 'a apat = AP_aux of 'a apat_aux * anf_annot

and 'a apat_aux =
| AP_tuple of 'a apat list
Expand Down Expand Up @@ -112,7 +112,7 @@ let aexp_loc (AE_aux (_, { loc = l; _ })) = l

(* Renaming variables in ANF expressions *)

let rec apat_bindings (AP_aux (apat_aux, _, _)) =
let rec apat_bindings (AP_aux (apat_aux, _)) =
match apat_aux with
| AP_tuple apats -> List.fold_left IdSet.union IdSet.empty (List.map apat_bindings apats)
| AP_id (id, _) -> IdSet.singleton id
Expand All @@ -127,7 +127,7 @@ let rec apat_bindings (AP_aux (apat_aux, _, _)) =

(** This function returns the types of all bound variables in a
pattern. It ignores AP_global, apat_globals is used for that. *)
let rec apat_types (AP_aux (apat_aux, env, _)) =
let rec apat_types (AP_aux (apat_aux, { env; _ })) =
let merge id b1 b2 =
match (b1, b2) with
| None, None -> None
Expand All @@ -148,7 +148,7 @@ let rec apat_types (AP_aux (apat_aux, env, _)) =
| AP_struct (afpats, _) ->
List.fold_left (Bindings.merge merge) Bindings.empty (List.map (fun (_, apat) -> apat_types apat) afpats)

let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) =
let rec apat_rename from_id to_id (AP_aux (apat_aux, annot)) =
let apat_aux =
match apat_aux with
| AP_tuple apats -> AP_tuple (List.map (apat_rename from_id to_id) apats)
Expand All @@ -164,7 +164,7 @@ let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) =
| AP_struct (afpats, typ) ->
AP_struct (List.map (fun (field, apat) -> (field, apat_rename from_id to_id apat)) afpats, typ)
in
AP_aux (apat_aux, env, l)
AP_aux (apat_aux, annot)

let rec aval_typ = function
| AV_lit (_, typ) -> typ
Expand Down Expand Up @@ -251,9 +251,9 @@ let rec aexp_rename from_id to_id (AE_aux (aexp, annot)) =
in
AE_aux (aexp, annot)

and apexp_rename from_id to_id (apat, aexp1, aexp2) =
if IdSet.mem from_id (apat_bindings apat) then (apat, aexp1, aexp2)
else (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2)
and apexp_rename from_id to_id (apat, aexp1, aexp2, uannot) =
if IdSet.mem from_id (apat_bindings apat) then (apat, aexp1, aexp2, uannot)
else (apat, aexp_rename from_id to_id aexp1, aexp_rename from_id to_id aexp2, uannot)

let rec fold_aexp f (AE_aux (aexp, annot)) =
let aexp =
Expand All @@ -269,17 +269,39 @@ let rec fold_aexp f (AE_aux (aexp, annot)) =
| AE_for (id, aexp1, aexp2, aexp3, order, aexp4) ->
AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4)
| AE_match (aval, cases, typ) ->
AE_match (aval, List.map (fun (pat, aexp1, aexp2) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2)) cases, typ)
AE_match
( aval,
List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2, uannot)) cases,
typ
)
| AE_try (aexp, cases, typ) ->
AE_try
( fold_aexp f aexp,
List.map (fun (pat, aexp1, aexp2) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2)) cases,
List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, fold_aexp f aexp1, fold_aexp f aexp2, uannot)) cases,
typ
)
| (AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _) as v -> v
in
f (AE_aux (aexp, annot))

let rec is_pure_aexp effect_info (AE_aux (aexp, { uannot; _ })) =
match get_attribute "anf_pure" uannot with
| Some _ -> true
| None -> (
match aexp with
| AE_app (f, _, _) -> Effects.function_is_pure f effect_info
| AE_typ (aexp, _) -> is_pure_aexp effect_info aexp
| AE_let (Immutable, _, _, aexp1, aexp2, _) -> is_pure_aexp effect_info aexp1 && is_pure_aexp effect_info aexp2
| AE_match (_, arms, _) ->
List.for_all (fun (_, guard, aexp, _) -> is_pure_aexp effect_info guard && is_pure_aexp effect_info aexp) arms
| AE_short_circuit (_, _, aexp) -> is_pure_aexp effect_info aexp
| AE_if (_, then_aexp, else_aexp, _) -> is_pure_aexp effect_info then_aexp && is_pure_aexp effect_info else_aexp
| AE_val _ | AE_field _ -> true
| _ -> false
)

let is_pure_case effect_info (_, guard, body, _) = is_pure_aexp effect_info guard && is_pure_aexp effect_info body

let aexp_bindings aexp =
let ids = ref IdSet.empty in
let collect_lets = function
Expand Down Expand Up @@ -336,14 +358,14 @@ let rec no_shadow ids (AE_aux (aexp, annot)) =
in
AE_aux (aexp, annot)

and no_shadow_apexp ids (apat, aexp1, aexp2) =
and no_shadow_apexp ids (apat, aexp1, aexp2, uannot) =
let shadows = IdSet.inter (apat_bindings apat) ids in
let shadows = List.map (fun id -> (id, new_shadow id)) (IdSet.elements shadows) in
let rename aexp = List.fold_left (fun aexp (from_id, to_id) -> aexp_rename from_id to_id aexp) aexp shadows in
let rename_apat apat = List.fold_left (fun apat (from_id, to_id) -> apat_rename from_id to_id apat) apat shadows in
let ids = IdSet.union (apat_bindings apat) (IdSet.union ids (IdSet.of_list (List.map snd shadows))) in
let new_guard = no_shadow ids (rename aexp1) in
(rename_apat apat, new_guard, no_shadow (IdSet.union ids (aexp_bindings new_guard)) (rename aexp2))
(rename_apat apat, new_guard, no_shadow (IdSet.union ids (aexp_bindings new_guard)) (rename aexp2), uannot)

(* Map over all the avals in an aexp. *)
let rec map_aval f (AE_aux (aexp, annot)) =
Expand All @@ -366,10 +388,16 @@ let rec map_aval f (AE_aux (aexp, annot)) =
| AE_field (aval, field, typ) -> AE_field (f annot aval, field, typ)
| AE_match (aval, cases, typ) ->
AE_match
(f annot aval, List.map (fun (pat, aexp1, aexp2) -> (pat, map_aval f aexp1, map_aval f aexp2)) cases, typ)
( f annot aval,
List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, map_aval f aexp1, map_aval f aexp2, uannot)) cases,
typ
)
| AE_try (aexp, cases, typ) ->
AE_try
(map_aval f aexp, List.map (fun (pat, aexp1, aexp2) -> (pat, map_aval f aexp1, map_aval f aexp2)) cases, typ)
( map_aval f aexp,
List.map (fun (pat, aexp1, aexp2, uannot) -> (pat, map_aval f aexp1, map_aval f aexp2, uannot)) cases,
typ
)
| AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, f annot aval, map_aval f aexp)
in
AE_aux (aexp, annot)
Expand All @@ -391,11 +419,18 @@ let rec map_functions f (AE_aux (aexp, annot)) =
AE_for (id, map_functions f aexp1, map_functions f aexp2, map_functions f aexp3, order, map_functions f aexp4)
| AE_match (aval, cases, typ) ->
AE_match
(aval, List.map (fun (pat, aexp1, aexp2) -> (pat, map_functions f aexp1, map_functions f aexp2)) cases, typ)
( aval,
List.map
(fun (pat, aexp1, aexp2, uannot) -> (pat, map_functions f aexp1, map_functions f aexp2, uannot))
cases,
typ
)
| AE_try (aexp, cases, typ) ->
AE_try
( map_functions f aexp,
List.map (fun (pat, aexp1, aexp2) -> (pat, map_functions f aexp1, map_functions f aexp2)) cases,
List.map
(fun (pat, aexp1, aexp2, uannot) -> (pat, map_functions f aexp1, map_functions f aexp2, uannot))
cases,
typ
)
| (AE_field _ | AE_struct_update _ | AE_val _ | AE_return _ | AE_exit _ | AE_throw _) as v -> v
Expand Down Expand Up @@ -426,12 +461,14 @@ let rec pp_alexp = function
| AL_addr (id, typ) -> string "*" ^^ parens (pp_annot typ (pp_id id))
| AL_field (alexp, field) -> pp_alexp alexp ^^ dot ^^ pp_id field

let rec pp_aexp (AE_aux (aexp, annot)) =
( match get_attributes annot.uannot with
let pp_anf_uannot uannot =
match get_attributes uannot with
| [] -> empty
| attrs ->
concat_map (fun (_, attr, arg) -> string (string_of_attribute attr arg |> Util.magenta |> Util.clear)) attrs
)

let rec pp_aexp (AE_aux (aexp, annot)) =
pp_anf_uannot annot.uannot
^^
match aexp with
| AE_val v -> pp_aval v
Expand Down Expand Up @@ -492,7 +529,9 @@ let rec pp_aexp (AE_aux (aexp, annot)) =
(List.map (fun (id, aval) -> pp_id id ^^ string " = " ^^ pp_aval aval) (Bindings.bindings updates))
)

and pp_apat (AP_aux (apat_aux, _, _)) =
and pp_apat (AP_aux (apat_aux, annot)) =
pp_anf_uannot annot.uannot
^^
match apat_aux with
| AP_wild _ -> string "_"
| AP_id (id, typ) -> pp_annot typ (pp_id id)
Expand All @@ -513,7 +552,8 @@ and pp_apat (AP_aux (apat_aux, _, _)) =

and pp_cases cases = surround 2 0 lbrace (separate_map (comma ^^ hardline) pp_case cases) rbrace

and pp_case (apat, guard, body) = separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body]
and pp_case (apat, guard, body, uannot) =
pp_anf_uannot uannot ^^ parens (separate space [pp_apat apat; string "if"; pp_aexp guard; string "=>"; pp_aexp body])

and pp_block = function
| [] -> string "()"
Expand Down Expand Up @@ -553,8 +593,8 @@ let rec split_block l = function
(exp :: exps, last)
| [] -> Reporting.unreachable l __POS__ "empty block found when converting to ANF" [@coverage off]

let rec anf_pat ?(global = false) (P_aux (p_aux, annot) as pat) =
let mk_apat aux = AP_aux (aux, env_of_annot annot, fst annot) in
let rec anf_pat ?(global = false) (P_aux (p_aux, (l, tannot)) as pat) =
let mk_apat aux = AP_aux (aux, { loc = l; env = env_of_tannot tannot; uannot = untyped_annot tannot }) in
match p_aux with
| P_id id when global -> mk_apat (AP_global (id, typ_of_pat pat))
| P_id id -> mk_apat (AP_id (id, typ_of_pat pat))
Expand All @@ -575,11 +615,9 @@ let rec anf_pat ?(global = false) (P_aux (p_aux, annot) as pat) =
| P_as (pat, id) -> mk_apat (AP_as (anf_pat ~global pat, id, typ_of_pat pat))
| P_struct (fpats, FP_no_wild) ->
mk_apat (AP_struct (List.map (fun (field, pat) -> (field, anf_pat ~global pat)) fpats, typ_of_pat pat))
| _ ->
Reporting.unreachable (fst annot) __POS__
("Could not convert pattern to ANF: " ^ string_of_pat pat) [@coverage off]
| _ -> Reporting.unreachable l __POS__ ("Could not convert pattern to ANF: " ^ string_of_pat pat) [@coverage off]

let rec apat_globals (AP_aux (aux, _, _)) =
let rec apat_globals (AP_aux (aux, _)) =
match aux with
| AP_nil _ | AP_wild _ | AP_id _ -> []
| AP_global (id, typ) -> [(id, typ)]
Expand Down Expand Up @@ -737,22 +775,28 @@ let rec anf (E_aux (e_aux, (l, tannot)) as exp) =
mk_aexp (AE_val (AV_ref (id, lvar)))
| E_match (match_exp, pexps) ->
let match_aval, match_wrap = to_aval (anf match_exp) in
let anf_pexp (Pat_aux (pat_aux, (l, _))) =
let anf_pexp (Pat_aux (pat_aux, (l, tannot))) =
match pat_aux with
| Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body)
| Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body, untyped_annot tannot)
| Pat_exp (pat, body) ->
( anf_pat pat,
AE_aux (AE_val (AV_lit (mk_lit L_true, bool_typ)), { loc = l; env = env_of body; uannot = empty_uannot }),
anf body
anf body,
untyped_annot tannot
)
in
match_wrap (mk_aexp (AE_match (match_aval, List.map anf_pexp pexps, typ_of exp)))
| E_try (match_exp, pexps) ->
let match_aexp = anf match_exp in
let anf_pexp (Pat_aux (pat_aux, _)) =
let anf_pexp (Pat_aux (pat_aux, (l, tannot))) =
match pat_aux with
| Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body)
| Pat_exp (pat, body) -> (anf_pat pat, mk_aexp (AE_val (AV_lit (mk_lit L_true, bool_typ))), anf body)
| Pat_when (pat, guard, body) -> (anf_pat pat, anf guard, anf body, untyped_annot tannot)
| Pat_exp (pat, body) ->
( anf_pat pat,
AE_aux (AE_val (AV_lit (mk_lit L_true, bool_typ)), { loc = l; env = env_of body; uannot = empty_uannot }),
anf body,
untyped_annot tannot
)
in
mk_aexp (AE_try (match_aexp, List.map anf_pexp pexps, typ_of exp))
| ( E_var (LE_aux (LE_id id, _), binding, body)
Expand Down
18 changes: 11 additions & 7 deletions src/lib/anf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ open Type_check
the original Sail expression, it's typing environment, and the
uannot type containing any attributes attached to the original
expression. *)
type aexp_annot = { loc : l; env : Env.t; uannot : uannot }
type anf_annot = { loc : l; env : Env.t; uannot : uannot }

type 'a aexp = AE_aux of 'a aexp_aux * aexp_annot
type 'a aexp = AE_aux of 'a aexp_aux * anf_annot

and 'a aexp_aux =
| AE_val of 'a aval
Expand All @@ -98,8 +98,8 @@ and 'a aexp_aux =
| AE_throw of 'a aval * 'a
| AE_if of 'a aval * 'a aexp * 'a aexp * 'a
| AE_field of 'a aval * id * 'a
| AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp) list * 'a
| AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp) list * 'a
| AE_match of 'a aval * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a
| AE_try of 'a aexp * ('a apat * 'a aexp * 'a aexp * uannot) list * 'a
| AE_struct_update of 'a aval * 'a aval Bindings.t * 'a
| AE_for of id * 'a aexp * 'a aexp * 'a aexp * order * 'a aexp
| AE_loop of loop * 'a aexp * 'a aexp
Expand All @@ -110,7 +110,7 @@ and 'a aexp_aux =

and sc_op = SC_and | SC_or

and 'a apat = AP_aux of 'a apat_aux * Env.t * l
and 'a apat = AP_aux of 'a apat_aux * anf_annot

and 'a apat_aux =
| AP_tuple of 'a apat list
Expand Down Expand Up @@ -152,10 +152,10 @@ val aval_typ : typ aval -> typ
val aexp_typ : typ aexp -> typ

(** Map over all values in an ANF expression *)
val map_aval : (aexp_annot -> 'a aval -> 'a aval) -> 'a aexp -> 'a aexp
val map_aval : (anf_annot -> 'a aval -> 'a aval) -> 'a aexp -> 'a aexp

(** Map over all function calls in an ANF expression *)
val map_functions : (aexp_annot -> id -> 'a aval list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp
val map_functions : (anf_annot -> id -> 'a aval list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp

(** This function 'folds' an [aexp] applying the provided function to
all leaf subexpressions, then applying the function to their
Expand All @@ -164,6 +164,10 @@ val fold_aexp : ('a aexp -> 'a aexp) -> 'a aexp -> 'a aexp

val aexp_bindings : 'a aexp -> IdSet.t

val is_pure_aexp : Effects.side_effect_info -> 'a aexp -> bool

val is_pure_case : Effects.side_effect_info -> 'a apat * 'a aexp * 'a aexp * uannot -> bool

(** Remove all variable shadowing in an ANF expression *)
val no_shadow : IdSet.t -> 'a aexp -> 'a aexp

Expand Down
13 changes: 9 additions & 4 deletions src/lib/chunk_ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,8 @@ let flatten_block exps =

(* Check if a sequence of cases in a match or try statement is aligned *)
let is_aligned pexps =
let pexp_exp_column = function
let rec pexp_exp_column = function
| Pat_aux (Pat_attribute (_, _, pexp), _) -> pexp_exp_column pexp
| Pat_aux (Pat_exp (_, E_aux (_, l)), _) -> starting_column_num l
| Pat_aux (Pat_when (_, _, E_aux (_, l)), _) -> starting_column_num l
in
Expand Down Expand Up @@ -945,7 +946,7 @@ let rec chunk_exp comments chunks (E_aux (aux, l)) =
let kind = match match_exp with E_match _ -> Match_match | _ -> Try_match in
let exp_chunks = rec_chunk_exp exp in
let aligned = is_aligned cases in
let cases = List.map (chunk_pexp ~delim:"," comments) cases in
let cases = List.map (chunk_pexp ~delim:"," comments chunks) cases in
Match { kind; exp = exp_chunks; aligned; cases } |> add_chunk chunks
| E_vector_update _ | E_vector_update_subrange _ ->
let vec_chunks, updates = chunk_vector_update comments (E_aux (aux, l)) in
Expand Down Expand Up @@ -1052,8 +1053,12 @@ and chunk_vector_update comments (E_aux (aux, l) as exp) =
chunk_exp comments exp_chunks exp;
(exp_chunks, [])

and chunk_pexp ?delim comments (Pat_aux (aux, l)) =
and chunk_pexp ?delim comments chunks (Pat_aux (aux, l)) =
match aux with
| Pat_attribute (attr, arg, pexp) ->
Queue.add (Atom (Ast_util.string_of_attribute attr arg)) chunks;
Queue.add (Spacer (false, 1)) chunks;
chunk_pexp ?delim comments chunks pexp
| Pat_exp (pat, exp) ->
let funcl_space = match pat with P_aux (P_tuple _, _) -> false | _ -> true in
let pat_chunks = Queue.create () in
Expand Down Expand Up @@ -1087,7 +1092,7 @@ let chunk_funcl comments funcl =
Queue.add (Spacer (false, 1)) chunks;
chunk_funcl' comments funcl
| FCL_doc (_, funcl) -> chunk_funcl' comments funcl
| FCL_funcl (_, pexp) -> chunk_pexp comments pexp
| FCL_funcl (_, pexp) -> chunk_pexp comments chunks pexp
in
(chunks, chunk_funcl' comments funcl)

Expand Down
Loading

0 comments on commit 120fe61

Please sign in to comment.