Skip to content

Commit

Permalink
clean up AST wrt. assignment statement rep; other minor signature cha…
Browse files Browse the repository at this point in the history
…nges
  • Loading branch information
wies committed Nov 16, 2024
1 parent 861d6bb commit 2f20d82
Show file tree
Hide file tree
Showing 7 changed files with 830 additions and 999 deletions.
44 changes: 22 additions & 22 deletions lib/ast/astDef.ml
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ module Type = struct
let mk_fld loc tpf = App (Fld, [tpf], mk_attr loc)
let mk_perm loc = App (Perm, [], mk_attr loc)
let mk_data id decls loc = App (Data (id, decls), [], mk_attr loc)
let mk_var loc qid = App (Var qid, [], mk_attr loc)
let mk_var qid = App (Var qid, [], mk_attr (QualIdent.to_loc qid))
let mk_atomic_token loc = App (AtomicToken, [], mk_attr loc)
let mk_prod loc tp_list =
match tp_list with
Expand All @@ -387,7 +387,7 @@ module Type = struct
let map = mk_map Loc.dummy
let perm = mk_perm Loc.dummy
let data id decls = mk_data id decls Loc.dummy
let var qid = mk_var Loc.dummy qid
let var qid = mk_var qid
let atomic_token = mk_atomic_token Loc.dummy

(** Equality and Subtyping *)
Expand Down Expand Up @@ -443,7 +443,7 @@ module Type = struct

(** Auxiliary utility functions *)

let mk_var_decl ?(loc = Loc.dummy) ?(const = false) ?(ghost = false) ?(implicit = false) name tp =
let mk_var_decl ?(const = false) ?(ghost = false) ?(implicit = false) name ?(loc = Ident.to_loc name) tp =
{ var_name = name; var_loc = loc; var_type = tp; var_const = const; var_ghost = ghost; var_implicit = implicit }

let is_num tp =
Expand Down Expand Up @@ -764,8 +764,8 @@ module Expr = struct
let mk_app ?(loc = Loc.dummy) ~typ c es =
App (c, es, mk_attr loc typ)

let mk_var ?(loc = Loc.dummy) ~typ (qual_ident: qual_ident) =
mk_app ~loc ~typ (Var qual_ident) []
let mk_var ~typ (qual_ident: qual_ident) =
mk_app ~loc:(QualIdent.to_loc qual_ident) ~typ (Var qual_ident) []

let mk_binder ?(loc = Loc.dummy) ?(typ = Type.bool) ?(trigs = []) b vs e =
match vs with
Expand Down Expand Up @@ -866,7 +866,7 @@ module Expr = struct
App (MapLookUp, [ e1; e2 ], mk_attr loc t)

let from_var_decl (var_decl:var_decl) =
mk_var ~loc:var_decl.var_loc ~typ:var_decl.var_type (QualIdent.from_ident var_decl.var_name)
mk_var ~typ:var_decl.var_type (QualIdent.from_ident var_decl.var_name)

(** Auxiliary functions *)

Expand Down Expand Up @@ -1169,7 +1169,7 @@ module Stmt = struct
new_args : (qual_ident * expr option) list;
}

type assign_desc = { assign_lhs : expr list; assign_rhs : expr; assign_is_init : bool }
type assign_desc = { assign_lhs : qual_ident list; assign_rhs : expr; assign_is_init : bool }
type bind_desc = { bind_lhs : expr list; bind_rhs : expr }

