From c562dbda251bcf9fdbd30712c56b7ef6ee69a2f2 Mon Sep 17 00:00:00 2001 From: Nicholas Coughlin Date: Wed, 29 Mar 2023 22:12:11 +1000 Subject: [PATCH 1/5] Cache inline function list --- libASL/dis.ml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/libASL/dis.ml b/libASL/dis.ml index 2148feb3..9a4390b2 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -821,10 +821,13 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws = | Expr_LitString(s) -> DisEnv.pure (Val (from_stringLit s)) ) +and no_inline_pure = List.map (fun (x,y) -> FIdent(x,y)) + ["LSL",0; "LSR",0; "ASR",0; "SignExtend",0; "ZeroExtend",0] +and no_inline_impure = List.map (fun (x,y) -> FIdent (x,y)) + no_inline + (** Disassemble call to function *) and dis_funcall (loc: l) (f: ident) (tvs: sym list) (vs: sym list): sym rws = - let no_inline_pure = List.map (fun (x,y) -> FIdent(x,y)) - ["LSL",0; "LSR",0; "ASR",0; "SignExtend",0; "ZeroExtend",0] in let+ ret = DisEnv.catcherror (dis_call loc f tvs vs) in (* we always want to reduce to values if possible, but exceptions may be thrown while disassembling functions. *) @@ -857,7 +860,6 @@ and dis_call (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws = let@ fn = DisEnv.getFun loc f in - let no_inline_impure = List.map (fun (x,y) -> FIdent (x,y)) no_inline in (match fn with | Some (rty, _, targs, _, _, _) when List.mem f no_inline_impure -> (* impure functions are not visited. *) From 999efb265cd899f12e526ad0750491429c342a6b Mon Sep 17 00:00:00 2001 From: Nicholas Coughlin Date: Wed, 29 Mar 2023 22:29:45 +1000 Subject: [PATCH 2/5] Avoid exception handling around function calls --- libASL/dis.ml | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/libASL/dis.ml b/libASL/dis.ml index 9a4390b2..86837da7 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -828,22 +828,18 @@ and no_inline_impure = List.map (fun (x,y) -> FIdent (x,y)) (** Disassemble call to function *) and dis_funcall (loc: l) (f: ident) (tvs: sym list) (vs: sym list): sym rws = - let+ ret = DisEnv.catcherror (dis_call loc f tvs vs) in - (* we always want to reduce to values if possible, but exceptions may be thrown while - disassembling functions. *) - match ret with - | Ok None -> internal_error loc "function call finished without returning a value" - | Ok (Some (Val v)) -> Val v - | Error _ - | Ok (Some (Exp _)) when List.mem f no_inline_pure -> - (* this case is reached when a no_inline_pure function cannot be fully evaluated. *) - (* in this case, simply emit the primitive. *) - let expr = Exp (Expr_TApply (f, List.map sym_expr tvs, List.map sym_expr vs)) in - (match sym_prim_simplify (name_of_FIdent f) tvs vs with - | Some x -> x - | None -> expr) - | Ok (Some (Exp e)) -> Exp e - | Error (exn,bt) -> raise exn (* it is an error if a non-primitive function cannot be disassembled. *) + if List.mem f no_inline_pure && + ((List.exists (function Exp _ -> true | _ -> false) tvs) || + (List.exists (function Exp _ -> true | _ -> false) vs)) then + let expr = Exp (Expr_TApply (f, List.map sym_expr tvs, List.map sym_expr vs)) in + DisEnv.pure (match sym_prim_simplify (name_of_FIdent f) tvs vs with + | Some x -> x + | None -> expr) + else + let+ r = dis_call loc f tvs vs in + match r with + | Some x -> x + | None -> internal_error loc "function call finished without returning a value" (** Evaluate call to procedure *) and dis_proccall (loc: l) (f: ident) (tvs: sym list) (vs: sym list): unit rws = From 3fc7aeee1416171242c421399e5ee009e9b21b28 Mon Sep 17 00:00:00 2001 From: Nicholas Coughlin Date: Fri, 31 Mar 2023 13:09:51 +1000 Subject: [PATCH 3/5] Add tree structure for writer Avoid repeated appends by building a tree and flattening later --- libASL/dis.ml | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/libASL/dis.ml b/libASL/dis.ml index 86837da7..9514ebca 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -279,13 +279,33 @@ module LocalEnv = struct end +type tree = + Node of stmt list + | Branch of tree * tree + +let empty = Node [] +let single x = Node x +let append x y = + match x, y with + | Node [], _ -> y + | _, Node [] -> x + | Node [x], Node y -> Node (x::y) + | _ -> Branch (x, y) + +let rec flatten x acc = + match x with + | Branch (x,y) -> + let acc = flatten y acc in + flatten x acc + | Node i -> i@acc + module DisEnv = struct include Rws.Make(struct type r = Eval.Env.t - type w = stmt list + type w = tree type s = LocalEnv.t - let mempty = [] - let mappend = (@) + let mempty = empty + let mappend = (append) end) open Let @@ -362,6 +382,8 @@ module DisEnv = struct let warn s = debug 0 ("WARNING: " ^ s) + let write i = fun e x -> ((),x,Node i) + let scope (loc: l) (name: string) (arg: string) (pp: 'a -> string) (x: 'a rws): 'a rws = (* logging header. looks like: +- dis_expr --> 1 + 1. *) log (Printf.sprintf "\u{256d}\u{2500} %s --> %s" name arg) >> @@ -373,6 +395,7 @@ module DisEnv = struct (* run computation but obtain state and writer to output in debugging. *) let* (result,s',w') = locally (catcherror x) in + let w' = flatten w' [] in let x' = (match result with | Error ((DisTrace _) as e, bt) -> raise e | Error (exn, bt) -> Printexc.raise_with_backtrace (DisTrace (trace, exn)) bt @@ -394,6 +417,7 @@ module DisEnv = struct else unit in pure x' + end type 'a rws = 'a DisEnv.rws @@ -518,12 +542,14 @@ let rec sym_if (loc: l) (t: ty) (test: sym rws) (tcase: sym rws) (fcase: sym rws (* Evaluate true branch statements. *) let@ (tenv,tstmts) = DisEnv.locally_ (tcase >>= assign_var loc tmp) in + let tstmts = flatten tstmts [] in (* Propagate incremented counter to env'. *) let@ env' = DisEnv.gets (fun env -> LocalEnv.sequence_merge env tenv) in (* Execute false branch statements with env'. *) let@ (fenv,fstmts) = DisEnv.locally_ (DisEnv.put env' >> fcase >>= assign_var loc tmp) in let@ () = DisEnv.join_locals tenv fenv in + let fstmts = flatten fstmts [] in match ite_body tstmts tmp, ite_body fstmts tmp with | Some (Val _ as te), Some fe | Some te, Some (Val _ as fe) -> @@ -545,9 +571,11 @@ and unit_if (loc: l) (test: sym rws) (tcase: unit rws) (fcase: unit rws): unit r | Val _ -> failwith ("Split on non-boolean value") | Exp e -> let@ (tenv,tstmts) = DisEnv.locally_ tcase in + let tstmts = flatten tstmts [] in let@ env' = DisEnv.gets (fun env -> LocalEnv.sequence_merge env tenv) in let@ (fenv,fstmts) = DisEnv.locally_ (DisEnv.put env' >> fcase) in + let fstmts = flatten fstmts [] in let@ () = DisEnv.join_locals tenv fenv in DisEnv.write [Stmt_If(e, tstmts, [], fstmts, loc)]) @@ -1373,6 +1401,7 @@ and dis_decode_alt' (loc: AST.l) (DecoderAlt_Alt (ps, b)) (vs: value list) (op: let@ () = dis_stmts exec in DisEnv.modify (LocalEnv.popLevel) ) in + let stmts = flatten stmts [] in if !debug_level >= 2 then begin Printf.printf "-----------\n"; @@ -1404,6 +1433,7 @@ type env = (LocalEnv.t * IdentSet.t) let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_case) (op: value): stmt list = let DecoderCase_Case (_,_,loc) = decode in let ((),lenv',stmts) = (dis_decode_case loc decode op) env lenv in + let stmts = flatten stmts [] in let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts in let stmts' = Transforms.RedundantSlice.do_transform stmts' in let stmts' = Transforms.StatefulIntToBits.run stmts' in From db0ec06082b7e4593270155f4f0eb5919106ecfd Mon Sep 17 00:00:00 2001 From: Nicholas Coughlin Date: Fri, 31 Mar 2023 13:19:32 +1000 Subject: [PATCH 4/5] Reduce partial evaluation state keys to strings --- libASL/dis.ml | 93 +++++++++++++++++++++++++++++---------------------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/libASL/dis.ml b/libASL/dis.ml index 9514ebca..c5c65ecf 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -17,6 +17,13 @@ open Value open Symbolic +module StringCmp = struct + type t = string + let compare (x: string) (y: string): int = String.compare x y +end +module StringMap = Map.Make(StringCmp) + + let debug_level = ref 0 let debug_show_trace = ref false let no_debug = fun () -> !debug_level = 0 @@ -101,19 +108,18 @@ let no_inline = [ The "stack level" is how many scopes deep it is. For example, globals are level 0 and this increases by 1 for each nested function call. *) -type var = Var of int * ident -let pp_var (Var (i,id)) = Printf.sprintf "Var(%d,%s)" i (pprint_ident id) +type var = Var of int * string +let pp_var (Var (i,id)) = Printf.sprintf "Var(%d,%s)" i (id) let var_ident (Var (i,id)) = match i,id with - | 0,Ident s -> Ident s (* special case globals with no suffix. *) - | _,Ident s -> Ident (s ^ "__" ^ string_of_int i) - | _ -> internal_error Unknown "unexpected resolved variable to function identifier" + | 0,s -> Ident s (* special case globals with no suffix. *) + | _,s -> Ident (s ^ "__" ^ string_of_int i) (** Returns the variable's name without mangling, suitable for disassembling then resolving again. WARNING: should only be used when variable is in the inner-most scope. *) -let var_expr_no_suffix_in_local_scope (Var(_,id)) = Expr_Var id +let var_expr_no_suffix_in_local_scope (Var(_,id)) = Expr_Var (Ident id) (** Returns an expression for the variable with a mangled name, suitable for emitting in a generated statement. *) @@ -141,16 +147,22 @@ module LocalEnv = struct - VUninitialized itself is only used for scalar types. thus, uninitialized structures must be expanded into structures of uninitialized scalars. *) - locals : (ty * sym) Bindings.t list; + locals : (ty * sym) StringMap.t list; returnSymbols : expr option list; numSymbols : int; indent : int; trace : dis_trace; } + let force i = + match i with Ident s -> s | _ -> unsupported Unknown "" + let pp_value_bindings = Utils.pp_list (pp_bindings pp_value) - let pp_sym_bindings (bss: (ty * sym) Bindings.t list) = + let pp_bindings (pp: 'a -> string) (bs: 'a StringMap.t): string = + String.concat ", " (List.map (fun (k, v) -> k ^"->"^ pp v) (StringMap.bindings bs)) + + let pp_sym_bindings (bss: (ty * sym) StringMap.t list) = Utils.pp_list (pp_bindings (fun (_,e) -> pp_sym e)) bss let init (env: Eval.Env.t) = @@ -174,10 +186,10 @@ module LocalEnv = struct in let globals = Bindings.mapi (fun id v -> (get_global_type id, Val v)) - globalsAndConsts - in + globalsAndConsts in + let globals = StringMap.of_seq @@ Seq.map ( fun (k,v) -> (force k,v)) @@ Bindings.to_seq globals in { - locals = [Bindings.empty ; globals]; + locals = [StringMap.empty ; globals]; returnSymbols = []; numSymbols = 0; indent = 0; @@ -192,7 +204,7 @@ module LocalEnv = struct let pp_locals (env: t): string = let last = List.length env.locals - 1 in let withoutGlobals = List.mapi - (fun i x -> if i = last then Bindings.empty else x) env.locals in + (fun i x -> if i = last then StringMap.empty else x) env.locals in Printf.sprintf "locals = %s" (pp_sym_bindings withoutGlobals) (* Printf.sprintf "locals = %s" (pp_sym_bindings env.locals) *) @@ -225,7 +237,7 @@ module LocalEnv = struct (** Adds a local scoping level within the current level. *) let addLevel (env: t): t = - {env with locals = (Bindings.empty)::env.locals} + {env with locals = (StringMap.empty)::env.locals} (** Pops the innermost scoping level. *) let popLevel (env: t): t = @@ -235,10 +247,11 @@ module LocalEnv = struct (** Adds a new local variable to the innermost scope. *) let addLocalVar (loc: l) (k: ident) (v: sym) (t: ty) (env: t): var * t = - if !Eval.trace_write then Printf.printf "TRACE: fresh %s = %s\n" (pprint_ident k) (pp_sym v); + let k = force k in + if !Eval.trace_write then Printf.printf "TRACE: fresh %s = %s\n" (k) (pp_sym v); let var = Var (List.length env.locals - 1, k) in match env.locals with - | (bs :: rest) -> var, {env with locals = (Bindings.add k (t,v) bs :: rest)} + | (bs :: rest) -> var, {env with locals = (StringMap.add k (t,v) bs :: rest)} | [] -> internal_error Unknown "attempt to add local var but no local scopes exist" let addLocalConst = addLocalVar @@ -247,24 +260,25 @@ module LocalEnv = struct let getVar (loc: l) (x: var) (env: t): (ty * sym) = let Var (i,id) = x in let n = List.length env.locals - i - 1 in - match Bindings.find_opt id (List.nth env.locals n) with + match StringMap.find_opt id (List.nth env.locals n) with | Some x -> x | None -> internal_error loc @@ "failed to get resolved variable: " ^ pp_var x (** Resolves then gets the type and value of a resolved variable. *) - let rec go loc x env i (bs: (ty * sym) Bindings.t list) = + let rec go loc x env i (bs: (ty * sym) StringMap.t list) = (match bs with - | [] -> internal_error loc @@ "cannot resolve undeclared variable: " ^ pprint_ident x ^ "\n\n" ^ pp_locals env - | b::rest -> match Bindings.find_opt x b with + | [] -> internal_error loc @@ "cannot resolve undeclared variable: " ^ x ^ "\n\n" ^ pp_locals env + | b::rest -> match StringMap.find_opt x b with | Some v -> (Var (i,x),v) | _ -> go loc x env (i - 1) rest) let resolveGetVar (loc: l) (x: ident) = fun env -> + let x = force x in let l = List.length env.locals - 1 in match env.locals with - | b::rest -> (match Bindings.find_opt x b with + | b::rest -> (match StringMap.find_opt x b with | Some v -> (Var (l,x),v) | _ -> go loc x env (l - 1) rest) - | _ -> internal_error loc @@ "cannot resolve undeclared variable: " ^ pprint_ident x ^ "\n\n" ^ pp_locals env + | _ -> internal_error loc @@ "cannot resolve undeclared variable: " ^ x ^ "\n\n" ^ pp_locals env (** Sets a resolved variable to the given value. *) let setVar (loc: l) (x: var) (v: sym) (env: t): t = @@ -272,7 +286,7 @@ module LocalEnv = struct let Var (i,id) = x in let n = List.length env.locals - i - 1 in let locals = Utils.nth_modify ( - Bindings.update id (fun e -> match e with + StringMap.update id (fun e -> match e with | Some (t,_) -> Some (t,v) | None -> internal_error loc @@ "failed to set resolved variable: " ^ pp_var x)) n env.locals in { env with locals } @@ -325,13 +339,13 @@ module DisEnv = struct let mkUninit (t: ty): value rws = reads (uninit t) - let merge_bindings env l r: (ty * sym) Bindings.t = + let merge_bindings env l r: (ty * sym) StringMap.t = if l == r then l else - Bindings.union (fun k (t1,v1) (t2,v2) -> + StringMap.union (fun k (t1,v1) (t2,v2) -> if !debug_level > 0 && t2 <> t1 then unsupported Unknown @@ Printf.sprintf "cannot merge locals with different types: %s, %s <> %s." - (pprint_ident k) (pp_type t1) (pp_type t2); + (k) (pp_type t1) (pp_type t2); let out = Some (t1,match v1 = v2 with | false -> Val (uninit t1 env) | true -> v1) in @@ -639,7 +653,7 @@ and type_of_load (loc: l) (x: expr): ty rws = and type_access_chain (loc: l) (var: var) (ref: access_chain list): ty rws = let Var (_,id) = var in - type_of_load loc (expr_access_chain (Expr_Var id) ref) + type_of_load loc (expr_access_chain (Expr_Var (Ident id)) ref) (** Disassemble type *) and dis_type (loc: l) (t: ty): ty rws = @@ -1038,7 +1052,7 @@ and dis_lexpr_chain (loc: l) (x: lexpr) (ref: access_chain list) (r: sym): unit fun fact - the only instructions i'm aware of that can actually do this don't work anyway *) let () = (match var, ref with - | Var(0, Ident("PSTATE")), ([Field(Ident("EL" | "SP" | "nRW"))]) -> + | Var(0, ("PSTATE")), ([Field(Ident("EL" | "SP" | "nRW"))]) -> unsupported loc @@ "Update to PSTATE EL/SP/nRW while disassembling" ^ pp_lexpr x; | _, _ -> () @@ -1047,11 +1061,11 @@ and dis_lexpr_chain (loc: l) (x: lexpr) (ref: access_chain list) (r: sym): unit DisEnv.modify (LocalEnv.setVar loc var (Val vv')) | [] -> (match var with - | Var(0, Ident("InGuardedPage")) -> + | Var(0, ("InGuardedPage")) -> unsupported loc @@ "Update to InGuardedPage while disassembling" ^ pp_lexpr x; - | Var(0, Ident("SCR_EL3")) -> + | Var(0, ("SCR_EL3")) -> unsupported loc @@ "Update to SCR_EL3 while disassembling" ^ pp_lexpr x; - | Var(0, Ident("SCTLR_EL1")) -> + | Var(0, ("SCTLR_EL1")) -> unsupported loc @@ "Update to SCTLR_EL1 while disassembling" ^ pp_lexpr x; | _ -> ()); @@ -1067,8 +1081,8 @@ and dis_lexpr_chain (loc: l) (x: lexpr) (ref: access_chain list) (r: sym): unit | _::_ -> (* variable contains a symbolic expression. read, modify, then write. *) let@ Var(_,tmp) = capture_expr_mutable loc t e in - let@ () = dis_lexpr_chain loc (LExpr_Var tmp) ref r in - let@ e' = dis_expr loc (Expr_Var tmp) in + let@ () = dis_lexpr_chain loc (LExpr_Var (Ident tmp)) ref r in + let@ e' = dis_expr loc (Expr_Var (Ident tmp)) in assign_var loc var e' | [] -> assign_var loc var r @@ -1455,7 +1469,7 @@ let build_env (env: Eval.Env.t): env = let loc = Unknown in (* get the pstate, then construct a new pstate where nRW=0, EL=0 & SP=0, then set the pstate *) - let (_, pstate) = LocalEnv.getVar loc (Var(0, Ident("PSTATE"))) lenv in + let (_, pstate) = LocalEnv.getVar loc (Var(0, ("PSTATE"))) lenv in let pstate = (match pstate with | Val(pstate_v) -> let pstate_v = set_access_chain loc pstate_v [Field(Ident("EL"))] (VBits({n=2; v=Z.zero;})) in @@ -1465,12 +1479,11 @@ let build_env (env: Eval.Env.t): env = | _ -> unsupported loc @@ "Initial env value of PSTATE is not a Value"; ) in - let lenv = LocalEnv.setVar loc (Var(0, Ident("PSTATE"))) (Val(pstate)) lenv in - let lenv = LocalEnv.setVar loc (Var(0, Ident("SCR_EL3"))) (Val(VBits({n=64; v=Z.zero;}))) lenv in - let lenv = LocalEnv.setVar loc (Var(0, Ident("SCTLR_EL1"))) (Val(VBits({n=64; v=Z.zero;}))) lenv in - + let lenv = LocalEnv.setVar loc (Var(0, ("PSTATE"))) (Val(pstate)) lenv in + let lenv = LocalEnv.setVar loc (Var(0, ("SCR_EL3"))) (Val(VBits({n=64; v=Z.zero;}))) lenv in + let lenv = LocalEnv.setVar loc (Var(0, ("SCTLR_EL1"))) (Val(VBits({n=64; v=Z.zero;}))) lenv in (* set InGuardedPage to false *) - let lenv = LocalEnv.setVar loc (Var(0, Ident("InGuardedPage"))) (Val (VBool false)) lenv in + let lenv = LocalEnv.setVar loc (Var(0, ("InGuardedPage"))) (Val (VBool false)) lenv in let globals = IdentSet.of_list @@ List.map fst @@ Bindings.bindings (Eval.Env.readGlobals env) in lenv, globals @@ -1479,8 +1492,8 @@ let build_env (env: Eval.Env.t): env = Assumes variable is named _PC and its represented as a bitvector. *) let setPC (env: Eval.Env.t) (lenv,g: env) (address: Z.t): env = let loc = Unknown in - let pc = Ident "_PC" in - let width = (match Eval.Env.getVar loc env pc with + let pc = "_PC" in + let width = (match Eval.Env.getVar loc env (Ident pc) with | VUninitialized ty -> width_of_type loc ty | VBits b -> b.n | _ -> unsupported loc @@ "Initial env contains PC with unexpected type") in From 1c15244ec04b79387f77fb39ac09513f33441712 Mon Sep 17 00:00:00 2001 From: Nicholas Coughlin Date: Fri, 31 Mar 2023 11:05:13 +1000 Subject: [PATCH 5/5] Optimise monadic functions --- libASL/dis.ml | 71 ++++++++++++++++++++++++++++--------------------- libASL/monad.ml | 35 ------------------------ libASL/rws.ml | 40 ++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 66 deletions(-) diff --git a/libASL/dis.ml b/libASL/dis.ml index c5c65ecf..a34f832a 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -324,9 +324,9 @@ module DisEnv = struct open Let - let getVar (loc: l) (x: ident): (ty * sym) rws = - let+ (_,v) = gets (LocalEnv.resolveGetVar loc x) in - v + let getVar (loc: l) (x: ident): (ty * sym) rws = fun env s -> + let (_,v) = LocalEnv.resolveGetVar loc x s in + (v,s,empty) let uninit (t: ty) (env: Eval.Env.t): value = try @@ -351,27 +351,26 @@ module DisEnv = struct | true -> v1) in out) l r - let join_locals (l: LocalEnv.t) (r: LocalEnv.t): unit rws = - let* env = read in + let join_locals (l: LocalEnv.t) (r: LocalEnv.t): unit rws = fun env s -> assert (l.returnSymbols = r.returnSymbols); assert (l.indent = r.indent); assert (l.trace = r.trace); let locals' = List.map2 (merge_bindings env) l.locals r.locals in - put { + let s : LocalEnv.t = { locals = locals'; returnSymbols = l.returnSymbols; numSymbols = max l.numSymbols r.numSymbols; indent = l.indent; trace = l.trace; - } - + } in + ((),s,empty) let getFun (loc: l) (x: ident): Eval.fun_sig option rws = reads (fun env -> Eval.Env.getFunOpt loc env x) - let nextVarName (prefix: string): ident rws = - let+ num = stateful LocalEnv.incNumSymbols in - Ident (prefix ^ string_of_int num) + let nextVarName (prefix: string): ident rws = fun env s -> + let num, s = LocalEnv.incNumSymbols s in + (Ident (prefix ^ string_of_int num),s,empty) let indent: string rws = let+ i = gets (fun l -> l.indent) in @@ -436,13 +435,24 @@ end type 'a rws = 'a DisEnv.rws -let (let@) = DisEnv.Let.(let*) -let (and@) = DisEnv.Let.(and*) -let (let+) = DisEnv.Let.(let+) -let (and+) = DisEnv.Let.(and+) +let (let@) x f = fun env s -> + let (r,s,w) = x env s in + let (r',s,w') = (f r) env s in + (r',s,append w w') + +let (let+) x f = fun env s -> + let (r,s,w) = x env s in + (f r,s,w) -let (>>) = DisEnv.(>>) -let (>>=) = DisEnv.(>>=) +let (>>) x f = fun env s -> + let (_,s,w) = x env s in + let (r',s,w') = f env s in + (r',s,append w w') + +let (>>=) x f = fun env s -> + let (r,s,w) = x env s in + let (r',s,w') = (f r) env s in + (r',s,append w w') (** Convert value to a simple expression containing that value, so we can print it or use it symbolically *) @@ -462,7 +472,6 @@ let is_expr (v: sym): bool = | Exp _ -> true let declare_var (loc: l) (t: ty) (i: ident): var rws = - let@ env = DisEnv.read in let@ uninit = DisEnv.mkUninit t in let@ var = DisEnv.stateful (LocalEnv.addLocalVar loc i (Val uninit) t) in @@ -694,8 +703,8 @@ and dis_pattern (loc: l) (v: sym) (x: AST.pattern): sym rws = let+ v' = dis_expr loc e in sym_eq loc v v' | Pat_Range(lo, hi) -> - let+ lo' = dis_expr loc lo - and+ hi' = dis_expr loc hi in + let@ lo' = dis_expr loc lo in + let+ hi' = dis_expr loc hi in sym_and_bool loc (sym_le_int loc lo' v) (sym_le_int loc v hi') ) @@ -706,13 +715,13 @@ and dis_slice (loc: l) (x: slice): (sym * sym) rws = let+ i' = dis_expr loc i in (i', Val (VInt Z.one)) | Slice_HiLo(hi, lo) -> - let+ hi' = dis_expr loc hi - and+ lo' = dis_expr loc lo in + let@ hi' = dis_expr loc hi in + let+ lo' = dis_expr loc lo in let wd' = sym_add_int loc (sym_sub_int loc hi' lo') (Val (VInt Z.one)) in (lo', wd') | Slice_LoWd(lo, wd) -> - let+ lo' = dis_expr loc lo - and+ wd' = dis_expr loc wd in + let@ lo' = dis_expr loc lo in + let+ wd' = dis_expr loc wd in (lo', wd') ) @@ -796,10 +805,10 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws = let+ vs = DisEnv.traverse (fun f -> dis_load loc (Expr_Field(e,f))) fs in sym_concat loc vs | Expr_Slices(e, ss) -> - let@ e' = dis_expr loc e - and@ ss' = DisEnv.traverse (dis_slice loc) ss in + let@ e' = dis_expr loc e in + let+ ss' = DisEnv.traverse (dis_slice loc) ss in let vs = List.map (fun (i,w) -> sym_extract_bits loc e' i w) ss' in - DisEnv.pure (sym_concat loc vs) + sym_concat loc vs | Expr_In(e, p) -> let@ e' = dis_expr loc e in let@ p' = dis_pattern loc e' p in @@ -907,7 +916,7 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws (match rty with | Some rty -> let@ () = DisEnv.modify LocalEnv.addLevel in - let@ () = DisEnv.sequence_ @@ List.map2 (fun arg e -> + let@ () = DisEnv.traverse2_ (fun arg e -> declare_const loc type_integer arg e ) targs tes in @@ -930,7 +939,7 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws assert (List.length targs == List.length tes); (* Assign targs := tes *) - let@ () = DisEnv.sequence_ @@ List.map2 (fun arg e -> + let@ () = DisEnv.traverse2_ (fun arg e -> declare_const loc type_integer arg e ) targs tes in @@ -938,10 +947,10 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws assert (List.length atys == List.length es); (* Assign args := es *) - let@ () = DisEnv.sequence_ (Utils.map3 (fun (ty, _) arg e -> + let@ () = DisEnv.traverse3_ (fun (ty, _) arg e -> let@ ty' = dis_type loc ty in declare_const loc ty' arg e - ) atys args es) in + ) atys args es in (* Create return variable (if necessary). This is in the inner scope to allow for type parameters. *) diff --git a/libASL/monad.ml b/libASL/monad.ml index 5a58f14c..8106fbf7 100644 --- a/libASL/monad.ml +++ b/libASL/monad.ml @@ -42,41 +42,6 @@ module Make (M : S) = struct let (and*) = (and+) end - open Let - - (* higher-order functions and transformations *) - - (** Performs a list of computations in sequence, resulting in a list - of their results. *) - let rec sequence (xs: 'a m list): 'a list m = - match xs with - | (x::xs) -> - let+ x = x - and+ xs = sequence xs in - (x :: xs) - | [] -> pure [] - - (** Performs a list of computations in sequence and discard their results - (but retains their monad effects). *) - let sequence_ (xs : 'a m list): unit m = - let+ _ = sequence xs in () - - (** Uses the given function to create a list of computations which are - then run sequentially. Results in a list of their results. *) - let traverse (f: 'a -> 'b m) (x: 'a list): 'b list m = - sequence (List.map f x) - - let traverse2 (f: 'a -> 'b -> 'c m) (x: 'a list) (y: 'b list): 'c list m = - sequence (List.map2 f x y) - - (** Uses the given function to create a list of computations which are - then run sequentually. Discards their results. *) - let traverse_ (f: 'a -> 'b m) (x: 'a list): unit m = - let+ _ = sequence (List.map f x) in () - - let traverse2_ (f: 'a -> 'b -> 'c m) (x: 'a list) (y: 'b list): unit m = - let+ _ = sequence (List.map2 f x y) in () - (** A nil computation. Does nothing and returns nothing of interest. *) let unit: unit m = pure () diff --git a/libASL/rws.ml b/libASL/rws.ml index c81351d4..2a0766af 100644 --- a/libASL/rws.ml +++ b/libASL/rws.ml @@ -122,6 +122,46 @@ module RWSBase (T : S) = struct let bt = Printexc.get_raw_backtrace () in (Error (e, bt), s, mempty) + let rec traverse (f: 'a -> 'b rws) (x: 'a list) (r: r) (s: s) = + match x with + | [] -> ([],s,mempty) + | x::xs -> + let (i,s,w) = (f x) r s in + let (is,s,w') = traverse f xs r s in + (i::is,s,mappend w w') + + let rec traverse_r (w: w) (f: 'a -> 'b rws) (x: 'a list) (r: r) (s: s) = + match x with + | [] -> ((),s,w) + | x::xs -> + let (_,s',w') = (f x) r s in + traverse_r (mappend w w') f xs r s' + + let traverse_ (f: 'a -> 'b rws) (x: 'a list): unit rws = + traverse_r mempty f x + + let rec traverse2_r (w: w) (f: 'a -> 'b -> 'c rws) (x: 'a list) (y: 'b list) (r: r) (s: s) = + match x, y with + | [], [] -> ((),s,w) + | x::xs, y::ys -> + let (_,s',w') = (f x y) r s in + traverse2_r (mappend w w') f xs ys r s' + | _ -> invalid_arg "traverse2_" + + let traverse2_ (f: 'a -> 'b -> 'c rws) (x: 'a list) (y: 'b list): unit rws = + traverse2_r mempty f x y + + let rec traverse3_r (w: w) (f: 'a -> 'b -> 'c -> 'd rws) (x: 'a list) (y: 'b list) (z: 'c list) (r: r) (s: s) = + match x, y, z with + | [], [], [] -> ((),s,w) + | x::xs, y::ys, z::zs -> + let (_,s',w') = (f x y z) r s in + traverse3_r (mappend w w') f xs ys zs r s' + | _ -> invalid_arg "traverse3_" + + let traverse3_ (f: 'a -> 'b -> 'c -> 'd rws) (x: 'a list) (y: 'b list) (z: 'c list): unit rws = + traverse3_r mempty f x y z + end (** Constructs a RWS monad using the given signature. *)