Skip to content

Commit

Permalink
Add tree structure for writer
Browse files Browse the repository at this point in the history
Avoid repeated appends by building a tree
and flattening later
  • Loading branch information
ncough authored and katrinafyi committed Jan 16, 2024
1 parent 999efb2 commit 3fc7aee
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions libASL/dis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,33 @@ module LocalEnv = struct

end

type tree =
Node of stmt list
| Branch of tree * tree

let empty = Node []
let single x = Node x
let append x y =
match x, y with
| Node [], _ -> y
| _, Node [] -> x
| Node [x], Node y -> Node (x::y)
| _ -> Branch (x, y)

let rec flatten x acc =
match x with
| Branch (x,y) ->
let acc = flatten y acc in
flatten x acc
| Node i -> i@acc

module DisEnv = struct
include Rws.Make(struct
type r = Eval.Env.t
type w = stmt list
type w = tree
type s = LocalEnv.t
let mempty = []
let mappend = (@)
let mempty = empty
let mappend = (append)
end)

open Let
Expand Down Expand Up @@ -362,6 +382,8 @@ module DisEnv = struct

let warn s = debug 0 ("WARNING: " ^ s)

let write i = fun e x -> ((),x,Node i)

let scope (loc: l) (name: string) (arg: string) (pp: 'a -> string) (x: 'a rws): 'a rws =
(* logging header. looks like: +- dis_expr --> 1 + 1. *)
log (Printf.sprintf "\u{256d}\u{2500} %s --> %s" name arg) >>
Expand All @@ -373,6 +395,7 @@ module DisEnv = struct

(* run computation but obtain state and writer to output in debugging. *)
let* (result,s',w') = locally (catcherror x) in
let w' = flatten w' [] in
let x' = (match result with
| Error ((DisTrace _) as e, bt) -> raise e
| Error (exn, bt) -> Printexc.raise_with_backtrace (DisTrace (trace, exn)) bt
Expand All @@ -394,6 +417,7 @@ module DisEnv = struct
else unit
in
pure x'

end

type 'a rws = 'a DisEnv.rws
Expand Down Expand Up @@ -518,12 +542,14 @@ let rec sym_if (loc: l) (t: ty) (test: sym rws) (tcase: sym rws) (fcase: sym rws
(* Evaluate true branch statements. *)
let@ (tenv,tstmts) = DisEnv.locally_
(tcase >>= assign_var loc tmp) in
let tstmts = flatten tstmts [] in
(* Propagate incremented counter to env'. *)
let@ env' = DisEnv.gets (fun env -> LocalEnv.sequence_merge env tenv) in
(* Execute false branch statements with env'. *)
let@ (fenv,fstmts) = DisEnv.locally_
(DisEnv.put env' >> fcase >>= assign_var loc tmp) in
let@ () = DisEnv.join_locals tenv fenv in
let fstmts = flatten fstmts [] in
match ite_body tstmts tmp, ite_body fstmts tmp with
| Some (Val _ as te), Some fe
| Some te, Some (Val _ as fe) ->
Expand All @@ -545,9 +571,11 @@ and unit_if (loc: l) (test: sym rws) (tcase: unit rws) (fcase: unit rws): unit r
| Val _ -> failwith ("Split on non-boolean value")
| Exp e ->
let@ (tenv,tstmts) = DisEnv.locally_ tcase in
let tstmts = flatten tstmts [] in

let@ env' = DisEnv.gets (fun env -> LocalEnv.sequence_merge env tenv) in
let@ (fenv,fstmts) = DisEnv.locally_ (DisEnv.put env' >> fcase) in
let fstmts = flatten fstmts [] in

let@ () = DisEnv.join_locals tenv fenv in
DisEnv.write [Stmt_If(e, tstmts, [], fstmts, loc)])
Expand Down Expand Up @@ -1373,6 +1401,7 @@ and dis_decode_alt' (loc: AST.l) (DecoderAlt_Alt (ps, b)) (vs: value list) (op:
let@ () = dis_stmts exec in
DisEnv.modify (LocalEnv.popLevel)
) in
let stmts = flatten stmts [] in

if !debug_level >= 2 then begin
Printf.printf "-----------\n";
Expand Down Expand Up @@ -1404,6 +1433,7 @@ type env = (LocalEnv.t * IdentSet.t)
let dis_decode_entry (env: Eval.Env.t) ((lenv,globals): env) (decode: decode_case) (op: value): stmt list =
let DecoderCase_Case (_,_,loc) = decode in
let ((),lenv',stmts) = (dis_decode_case loc decode op) env lenv in
let stmts = flatten stmts [] in
let stmts' = Transforms.RemoveUnused.remove_unused globals @@ stmts in
let stmts' = Transforms.RedundantSlice.do_transform stmts' in
let stmts' = Transforms.StatefulIntToBits.run stmts' in
Expand Down

0 comments on commit 3fc7aee

Please sign in to comment.