From aceb49cb068cabd512ec40696ced1c0e6beed7fb Mon Sep 17 00:00:00 2001 From: Nicholas Coughlin Date: Fri, 31 Mar 2023 11:05:13 +1000 Subject: [PATCH] Optimise monadic functions --- libASL/dis.ml | 71 ++++++++++++++++++++++++++++--------------------- libASL/monad.ml | 35 ------------------------ libASL/rws.ml | 40 ++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 66 deletions(-) diff --git a/libASL/dis.ml b/libASL/dis.ml index d851b4a7..fd8a6db2 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -324,9 +324,9 @@ module DisEnv = struct open Let - let getVar (loc: l) (x: ident): (ty * sym) rws = - let+ (_,v) = gets (LocalEnv.resolveGetVar loc x) in - v + let getVar (loc: l) (x: ident): (ty * sym) rws = fun env s -> + let (_,v) = LocalEnv.resolveGetVar loc x s in + (v,s,empty) let uninit (t: ty) (env: Eval.Env.t): value = try @@ -351,27 +351,26 @@ module DisEnv = struct | true -> v1) in out) l r - let join_locals (l: LocalEnv.t) (r: LocalEnv.t): unit rws = - let* env = read in + let join_locals (l: LocalEnv.t) (r: LocalEnv.t): unit rws = fun env s -> assert (l.returnSymbols = r.returnSymbols); assert (l.indent = r.indent); assert (l.trace = r.trace); let locals' = List.map2 (merge_bindings env) l.locals r.locals in - put { + let s : LocalEnv.t = { locals = locals'; returnSymbols = l.returnSymbols; numSymbols = max l.numSymbols r.numSymbols; indent = l.indent; trace = l.trace; - } - + } in + ((),s,empty) let getFun (loc: l) (x: ident): Eval.fun_sig option rws = reads (fun env -> Eval.Env.getFunOpt loc env x) - let nextVarName (prefix: string): ident rws = - let+ num = stateful LocalEnv.incNumSymbols in - Ident (prefix ^ string_of_int num) + let nextVarName (prefix: string): ident rws = fun env s -> + let num, s = LocalEnv.incNumSymbols s in + (Ident (prefix ^ string_of_int num),s,empty) let indent: string rws = let+ i = gets (fun l -> l.indent) in @@ -436,13 +435,24 @@ end type 'a rws = 'a DisEnv.rws -let (let@) = DisEnv.Let.(let*) -let (and@) = DisEnv.Let.(and*) -let (let+) = DisEnv.Let.(let+) -let (and+) = DisEnv.Let.(and+) +let (let@) x f = fun env s -> + let (r,s,w) = x env s in + let (r',s,w') = (f r) env s in + (r',s,append w w') + +let (let+) x f = fun env s -> + let (r,s,w) = x env s in + (f r,s,w) -let (>>) = DisEnv.(>>) -let (>>=) = DisEnv.(>>=) +let (>>) x f = fun env s -> + let (_,s,w) = x env s in + let (r',s,w') = f env s in + (r',s,append w w') + +let (>>=) x f = fun env s -> + let (r,s,w) = x env s in + let (r',s,w') = (f r) env s in + (r',s,append w w') (** Convert value to a simple expression containing that value, so we can print it or use it symbolically *) @@ -462,7 +472,6 @@ let is_expr (v: sym): bool = | Exp _ -> true let declare_var (loc: l) (t: ty) (i: ident): var rws = - let@ env = DisEnv.read in let@ uninit = DisEnv.mkUninit t in let@ var = DisEnv.stateful (LocalEnv.addLocalVar loc i (Val uninit) t) in @@ -694,8 +703,8 @@ and dis_pattern (loc: l) (v: sym) (x: AST.pattern): sym rws = let+ v' = dis_expr loc e in sym_eq loc v v' | Pat_Range(lo, hi) -> - let+ lo' = dis_expr loc lo - and+ hi' = dis_expr loc hi in + let@ lo' = dis_expr loc lo in + let+ hi' = dis_expr loc hi in sym_and_bool loc (sym_le_int loc lo' v) (sym_le_int loc v hi') ) @@ -706,13 +715,13 @@ and dis_slice (loc: l) (x: slice): (sym * sym) rws = let+ i' = dis_expr loc i in (i', Val (VInt Z.one)) | Slice_HiLo(hi, lo) -> - let+ hi' = dis_expr loc hi - and+ lo' = dis_expr loc lo in + let@ hi' = dis_expr loc hi in + let+ lo' = dis_expr loc lo in let wd' = sym_add_int loc (sym_sub_int loc hi' lo') (Val (VInt Z.one)) in (lo', wd') | Slice_LoWd(lo, wd) -> - let+ lo' = dis_expr loc lo - and+ wd' = dis_expr loc wd in + let@ lo' = dis_expr loc lo in + let+ wd' = dis_expr loc wd in (lo', wd') ) @@ -796,10 +805,10 @@ and dis_expr' (loc: l) (x: AST.expr): sym rws = let+ vs = DisEnv.traverse (fun f -> dis_load loc (Expr_Field(e,f))) fs in sym_concat loc vs | Expr_Slices(e, ss) -> - let@ e' = dis_expr loc e - and@ ss' = DisEnv.traverse (dis_slice loc) ss in + let@ e' = dis_expr loc e in + let+ ss' = DisEnv.traverse (dis_slice loc) ss in let vs = List.map (fun (i,w) -> sym_extract_bits loc e' i w) ss' in - DisEnv.pure (sym_concat loc vs) + sym_concat loc vs | Expr_In(e, p) -> let@ e' = dis_expr loc e in let@ p' = dis_pattern loc e' p in @@ -907,7 +916,7 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws (match rty with | Some rty -> let@ () = DisEnv.modify LocalEnv.addLevel in - let@ () = DisEnv.sequence_ @@ List.map2 (fun arg e -> + let@ () = DisEnv.traverse2_ (fun arg e -> declare_const loc type_integer arg e ) targs tes in @@ -930,7 +939,7 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws assert (List.length targs == List.length tes); (* Assign targs := tes *) - let@ () = DisEnv.sequence_ @@ List.map2 (fun arg e -> + let@ () = DisEnv.traverse2_ (fun arg e -> declare_const loc type_integer arg e ) targs tes in @@ -938,10 +947,10 @@ and dis_call' (loc: l) (f: ident) (tes: sym list) (es: sym list): sym option rws assert (List.length atys == List.length es); (* Assign args := es *) - let@ () = DisEnv.sequence_ (Utils.map3 (fun (ty, _) arg e -> + let@ () = DisEnv.traverse3_ (fun (ty, _) arg e -> let@ ty' = dis_type loc ty in declare_const loc ty' arg e - ) atys args es) in + ) atys args es in (* Create return variable (if necessary). This is in the inner scope to allow for type parameters. *) diff --git a/libASL/monad.ml b/libASL/monad.ml index 5a58f14c..8106fbf7 100644 --- a/libASL/monad.ml +++ b/libASL/monad.ml @@ -42,41 +42,6 @@ module Make (M : S) = struct let (and*) = (and+) end - open Let - - (* higher-order functions and transformations *) - - (** Performs a list of computations in sequence, resulting in a list - of their results. *) - let rec sequence (xs: 'a m list): 'a list m = - match xs with - | (x::xs) -> - let+ x = x - and+ xs = sequence xs in - (x :: xs) - | [] -> pure [] - - (** Performs a list of computations in sequence and discard their results - (but retains their monad effects). *) - let sequence_ (xs : 'a m list): unit m = - let+ _ = sequence xs in () - - (** Uses the given function to create a list of computations which are - then run sequentially. Results in a list of their results. *) - let traverse (f: 'a -> 'b m) (x: 'a list): 'b list m = - sequence (List.map f x) - - let traverse2 (f: 'a -> 'b -> 'c m) (x: 'a list) (y: 'b list): 'c list m = - sequence (List.map2 f x y) - - (** Uses the given function to create a list of computations which are - then run sequentually. Discards their results. *) - let traverse_ (f: 'a -> 'b m) (x: 'a list): unit m = - let+ _ = sequence (List.map f x) in () - - let traverse2_ (f: 'a -> 'b -> 'c m) (x: 'a list) (y: 'b list): unit m = - let+ _ = sequence (List.map2 f x y) in () - (** A nil computation. Does nothing and returns nothing of interest. *) let unit: unit m = pure () diff --git a/libASL/rws.ml b/libASL/rws.ml index c81351d4..2a0766af 100644 --- a/libASL/rws.ml +++ b/libASL/rws.ml @@ -122,6 +122,46 @@ module RWSBase (T : S) = struct let bt = Printexc.get_raw_backtrace () in (Error (e, bt), s, mempty) + let rec traverse (f: 'a -> 'b rws) (x: 'a list) (r: r) (s: s) = + match x with + | [] -> ([],s,mempty) + | x::xs -> + let (i,s,w) = (f x) r s in + let (is,s,w') = traverse f xs r s in + (i::is,s,mappend w w') + + let rec traverse_r (w: w) (f: 'a -> 'b rws) (x: 'a list) (r: r) (s: s) = + match x with + | [] -> ((),s,w) + | x::xs -> + let (_,s',w') = (f x) r s in + traverse_r (mappend w w') f xs r s' + + let traverse_ (f: 'a -> 'b rws) (x: 'a list): unit rws = + traverse_r mempty f x + + let rec traverse2_r (w: w) (f: 'a -> 'b -> 'c rws) (x: 'a list) (y: 'b list) (r: r) (s: s) = + match x, y with + | [], [] -> ((),s,w) + | x::xs, y::ys -> + let (_,s',w') = (f x y) r s in + traverse2_r (mappend w w') f xs ys r s' + | _ -> invalid_arg "traverse2_" + + let traverse2_ (f: 'a -> 'b -> 'c rws) (x: 'a list) (y: 'b list): unit rws = + traverse2_r mempty f x y + + let rec traverse3_r (w: w) (f: 'a -> 'b -> 'c -> 'd rws) (x: 'a list) (y: 'b list) (z: 'c list) (r: r) (s: s) = + match x, y, z with + | [], [], [] -> ((),s,w) + | x::xs, y::ys, z::zs -> + let (_,s',w') = (f x y z) r s in + traverse3_r (mappend w w') f xs ys zs r s' + | _ -> invalid_arg "traverse3_" + + let traverse3_ (f: 'a -> 'b -> 'c -> 'd rws) (x: 'a list) (y: 'b list) (z: 'c list): unit rws = + traverse3_r mempty f x y z + end (** Constructs a RWS monad using the given signature. *)