Skip to content

Commit

Permalink
make module type checking more permissive; allow implicit return valu…
Browse files Browse the repository at this point in the history
…es to be dropped on calls
  • Loading branch information
wies committed Dec 5, 2024
1 parent eeb1048 commit ffa32f5
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 56 deletions.
13 changes: 13 additions & 0 deletions lib/ast/rewriter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,11 @@ module List = struct
(s, Base.List.Or_unequal_lengths.Ok res)
| Unequal_lengths -> (s, Unequal_lengths)

let map2_exn (xs : 'a list) (ys : 'b list) ~f s =
match map2 xs ys ~f s with
| (s, Base.List.Or_unequal_lengths.Ok zs) -> (s, zs)
| _ -> failwith "Rewriter.List.map2 unequal length"

let fold_right (xs : 'a list) ~(init : 'b) ~f : ('b, 'c) t_ext =
fun s -> List.fold_right xs ~f:(fun x (s, acc) -> f x acc s) ~init:(s, init)

Expand Down Expand Up @@ -1395,6 +1400,14 @@ let find_and_reify_field name : (AstDef.Module.field_def, 'a) t_ext =
| FieldDef field -> field
| _ -> Error.type_error (QualIdent.to_loc name) (Printf.sprintf "Expected field but found %s" (AstDef.Symbol.kind symbol))

let find_and_reify_module name : (AstDef.Module.t, 'a) t_ext =
let open Syntax in
let+ symbol = find_and_reify name in
match symbol with
| ModDef m -> m
| _ -> Error.type_error (QualIdent.to_loc name) (Printf.sprintf "Expected module or interface but found %s" (AstDef.Symbol.kind symbol))


let is_local qual_ident s =
let s, qual_ident = resolve qual_ident s in
(s, Base.List.is_empty qual_ident.qual_path)
Expand Down
106 changes: 72 additions & 34 deletions lib/frontend/rewrites/rewrites.ml
Original file line number Diff line number Diff line change
Expand Up @@ -720,30 +720,60 @@ let rec rewrite_ret_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t =
m "Rewrites.rewrite_ret_stmts: curr_proc_name: %a" QualIdent.pr
curr_proc_name);

let+ symbol = Rewriter.find_and_reify curr_proc_name in

match symbol with
| CallDef c ->
Logs.debug (fun m ->
m "Rewrites.rewrite_ret_stmts: curr_proc: %a" Callable.pr c);
c.call_decl
| _ -> Error.error stmt.stmt_loc "Expected a call_def"
Rewriter.find_and_reify_callable curr_proc_name |+> fun c -> c.call_decl
in

let ret_expr_list = Expr.unfold_tuple ret_expr in

let truncated_returns, dropped_returns =
List.split_n callable_decl.call_decl_returns
(List.length ret_expr_list)
in

(*let fresh_dropped_returns =
List.map dropped_formal_args ~f:(fun var_decl ->
{
var_decl with
var_name =
Ident.fresh stmt.stmt_loc var_decl.var_name.ident_name;
var_loc = stmt.stmt_loc;
})
in*)

let dropped_returns_exprs =
List.map dropped_returns ~f:Expr.from_var_decl
in

let dropped_returns_vars =
List.map dropped_returns ~f:(fun decl -> QualIdent.from_ident decl.var_name)
in

let ret_expr = Expr.mk_tuple (ret_expr_list @ dropped_returns_exprs) in

(* Need to ensure that call_decl_returns and call_desc.call_lhs line up *)
let renaming_map =
List.fold2_exn
callable_decl.call_decl_returns
(ret_expr_list @ dropped_returns_exprs)
~init:(Map.empty (module QualIdent))
~f:(fun map var_decl arg_expr ->
Map.add_exn map
~key:(QualIdent.from_ident var_decl.var_name)
~data:arg_expr)
in

(*let renaming_map =
List.fold2_exn callable_decl.call_decl_returns ret_expr_list
~init:(Map.empty (module QualIdent))
~f:(fun map var_decl expr ->
Map.add_exn map
~key:(QualIdent.from_ident var_decl.var_name)
~data:expr)
in
in*)

let postconds_spec = callable_decl.call_decl_postcond in

let postconds_exhale_stmts =
let postcond, spec_error =
if Callable.is_atomic callable_decl then
let atomic_token_var =
Expr.mk_var ~typ:(Type.atomic_token curr_proc_name)
Expand All @@ -762,34 +792,42 @@ let rec rewrite_ret_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t =

let error =
( Error.Verification,
stmt.stmt_loc,
loc,
"The atomic specification may not have been committed before \
reaching this return point" )
in

let loc = Stmt.to_loc stmt in
[
Stmt.mk_exhale_expr ~cmnt:("au_return_stmt") ~loc
~spec_error:[ Stmt.mk_const_spec_error error ]
(Expr.mk_app ~loc ~typ:Type.perm
(Expr.AUPredCommit curr_proc_name)
((atomic_token_var :: concrete_args_expr) @ [ ret_expr ]));
]
[Expr.mk_app ~loc ~typ:Type.perm
(Expr.AUPredCommit curr_proc_name)
((atomic_token_var :: concrete_args_expr) @ [ ret_expr ])],
[ Stmt.mk_const_spec_error error ]
else
List.map postconds_spec ~f:(fun spec ->
let expr = Expr.alpha_renaming spec.spec_form renaming_map in

let error =
( Error.Verification,
stmt.stmt_loc,
"A postcondition may not hold at this return point" )
in

Stmt.mk_exhale_expr ~loc:stmt.stmt_loc
~cmnt:
("postconds added for ret_stmt: " ^ Stmt.to_string stmt)
~spec_error:(Stmt.mk_const_spec_error error :: spec.spec_error)
expr)
Expr.alpha_renaming spec.spec_form renaming_map),
let error =
( Error.Verification,
loc,
"A postcondition may not hold at this return point" )
in
[ Stmt.mk_const_spec_error error ]
in

