diff --git a/libASL/dis.ml b/libASL/dis.ml index 86837da7..9514ebca 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -279,13 +279,33 @@ module LocalEnv = struct end +type tree = + Node of stmt list + | Branch of tree * tree + +let empty = Node [] +let single x = Node x +let append x y = + match x, y with + | Node [], _ -> y + | _, Node [] -> x + | Node [x], Node y -> Node (x::y) + | _ -> Branch (x, y) + +let rec flatten x acc = + match x with + | Branch (x,y) -> + let acc = flatten y acc in + flatten x acc + | Node i -> i@acc + module DisEnv = struct include Rws.Make(struct type r = Eval.Env.t - type w = stmt list + type w = tree type s = LocalEnv.t - let mempty = [] - let mappend = (@) + let mempty = empty + let mappend = (append) end) open Let @@ -362,6 +382,8 @@ module DisEnv = struct let warn s = debug 0 ("WARNING: " ^ s) + let write i = fun e x -> ((),x,Node i) + let scope (loc: l) (name: string) (arg: string) (pp: 'a -> string) (x: 'a rws): 'a rws = (* logging header. looks like: +- dis_expr --> 1 + 1. *) log (Printf.sprintf "\u{256d}\u{2500} %s --> %s" name arg) >> @@ -373,6 +395,7 @@ module DisEnv = struct (* run computation but obtain state and writer to output in debugging. *) let* (result,s',w') = locally (catcherror x) in + let w' = flatten w' [] in let x' = (match result with | Error ((DisTrace _) as e, bt) -> raise e | Error (exn, bt) -> Printexc.raise_with_backtrace (DisTrace (trace, exn)) bt @@ -394,6 +417,7 @@ module DisEnv = struct else unit in pure x' + end type 'a rws = 'a DisEnv.rws @@ -518,12 +542,14 @@ let rec sym_if (loc: l) (t: ty) (test: sym rws) (tcase: sym rws) (fcase: sym rws (* Evaluate true branch statements. *) let@ (tenv,tstmts) = DisEnv.locally_ (tcase >>= assign_var loc tmp) in + let tstmts = flatten tstmts [] in (* Propagate incremented counter to env'. *) let@ env' = DisEnv.gets (fun env -> LocalEnv.sequence_merge env tenv) in (* Execute false branch statements with env'. *) let@ (fenv,fstmts) = DisEnv.locally_ (DisEnv.put env' >> fcase >>= assign_var loc tmp) in let@ () = DisEnv.join_locals tenv fenv in + let fstmts = flatten fstmts [] in match ite_body tstmts tmp, ite_body fstmts tmp with | Some (Val _ as te), Some fe | Some te, Some (Val _ as fe) -> @@ -545,9 +571,11 @@ and unit_if (loc: l) (test: sym rws) (tcase: unit rws) (fcase: unit rws): unit r | Val _ -> failwith ("Split on non-boolean value") | Exp e -> let@ (tenv,tstmts) = DisEnv.locally_ tcase in + let tstmts = flatten tstmts [] in let@ env' = DisEnv.gets (fun env -> LocalEnv.sequence_merge env tenv) in let@ (fenv,fstmts) = DisEnv.locally_ (DisEnv.put env' >> fcase) in + let fstmts = flatten fstmts [] in let@ () = DisEnv.join_locals tenv fenv in DisEnv.write [Stmt_If(e, tstmts, [], fstmts, loc)]) @@ -1373,6 +1401,7 @@ and dis_decode_alt' (loc: AST.l) (DecoderAlt_Alt (ps, b)) (vs: value list) (op: let@ () = dis_stmts exec in DisEnv.modify (LocalEnv.popLevel) ) in + let stmts = flatten stmts [] in if !debug_level >= 2 then begin Printf.printf "-----------\n"; @@ -1404,6 +1433,7 @@ type env = (LocalEnv.t * IdentSet.t) let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_case) (op: value): stmt list = let DecoderCase_Case (_,_,loc) = decode in let ((),lenv',stmts) = (dis_decode_case loc decode op) env lenv in + let stmts = flatten stmts [] in let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts in let stmts' = Transforms.RedundantSlice.do_transform stmts' in let stmts' = Transforms.StatefulIntToBits.run stmts' in