type field_read_desc = {
Expand Down Expand Up @@ -1322,8 +1322,8 @@ module Stmt = struct
| Assign astm -> (
match astm.assign_lhs with
| [] -> Expr.pr ppf astm.assign_rhs
| es ->
fprintf ppf "@[<2>%a@ :=@ %a@]" Expr.pr_list es Expr.pr
| vs ->
fprintf ppf "@[<2>%a@ :=@ %a@]" QualIdent.pr_list vs Expr.pr
astm.assign_rhs)
| Bind bstm -> (
match bstm.bind_lhs with
Expand Down Expand Up @@ -1497,6 +1497,11 @@ module Stmt = struct
let mk_assign ~loc lhs rhs =
{ stmt_desc = Basic (Assign { assign_lhs = lhs; assign_rhs = rhs; assign_is_init = false }); stmt_loc = loc }

let mk_field_write ~loc ref field v =
{ stmt_desc = Basic (FieldWrite {field_write_ref = ref; field_write_field = field; field_write_val = v});
stmt_loc = loc
}

let mk_return ~loc e = { stmt_desc = Basic (Return e); stmt_loc = loc }

let mk_bind ~loc lhs rhs =
Expand Down Expand Up @@ -1544,7 +1549,10 @@ module Stmt = struct
Set.add accesses new_desc.new_lhs

| Assign assign_desc ->
scan_expr_list accesses (assign_desc.assign_rhs :: assign_desc.assign_lhs)
let accesses =
List.fold assign_desc.assign_lhs ~init:accesses ~f:Set.add
in
scan_expr_list accesses [assign_desc.assign_rhs]

| Bind bind_desc ->
scan_expr_list accesses (bind_desc.bind_rhs :: bind_desc.bind_lhs)
Expand Down Expand Up @@ -1641,15 +1649,8 @@ module Stmt = struct
[]

| Assign assign_desc ->
List.filter_map assign_desc.assign_lhs ~f:(fun e ->
match e with
| App (Var qi, _, _) ->
if List.is_empty qi.qual_path then
Some qi.qual_base
else
None
| _ -> None
)
List.map assign_desc.assign_lhs ~f:QualIdent.unqualify


| Bind bind_desc ->
List.filter_map bind_desc.bind_lhs ~f:(fun e ->
Expand Down Expand Up @@ -1748,7 +1749,7 @@ module Stmt = struct
[]

| Assign assign_desc ->
List.concat_map (assign_desc.assign_rhs :: assign_desc.assign_lhs) ~f:(fun e -> Expr.expr_fields_accessed e)
Expr.expr_fields_accessed assign_desc.assign_rhs

| Bind bind_desc ->
List.concat_map (bind_desc.bind_rhs :: bind_desc.bind_lhs) ~f:(fun e -> Expr.expr_fields_accessed e)
Expand Down Expand Up @@ -2489,8 +2490,7 @@ module ProgStats = struct
| Assign assign_desc ->
let is_ghost = try
(List.exists assign_desc.assign_lhs
~f:(fun e ->
let qi = Expr.to_qual_ident e in
~f:(fun qi ->
List.exists proc_decl.Callable.call_decl_locals ~f:(fun vd -> Ident.(vd.var_name = (QualIdent.to_ident qi)) )
)
)
Expand Down
29 changes: 12 additions & 17 deletions lib/ast/rewriter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -641,11 +641,10 @@ module Stmt = struct
Basic (Spec (sk, { spec with spec_form = new_spec_form }));
}
| Assign assign ->
let* assign_lhs = List.map assign.assign_lhs ~f in
let+ assign_rhs = f assign.assign_rhs in
{
stmt with
stmt_desc = Basic (Assign { assign with assign_lhs; assign_rhs });
stmt_desc = Basic (Assign { assign with assign_rhs });
}
| Bind bind_desc ->
let* bind_lhs = List.map bind_desc.bind_lhs ~f in
Expand Down Expand Up @@ -781,10 +780,8 @@ module Stmt = struct
in
{ stmt with stmt_desc = Basic (VarDef { var_decl; var_init }) }
| Assign assign ->
let+ assign_rhs = Expr.rewrite_qual_idents ~f assign.assign_rhs
and+ assign_lhs =
List.map assign.assign_lhs ~f:(Expr.rewrite_qual_idents ~f)
in
let+ assign_rhs = Expr.rewrite_qual_idents ~f assign.assign_rhs in
let assign_lhs = Base.List.map assign.assign_lhs ~f in
{
stmt with
stmt_desc = Basic (Assign { assign with assign_lhs; assign_rhs });
Expand Down Expand Up @@ -1279,8 +1276,7 @@ module Symbol = struct
Some tp_expr
| ModDef { mod_decl = { mod_decl_rep = Some rep_id; _ }; _ } ->
let+ tp_expr =
AstDef.Type.mk_var (QualIdent.to_loc name)
(QualIdent.append name rep_id)
AstDef.Type.mk_var (QualIdent.append name rep_id)
|> Type.rewrite_qual_idents ~f:(QualIdent.requalify subst)
in
Some tp_expr
Expand Down Expand Up @@ -1351,8 +1347,8 @@ let find_and_reify loc name : (AstDef.Module.symbol, 'a) t_ext =
m "Rewriter.find_and_reify: symbol = %a" Symbol.pr symbol);*)
Symbol.reify symbol

let is_local loc qual_ident s =
let s, qual_ident = resolve loc qual_ident s in
let is_local qual_ident s =
let s, qual_ident = resolve (QualIdent.to_loc qual_ident) qual_ident s in
(s, Base.List.is_empty qual_ident.qual_path)

module ProgUtils = struct
Expand Down Expand Up @@ -1640,7 +1636,6 @@ module ProgUtils = struct

let get_ra_rep_type (ra_qual_iden : qual_ident) : type_expr =
AstDef.Type.mk_var
(QualIdent.to_loc ra_qual_iden)
(QualIdent.append ra_qual_iden
(Ident.make (QualIdent.to_loc ra_qual_iden) "T" 0))

Expand Down Expand Up @@ -1842,7 +1837,7 @@ module ProgUtils = struct
QualIdent.append field_utils_module (heap_utils_id_ident loc)
in

return @@ AstDef.Expr.mk_var ~loc id_qual_ident ~typ:field_elem_type
return @@ AstDef.Expr.mk_var id_qual_ident ~typ:field_elem_type

let get_pred_utils_id loc pred_name : expr t =
let open Syntax in
Expand All @@ -1852,13 +1847,13 @@ module ProgUtils = struct

let* pred_elem_type_name = get_pred_utils_rep_type loc pred_name in

let pred_elem_type = AstDef.Type.mk_var loc pred_elem_type_name in
let pred_elem_type = AstDef.Type.mk_var pred_elem_type_name in

let id_qual_ident =
QualIdent.append pred_utils_module (heap_utils_id_ident loc)
in

return @@ AstDef.Expr.mk_var ~loc id_qual_ident ~typ:pred_elem_type
return @@ AstDef.Expr.mk_var id_qual_ident ~typ:pred_elem_type

let get_au_utils_id loc call_name : expr t =
let open Syntax in
Expand All @@ -1868,13 +1863,13 @@ module ProgUtils = struct

let* call_elem_type_name = get_au_utils_rep_type loc call_name in

let call_elem_type = AstDef.Type.mk_var loc call_elem_type_name in
let call_elem_type = AstDef.Type.mk_var call_elem_type_name in

let id_qual_ident =
QualIdent.append call_utils_module (heap_utils_id_ident loc)
in

return @@ AstDef.Expr.mk_var ~loc id_qual_ident ~typ:call_elem_type
return @@ AstDef.Expr.mk_var id_qual_ident ~typ:call_elem_type

(* ======================== *)

Expand Down Expand Up @@ -1987,7 +1982,7 @@ module ProgUtils = struct
AstDef.Type.mk_map
(QualIdent.to_loc pred_name)
(AstDef.Type.mk_prod (QualIdent.to_loc pred_name) pred_in_types)
(AstDef.Type.mk_var (QualIdent.to_loc pred_name) pred_rep_type)
(AstDef.Type.mk_var pred_rep_type)

let rec is_expr_pure (expr : expr) : (bool, 'a) t_ext =
let open Syntax in
Expand Down
2 changes: 1 addition & 1 deletion lib/backend/smt_solver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ let declare_tuple_sort (arity : int) : command =

let destrs_sorts =
Base.List.map2_exn destrs params ~f:(fun destr param ->
(destr, Ast.Type.mk_var Util.Loc.dummy param))
(destr, Ast.Type.mk_var param))
in

mk_declare_datatype (tuple_sort_name, params, [ (constr, destrs_sorts) ])
Expand Down
17 changes: 12 additions & 5 deletions lib/frontend/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ stmt_wo_trailing_substmt:
Stmt.(Basic (Assign assign))
}
(* assignment / allocation *)
| es = separated_nonempty_list(COMMA, expr); COLONEQ; e = new_or_expr; SEMICOLON {
| es = separated_nonempty_list(COMMA, expr); COLONEQ; e = assign_rhs; SEMICOLON {
let open Stmt in
match e with
| Basic (New new_descr) ->
Expand All @@ -436,7 +436,14 @@ stmt_wo_trailing_substmt:
(match es with
| [Expr.(App (Read, [field_write_ref; App (Var field_write_field, [], _)], _))] ->
Basic (FieldWrite { field_write_ref; field_write_field; field_write_val = assign.assign_rhs })
| _ -> Basic (Assign { assign with assign_lhs = es }))
| _ ->
let vs = List.map (function
| Expr.(App (Var qual_ident, [], _))
when QualIdent.is_local qual_ident -> qual_ident
| e -> Error.syntax_error (Expr.to_loc e) "Expected single field location or local variables on left-hand side of assignment")
es
in
Basic (Assign { assign with assign_lhs = vs }))
| _ -> assert false
}
(* bind *)
Expand Down Expand Up @@ -549,7 +556,7 @@ with_clause:
| _ -> Error.syntax_error (Loc.make $startpos $startpos) "A 'with' clause is only allowed in assert statements"
}

new_or_expr:
assign_rhs:
| NEW LPAREN fes = separated_list(COMMA, pair(qual_ident, option(preceded(COLON, expr)))) RPAREN {
let new_descr = Stmt.{
new_lhs = QualIdent.from_ident (Ident.make Loc.dummy "" 0);
Expand Down Expand Up @@ -1007,10 +1014,10 @@ type_expr:
| REF { Type.mk_ref (Loc.make $startpos $endpos) }
| PERM { Type.mk_perm (Loc.make $startpos $endpos)}
| ATOMICTOKEN { Type.mk_atomic_token (Loc.make $startpos $endpos) }
//| x = IDENT { Type.mk_var (Loc.make $startpos $endpos) (QualIdent.from_ident x) }
//| x = IDENT { Type.mk_var (QualIdent.from_ident x) }
| SET LBRACKET t = type_expr RBRACKET { Type.mk_set (Loc.make $startpos $endpos) t }
| MAP LBRACKET; t1 = type_expr; COMMA; t2 = type_expr; RBRACKET { Type.mk_map (Loc.make $startpos $endpos) t1 t2 }
| x = mod_ident { Type.mk_var (Loc.make $startpos $endpos) x }
| x = mod_ident { Type.mk_var x }
| LPAREN ts = type_expr_list RPAREN { Type.mk_prod (Loc.make $startpos $endpos) ts }
| x = mod_ident LBRACKET; ts = type_expr_list; RBRACKET {
Type.(App(Var x, ts, Type.mk_attr (Loc.make $startpos $endpos))) }
Expand Down
Loading

0 comments on commit 2f20d82

Please sign in to comment.