let bind_stmt =
match dropped_returns with
| [] -> Stmt.mk_skip ~loc
| _ ->
Stmt.mk_bind ~loc dropped_returns_vars
(Stmt.mk_spec
~atomic:false
~cmnt:("bind added for return stmt: " ^ Stmt.to_string stmt)
~spec_error
(Expr.mk_and postcond))
in

let postconds_exhale_stmts =
List.map postcond ~f:(fun expr ->
Stmt.mk_exhale_expr ~cmnt:("exhape added for return stmt: " ^ Stmt.to_string stmt) ~loc
~spec_error
expr)
in

let assume_false =
Expand All @@ -799,7 +837,7 @@ let rec rewrite_ret_stmts (stmt : Stmt.t) : Stmt.t Rewriter.t =

let new_stmt =
Stmt.mk_block_stmt ~loc:stmt.stmt_loc
(postconds_exhale_stmts @ [ assume_false ])
(bind_stmt :: postconds_exhale_stmts @ [ assume_false ])
in

Rewriter.return new_stmt
Expand Down
44 changes: 22 additions & 22 deletions lib/frontend/typing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1497,10 +1497,7 @@ module ProcessCallable = struct
Error.type_error stmt_loc "Cannot return in a ghost block";

let* expr = disambiguate_expr expr disam_tbl in
let return_list = match expr with
| App (Tuple, return_list, _) -> return_list
| _ -> [expr]
in
let return_list = Expr.unfold_tuple expr in

let+ return_list = process_callable_returns stmt_loc is_ghost_scope call_decl return_list in
let expr = Expr.mk_tuple ~loc:(Expr.to_loc expr) return_list in
Expand Down Expand Up @@ -1624,22 +1621,25 @@ module ProcessCallable = struct
qual_ident :: assign_lhs, var_decl :: var_decls_lhs
)
in
let* call_lhs_types =
Rewriter.List.map var_decls_lhs ~f:(fun var_decl ->
ProcessTypeExpr.expand_type_expr var_decl.var_type)
let* call_lhs_expr =
Rewriter.List.map2_exn call_lhs var_decls_lhs ~f:(fun qual_ident var_decl ->
let+ typ = ProcessTypeExpr.expand_type_expr var_decl.var_type in
Expr.mk_var ~typ qual_ident)
in

let* call_decl = Rewriter.find_and_reify_callable call_desc.call_name |+> fun c -> c.call_decl in
let* call_lhs_expr = process_callable_returns stmt_loc is_ghost_scope call_decl call_lhs_expr in

let expected_return_type =
Type.mk_prod stmt_loc call_lhs_types
in

let is_ghost = List.for_all call_lhs_expr ~f:(fun e -> e |> Expr.to_type |> Type.is_ghost) in

let+ call_expr =
Expr.App
( Var call_desc.call_name,
call_desc.call_args,
{ Expr.expr_loc = stmt_loc; expr_type = Type.bot } )
{ Expr.expr_loc = stmt_loc; expr_type = Type.any } )
|> fun expr ->
disambiguate_process_expr expr expected_return_type disam_tbl
disambiguate_process_expr expr (Type.any |> Type.set_ghost is_ghost) disam_tbl
in

match call_expr with
Expand Down Expand Up @@ -2189,7 +2189,6 @@ module ProcessModule = struct
with
| Some mod_typ, orig_mod_typ
when QualIdent.(mod_typ <> orig_mod_typ) ->
Logs.debug (fun m -> m !"%{QualIdent} %{QualIdent}" mod_typ orig_mod_typ);
Error.type_error loc
(Printf.sprintf
!"%s %{Ident} must implement interface %{QualIdent} \
Expand Down Expand Up @@ -2240,15 +2239,16 @@ module ProcessModule = struct
!"Cannot redeclare interface %{Ident} from interface \
%{QualIdent} as module"
ident interface_ident)
else if
QualIdent.(mod_inst.mod_inst_type <> orig_mod_inst.mod_inst_type)
then
Error.type_error loc
(Printf.sprintf
!"%s %{Ident} must implement interface %{QualIdent} according \
to interface %{QualIdent}"
(Symbol.kind symbol |> String.capitalize)
ident orig_mod_inst.mod_inst_type interface_ident)
else
let* mod_inst_def = Rewriter.find_and_reify_module mod_inst.mod_inst_type in
if not @@ Set.mem mod_inst_def.mod_decl.mod_decl_interfaces orig_mod_inst.mod_inst_type then
let _ = Logs.debug (fun m -> m !"%{QualIdent} %{QualIdent}" mod_inst.mod_inst_type orig_mod_inst.mod_inst_type) in
Error.type_error loc
(Printf.sprintf
!"%s %{Ident} must implement interface %{QualIdent} according \
to interface %{QualIdent}"
(Symbol.kind symbol |> String.capitalize)
ident orig_mod_inst.mod_inst_type interface_ident)
else
match (mod_inst.mod_inst_def, orig_mod_inst.mod_inst_def) with
| Some _, Some _ ->
Expand Down

0 comments on commit ffa32f5

Please sign in to comment.