From 4899c72c1b3004c547f9db3ba6ce65d02cd490e5 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 22 May 2024 13:47:51 +1000 Subject: [PATCH] Merged backends (#80) * Add scala backend for BASIL * Add C++ Backend for LLVM * add aslBackwardsVisitor * aslVisitor: change vstmt to return stmt list BREAKING! This change affects the signature of the Asl_visitor.visit_stmt method. For compatibility, a visit_stmt_single method is provided with equivalent behaviour to the old visit_stmt. There is also an added helper function to convert visitActions on single statements to visitActions on a list of statements. Both of these compatibility helpers WILL THROW if used with a visitor that returns non-singleton statement lists. This gives the user the flexibility to insert new statements or delete statements entirely. On the other hand, post-functions in ChangeDoChildrenPost will need to handle lists of functions as well. This follows the original CIL visitor: https://people.eecs.berkeley.edu/~necula/cil/api/Cil.cilVisitor.html * fix backwards visitor and rearrange code it is no longer a good idea for the backwards and forwards visitors to have a subtyping relation. * support -x 0 to print encoding name. (#78) this is very useful when looking for the name of an encoding, without cluttering the output with the disassembly trace. the default debug_level has been lowered to -1 to support -x 0 as a non-default level. we cannot print by default since that would clutter stdout when used as a library. Co-authored-by: rina --- .gitignore | 3 + bin/asli.ml | 5 +- libASL/asl_utils.ml | 6 +- libASL/asl_visitor.ml | 648 +++++++++++++---------- libASL/cpp_backend.ml | 553 +++++++++++++++++++ libASL/cpu.ml | 5 +- libASL/cpu.mli | 1 + libASL/dis.ml | 4 +- libASL/dune | 6 +- libASL/offline_opt.ml | 4 +- libASL/scala_backend.ml | 787 ++++++++++++++++++++++++++++ libASL/symbolic_lifter.ml | 6 +- libASL/transforms.ml | 23 +- libASL/utils.ml | 9 + offlineASL-cpp/.gitignore | 3 + offlineASL-cpp/aslp_lifter.hpp | 1 + offlineASL-cpp/aslp_lifter_impl.hpp | 1 + offlineASL-cpp/build.sh | 19 + offlineASL-cpp/dune | 6 + offlineASL-cpp/interface.hpp | 168 ++++++ scalaOfflineASL/utils.scala | 505 ++++++++++++++++++ 21 files changed, 2464 insertions(+), 299 deletions(-) create mode 100644 libASL/cpp_backend.ml create mode 100644 libASL/scala_backend.ml create mode 100644 offlineASL-cpp/.gitignore create mode 100644 offlineASL-cpp/aslp_lifter.hpp create mode 100644 offlineASL-cpp/aslp_lifter_impl.hpp create mode 100755 offlineASL-cpp/build.sh create mode 100644 offlineASL-cpp/dune create mode 100644 offlineASL-cpp/interface.hpp create mode 100644 scalaOfflineASL/utils.scala diff --git a/.gitignore b/.gitignore index 184db933..b9334234 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ scripts/result.txt scripts/output.txt scripts/cntlm.bir scripts/total.txt + +.cache/ +build/ diff --git a/bin/asli.ml b/bin/asli.ml index e6e38437..0f01751b 100644 --- a/bin/asli.ml +++ b/bin/asli.ml @@ -26,7 +26,7 @@ let opt_no_default_aarch64 = ref false let opt_print_aarch64_dir = ref false let opt_verbose = ref false -let opt_debug_level = ref 0 +let opt_debug_level = ref (-1) let () = Printexc.register_printer @@ -60,6 +60,7 @@ let help_msg = [ let gen_backends = [ ("ocaml", (Cpu.Ocaml, "offlineASL")); ("cpp", (Cpu.Cpp, "offlineASL-cpp")); + ("scala", (Cpu.Scala, "offlineASL-cpp")); ] let flags = [ @@ -325,7 +326,7 @@ let rec repl (tcenv: TC.Env.t) (cpu: Cpu.cpu): unit = ) let options = Arg.align ([ - ( "-x", Arg.Set_int opt_debug_level, " Debugging output"); + ( "-x", Arg.Set_int opt_debug_level, " Partial evaluation debugging (requires debug level argument >= 0)"); ( "-v", Arg.Set opt_verbose, " Verbose output"); ( "--no-aarch64", Arg.Set opt_no_default_aarch64 , " Disable bundled AArch64 semantics"); ( "--aarch64-dir", Arg.Set opt_print_aarch64_dir, " Print directory of bundled AArch64 semantics"); diff --git a/libASL/asl_utils.ml b/libASL/asl_utils.ml index e3f335c9..a282ebc1 100644 --- a/libASL/asl_utils.ml +++ b/libASL/asl_utils.ml @@ -226,7 +226,7 @@ let fv_stmts stmts = let fv_stmt stmt = let fvs = new freevarClass in - ignore (visit_stmt (fvs :> aslVisitor) stmt); + ignore (visit_stmt_single (fvs :> aslVisitor) stmt); fvs#result let fv_decl decl = @@ -295,7 +295,7 @@ end let locals_of_stmts stmts = let lc = new localsClass in - ignore (Visitor.mapNoCopy (visit_stmt (lc :> aslVisitor)) stmts); + ignore @@ Asl_visitor.visit_stmts lc stmts; lc#locals let locals_of_decl decl = @@ -423,7 +423,7 @@ let subst_type (s: expr Bindings.t) (x: ty): ty = let subst_stmt (s: expr Bindings.t) (x: stmt): stmt = let subst = new substClass s in - visit_stmt subst x + visit_stmt_single subst x (** More flexible substitution class - takes a function instead diff --git a/libASL/asl_visitor.ml b/libASL/asl_visitor.ml index 3d9711a2..efbf5481 100644 --- a/libASL/asl_visitor.ml +++ b/libASL/asl_visitor.ml @@ -32,7 +32,7 @@ class type aslVisitor = object method vtype : ty -> ty visitAction method vlvar : ident -> ident visitAction method vlexpr : lexpr -> lexpr visitAction - method vstmt : stmt -> stmt visitAction + method vstmt : stmt -> stmt list visitAction method vs_elsif : s_elsif -> s_elsif visitAction method valt : alt -> alt visitAction method vcatcher : catcher -> catcher visitAction @@ -49,6 +49,20 @@ class type aslVisitor = object method leave_scope : unit -> unit end +(** Converts a visitAction on single values to an action on lists. + The generated visitAction will throw if given a non-singleton list. *) +let singletonVisitAction (a: 'a visitAction) : 'a list visitAction = + let listpost post : 'a list -> 'a list = function + | [x] -> [post x] + | xs -> + let len = string_of_int @@ List.length xs in + failwith @@ "this ChangeDoChildrenPost handles single values only, but was given a list of " ^ len ^ " items" + in match a with + | ChangeTo x -> ChangeTo [x] + | ChangeDoChildrenPost(x, post) -> ChangeDoChildrenPost([x], listpost post) + | DoChildren -> DoChildren + | SkipChildren -> SkipChildren + (****************************************************************) (** {2 ASL visitor functions} *) @@ -65,55 +79,71 @@ end different. *) -let rec visit_exprs (vis: #aslVisitor) (xs: expr list): expr list = - mapNoCopy (visit_expr vis) xs +let arg_of_sformal (sf: sformal): (ty * ident) = + match sf with + | Formal_In (ty, id) + | Formal_InOut (ty, id) -> (ty, id) + +let arg_of_ifield (IField_Field (id, _, wd)): (ty * ident) = + (Type_Bits (Expr_LitInt (string_of_int wd)), id) + +let args_of_encoding (Encoding_Block (_, _, fs, _, _, _, _, _)): (ty * ident) list = + List.map arg_of_ifield fs + +(** a base class for treeVisitors transforming the AST. + the method visit_stmts is left abstract for subclasses + to implement. *) +class virtual aslTreeVisitor (vis: #aslVisitor) = object(self) - and visit_var (vis: #aslVisitor) (x: ident): ident = + method visit_exprs (xs: expr list): expr list = + mapNoCopy (self#visit_expr) xs + + method visit_var (x: ident): ident = let aux (_: #aslVisitor) (x: ident): ident = x in doVisit vis (vis#vvar x) aux x - and visit_lvar (vis: #aslVisitor) (x: ident): ident = + method visit_lvar (x: ident): ident = let aux (_: #aslVisitor) (x: ident): ident = x in doVisit vis (vis#vlvar x) aux x - and visit_e_elsif (vis: #aslVisitor) (x: e_elsif): e_elsif = - let aux (vis: #aslVisitor) (x: e_elsif): e_elsif = + method visit_e_elsif (x: e_elsif): e_elsif = + let aux (_: #aslVisitor) (x: e_elsif): e_elsif = (match x with | E_Elsif_Cond(c, e) -> - let c' = visit_expr vis c in - let e' = visit_expr vis e in + let c' = self#visit_expr c in + let e' = self#visit_expr e in if c == c' && e == e' then x else E_Elsif_Cond(c', e') ) in doVisit vis (vis#ve_elsif x) aux x - and visit_slice (vis: #aslVisitor) (x: slice): slice = - let aux (vis: #aslVisitor) (x: slice): slice = + method visit_slice (x: slice): slice = + let aux (_: #aslVisitor) (x: slice): slice = (match x with | Slice_Single(e) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Slice_Single e' | Slice_HiLo(hi, lo) -> - let hi' = visit_expr vis hi in - let lo' = visit_expr vis lo in + let hi' = self#visit_expr hi in + let lo' = self#visit_expr lo in if hi == hi' && lo == lo' then x else Slice_HiLo(hi', lo') | Slice_LoWd(lo, wd) -> - let lo' = visit_expr vis lo in - let wd' = visit_expr vis wd in + let lo' = self#visit_expr lo in + let wd' = self#visit_expr wd in if lo == lo' && wd == wd' then x else Slice_LoWd(lo', wd') ) in doVisit vis (vis#vslice x) aux x - and visit_patterns (vis: #aslVisitor) (xs: pattern list): pattern list = - mapNoCopy (visit_pattern vis) xs + method visit_patterns (xs: pattern list): pattern list = + mapNoCopy (self#visit_pattern) xs - and visit_pattern (vis: #aslVisitor) (x: pattern): pattern = - let aux (vis: #aslVisitor) (x: pattern): pattern = + method visit_pattern (x: pattern): pattern = + let aux (_: #aslVisitor) (x: pattern): pattern = ( match x with | Pat_LitInt(_) -> x | Pat_LitHex(_) -> x @@ -122,75 +152,75 @@ let rec visit_exprs (vis: #aslVisitor) (xs: expr list): expr list = | Pat_Const(_) -> x | Pat_Wildcard -> x | Pat_Tuple(ps) -> - let ps' = visit_patterns vis ps in + let ps' = self#visit_patterns ps in if ps == ps' then x else Pat_Tuple ps' | Pat_Set(ps) -> - let ps' = visit_patterns vis ps in + let ps' = self#visit_patterns ps in if ps == ps' then x else Pat_Set ps' | Pat_Single(e) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Pat_Single(e') | Pat_Range(lo, hi) -> - let lo' = visit_expr vis lo in - let hi' = visit_expr vis hi in + let lo' = self#visit_expr lo in + let hi' = self#visit_expr hi in if lo == lo' && hi == hi' then x else Pat_Range(lo', hi') ) in doVisit vis (vis#vpattern x) aux x - and visit_expr (vis: #aslVisitor) (x: expr): expr = - let aux (vis: #aslVisitor) (x: expr): expr = + method visit_expr (x: expr): expr = + let aux (_: #aslVisitor) (x: expr): expr = (match x with | Expr_If(ty, c, t, els, e) -> - let ty = visit_type vis ty in - let c' = visit_expr vis c in - let t' = visit_expr vis t in - let els' = mapNoCopy (visit_e_elsif vis) els in - let e' = visit_expr vis e in + let ty = self#visit_type ty in + let c' = self#visit_expr c in + let t' = self#visit_expr t in + let els' = mapNoCopy (self#visit_e_elsif) els in + let e' = self#visit_expr e in if c == c' && t == t' && els == els' && e == e' then x else Expr_If(ty, c', t', els', e') | Expr_Binop(a, op, b) -> - let a' = visit_expr vis a in - let b' = visit_expr vis b in + let a' = self#visit_expr a in + let b' = self#visit_expr b in if a == a' && b == b' then x else Expr_Binop(a', op, b') | Expr_Field(e, f) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Expr_Field(e', f) | Expr_Fields(e, fs) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Expr_Fields(e', fs) | Expr_Slices(e, ss) -> - let e' = visit_expr vis e in - let ss' = mapNoCopy (visit_slice vis) ss in + let e' = self#visit_expr e in + let ss' = mapNoCopy (self#visit_slice) ss in if e == e' && ss == ss' then x else Expr_Slices(e', ss') | Expr_In(e, p) -> - let e' = visit_expr vis e in - let p' = visit_pattern vis p in + let e' = self#visit_expr e in + let p' = self#visit_pattern p in if e == e' && p == p' then x else Expr_In(e', p') | Expr_Var(v) -> - let v' = visit_var vis v in + let v' = self#visit_var v in if v == v' then x else Expr_Var(v') | Expr_Parens(e) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Expr_Parens e' | Expr_TApply(f, tes, es) -> - let tes' = visit_exprs vis tes in - let es' = visit_exprs vis es in + let tes' = self#visit_exprs tes in + let es' = self#visit_exprs es in if tes == tes' && es == es' then x else Expr_TApply(f, tes', es') | Expr_Tuple(es) -> - let es' = visit_exprs vis es in + let es' = self#visit_exprs es in if es == es' then x else Expr_Tuple es' | Expr_Unop(op, e) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Expr_Unop(op, e') | Expr_Unknown(t) -> - let t' = visit_type vis t in + let t' = self#visit_type t in if t == t' then x else Expr_Unknown t' | Expr_ImpDef(t, os) -> - let t' = visit_type vis t in + let t' = self#visit_type t in if t == t' then x else Expr_ImpDef(t', os) | Expr_Array(a, e) -> - let a' = visit_expr vis a in - let e' = visit_expr vis e in + let a' = self#visit_expr a in + let e' = self#visit_expr e in if a == a' && e == e' then x else Expr_Array(a', e') | Expr_LitInt _ -> x | Expr_LitHex _ -> x @@ -203,141 +233,134 @@ let rec visit_exprs (vis: #aslVisitor) (xs: expr list): expr list = doVisit vis (vis#vexpr x) aux x - and visit_types (vis: #aslVisitor) (xs: ty list): ty list = - mapNoCopy (visit_type vis) xs + method visit_types (xs: ty list): ty list = + mapNoCopy (self#visit_type) xs - and visit_type (vis: #aslVisitor) (x: ty): ty = - let aux (vis: #aslVisitor) (x: ty): ty = + method visit_type (x: ty): ty = + let aux (_: #aslVisitor) (x: ty): ty = ( match x with | Type_Constructor(_) -> x | Type_Bits(n) -> - let n' = visit_expr vis n in + let n' = self#visit_expr n in if n == n' then x else Type_Bits(n') | Type_App(tc, es) -> - let es' = visit_exprs vis es in + let es' = self#visit_exprs es in if es == es' then x else Type_App(tc, es') | Type_OfExpr(e) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Type_OfExpr(e') | Type_Register(wd, fs) -> let fs' = mapNoCopy (fun ((ss, f) as r) -> - let ss' = mapNoCopy (visit_slice vis) ss in + let ss' = mapNoCopy (self#visit_slice) ss in if ss == ss' then r else (ss', f) ) fs in if fs == fs' then x else Type_Register(wd, fs') | Type_Array(Index_Enum(tc), ety) -> - let ety' = visit_type vis ety in + let ety' = self#visit_type ety in if ety == ety' then x else Type_Array(Index_Enum(tc), ety') | Type_Array(Index_Range(lo, hi), ety) -> - let lo' = visit_expr vis lo in - let hi' = visit_expr vis hi in - let ety' = visit_type vis ety in + let lo' = self#visit_expr lo in + let hi' = self#visit_expr hi in + let ety' = self#visit_type ety in if lo == lo' && hi == hi' && ety == ety' then x else Type_Array(Index_Range(lo',hi'),ety') | Type_Tuple(tys) -> - let tys' = visit_types vis tys in + let tys' = self#visit_types tys in if tys == tys' then x else Type_Tuple(tys') ) in doVisit vis (vis#vtype x) aux x -let rec visit_lexprs (vis: #aslVisitor) (xs: lexpr list): lexpr list = - mapNoCopy (visit_lexpr vis) xs + method visit_lexprs (xs: lexpr list): lexpr list = + mapNoCopy (self#visit_lexpr) xs - and visit_lexpr (vis: #aslVisitor) (x: lexpr): lexpr = - let aux (vis: #aslVisitor) (x: lexpr): lexpr = + method visit_lexpr (x: lexpr): lexpr = + let aux (_: #aslVisitor) (x: lexpr): lexpr = ( match x with | LExpr_Wildcard -> x | LExpr_Var(v) -> - let v' = visit_lvar vis v in + let v' = self#visit_lvar v in if v == v' then x else LExpr_Var(v') | LExpr_Field(e, f) -> - let e' = visit_lexpr vis e in + let e' = self#visit_lexpr e in if e == e' then x else LExpr_Field(e', f) | LExpr_Fields(e, fs) -> - let e' = visit_lexpr vis e in + let e' = self#visit_lexpr e in if e == e' then x else LExpr_Fields(e', fs) | LExpr_Slices(e, ss) -> - let e' = visit_lexpr vis e in - let ss' = mapNoCopy (visit_slice vis) ss in + let e' = self#visit_lexpr e in + let ss' = mapNoCopy (self#visit_slice) ss in if e == e' && ss == ss' then x else LExpr_Slices(e', ss') | LExpr_BitTuple(es) -> - let es' = mapNoCopy (visit_lexpr vis) es in + let es' = mapNoCopy (self#visit_lexpr) es in if es == es' then x else LExpr_BitTuple es' | LExpr_Tuple(es) -> - let es' = mapNoCopy (visit_lexpr vis) es in + let es' = mapNoCopy (self#visit_lexpr) es in if es == es' then x else LExpr_Tuple es' | LExpr_Array(a, e) -> - let a' = visit_lexpr vis a in - let e' = visit_expr vis e in + let a' = self#visit_lexpr a in + let e' = self#visit_expr e in if a == a' && e == e' then x else LExpr_Array(a', e') | LExpr_Write(f, tes, es) -> - let f' = visit_var vis f in - let tes' = visit_exprs vis tes in - let es' = visit_exprs vis es in + let f' = self#visit_var f in + let tes' = self#visit_exprs tes in + let es' = self#visit_exprs es in if f == f' && tes == tes' && es == es' then x else LExpr_Write(f, tes', es') | LExpr_ReadWrite(f, g, tes, es) -> - let f' = visit_var vis f in - let g' = visit_var vis g in - let tes' = visit_exprs vis tes in - let es' = visit_exprs vis es in + let f' = self#visit_var f in + let g' = self#visit_var g in + let tes' = self#visit_exprs tes in + let es' = self#visit_exprs es in if f == f' && g == g' && tes == tes' && es == es' then x else LExpr_ReadWrite(f, g, tes', es') ) in doVisit vis (vis#vlexpr x) aux x -let with_locals (ls: ((ty * ident) list)) (vis: #aslVisitor) (f: #aslVisitor -> 'a): 'a = - vis#enter_scope ls; - let result = f vis in - vis#leave_scope (); - result -(* todo: should probably make this more like cil visitor and allow - * visit_stmt to generate a list of statements and provide a mechanism to emit - * statements to be inserted before/after the statement being transformed - *) -let rec visit_stmts (vis: #aslVisitor) (xs: stmt list): stmt list = - vis#enter_scope []; - let stmts' = mapNoCopy (visit_stmt vis) xs in + method virtual visit_stmts : stmt list -> stmt list + + method with_locals : 'a 'b. (ty * ident) list -> ('a -> 'b) -> 'a -> 'b = fun ls f x -> + vis#enter_scope ls; + let result = f x in vis#leave_scope (); - stmts' + result - and visit_stmt (vis: #aslVisitor) (x: stmt): stmt = - let aux (vis: #aslVisitor) (x: stmt): stmt = + method visit_stmt (x: stmt): stmt list = + let aux (_: #aslVisitor) (x: stmt): stmt = (match x with | Stmt_VarDeclsNoInit (ty, vs, loc) -> - let ty' = visit_type vis ty in - let vs' = mapNoCopy (visit_lvar vis) vs in + let ty' = self#visit_type ty in + let vs' = mapNoCopy (self#visit_lvar) vs in if ty == ty' && vs == vs' then x else Stmt_VarDeclsNoInit (ty', vs', loc) | Stmt_VarDecl (ty, v, i, loc) -> - let ty' = visit_type vis ty in - let v' = visit_lvar vis v in - let i' = visit_expr vis i in + let ty' = self#visit_type ty in + let v' = self#visit_lvar v in + let i' = self#visit_expr i in if ty == ty' && v == v' && i == i' then x else Stmt_VarDecl (ty', v', i', loc) | Stmt_ConstDecl (ty, v, i, loc) -> - let ty' = visit_type vis ty in - let v' = visit_lvar vis v in - let i' = visit_expr vis i in + let ty' = self#visit_type ty in + let v' = self#visit_lvar v in + let i' = self#visit_expr i in if ty == ty' && v == v' && i == i' then x else Stmt_ConstDecl (ty', v', i', loc) | Stmt_Assign (l, r, loc) -> - let l' = visit_lexpr vis l in - let r' = visit_expr vis r in + let l' = self#visit_lexpr l in + let r' = self#visit_expr r in if l == l' && r == r' then x else Stmt_Assign (l', r', loc) | Stmt_TCall (f, tes, args, loc) -> - let f' = visit_var vis f in - let tes' = visit_exprs vis tes in - let args' = visit_exprs vis args in + let f' = self#visit_var f in + let tes' = self#visit_exprs tes in + let args' = self#visit_exprs args in if f == f' && tes == tes' && args == args' then x else Stmt_TCall (f', tes', args', loc) | Stmt_FunReturn (e, loc) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Stmt_FunReturn (e', loc) | Stmt_ProcReturn (_) -> x | Stmt_Assert (e, loc) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Stmt_Assert (e', loc) | Stmt_Unpred (_) -> x | Stmt_ConstrainedUnpred(_) -> x | Stmt_ImpDef (v, loc) -> - let v' = visit_var vis v in + let v' = self#visit_var v in if v == v' then x else Stmt_ImpDef (v', loc) | Stmt_Undefined (_) -> x | Stmt_ExceptionTaken (_) -> x @@ -345,371 +368,448 @@ let rec visit_stmts (vis: #aslVisitor) (xs: stmt list): stmt list = | Stmt_Dep_ImpDef (_, _) -> x | Stmt_Dep_Undefined (_) -> x | Stmt_See (e, loc) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Stmt_See (e', loc) | Stmt_Throw (v, loc) -> - let v' = visit_var vis v in + let v' = self#visit_var v in if v == v' then x else Stmt_Throw (v', loc) | Stmt_DecodeExecute (i, e, loc) -> - let e' = visit_expr vis e in + let e' = self#visit_expr e in if e == e' then x else Stmt_DecodeExecute (i, e', loc) | Stmt_If (c, t, els, e, loc) -> - let c' = visit_expr vis c in - let t' = visit_stmts vis t in - let els' = mapNoCopy (visit_s_elsif vis) els in - let e' = visit_stmts vis e in + let c' = self#visit_expr c in + let t' = self#visit_stmts t in + let els' = mapNoCopy (self#visit_s_elsif) els in + let e' = self#visit_stmts e in if c == c' && t == t' && els == els' && e == e' then x else Stmt_If (c', t', els', e', loc) | Stmt_Case (e, alts, ob, loc) -> - let e' = visit_expr vis e in - let alts' = mapNoCopy (visit_alt vis) alts in - let ob' = mapOptionNoCopy (visit_stmts vis) ob in + let e' = self#visit_expr e in + let alts' = mapNoCopy (self#visit_alt) alts in + let ob' = mapOptionNoCopy (self#visit_stmts) ob in if e == e' && alts == alts' && ob == ob' then x else Stmt_Case (e', alts', ob', loc) | Stmt_For (v, f, dir, t, b, loc) -> - let v' = visit_lvar vis v in - let f' = visit_expr vis f in - let t' = visit_expr vis t in + let v' = self#visit_lvar v in + let f' = self#visit_expr f in + let t' = self#visit_expr t in let ty_v' = (Type_Constructor(Ident "integer"), v') in - let b' = with_locals [ty_v'] vis visit_stmts b in + let b' = self#with_locals [ty_v'] self#visit_stmts b in if v == v' && f == f' && t == t' && b == b' then x else Stmt_For (v', f', dir, t', b', loc) | Stmt_While (c, b, loc) -> - let c' = visit_expr vis c in - let b' = visit_stmts vis b in + let c' = self#visit_expr c in + let b' = self#visit_stmts b in if c == c' && b == b' then x else Stmt_While (c', b', loc) | Stmt_Repeat (b, c, loc) -> - let b' = visit_stmts vis b in - let c' = visit_expr vis c in + let b' = self#visit_stmts b in + let c' = self#visit_expr c in if b == b' && c == c' then x else Stmt_Repeat (b', c', loc) | Stmt_Try (b, v, cs, ob, loc) -> - let b' = visit_stmts vis b in - let v' = visit_lvar vis v in + let b' = self#visit_stmts b in + let v' = self#visit_lvar v in let ty_v' = (Type_Constructor(Ident "__Exception"), v') in - let cs' = mapNoCopy (with_locals [ty_v'] vis visit_catcher) cs in - let ob' = mapOptionNoCopy (with_locals [ty_v'] vis visit_stmts) ob in + let cs' = mapNoCopy (self#with_locals [ty_v'] self#visit_catcher) cs in + let ob' = mapOptionNoCopy (self#with_locals [ty_v'] self#visit_stmts) ob in if b == b' && v == v' && cs == cs' && ob == ob' then x else Stmt_Try (b', v', cs', ob', loc) ) in - doVisit vis (vis#vstmt x) aux x + doVisitList vis (vis#vstmt x) aux x - and visit_s_elsif (vis: #aslVisitor) (x: s_elsif): s_elsif = - let aux (vis: #aslVisitor) (x: s_elsif): s_elsif = + method visit_s_elsif (x: s_elsif): s_elsif = + let aux (_: #aslVisitor) (x: s_elsif): s_elsif = (match x with | S_Elsif_Cond(c, b) -> - let c' = visit_expr vis c in - let b' = visit_stmts vis b in + let c' = self#visit_expr c in + let b' = self#visit_stmts b in if c == c' && b == b' then x else S_Elsif_Cond(c', b') ) in doVisit vis (vis#vs_elsif x) aux x - and visit_alt (vis: #aslVisitor) (x: alt): alt = - let aux (vis: #aslVisitor) (x: alt): alt = + method visit_alt (x: alt): alt = + let aux (_: #aslVisitor) (x: alt): alt = (match x with | Alt_Alt(ps, oc, b) -> - let ps' = visit_patterns vis ps in - let oc' = mapOptionNoCopy (visit_expr vis) oc in - let b' = visit_stmts vis b in + let ps' = self#visit_patterns ps in + let oc' = mapOptionNoCopy (self#visit_expr) oc in + let b' = self#visit_stmts b in if ps == ps' && oc == oc' && b == b' then x else Alt_Alt(ps', oc', b') ) in doVisit vis (vis#valt x) aux x - and visit_catcher (vis: #aslVisitor) (x: catcher): catcher = - let aux (vis: #aslVisitor) (x: catcher): catcher = + method visit_catcher (x: catcher): catcher = + let aux (_: #aslVisitor) (x: catcher): catcher = (match x with | Catcher_Guarded(c, b) -> - let c' = visit_expr vis c in - let b' = visit_stmts vis b in + let c' = self#visit_expr c in + let b' = self#visit_stmts b in if c == c' && b == b' then x else Catcher_Guarded(c', b') ) in doVisit vis (vis#vcatcher x) aux x -let visit_mapfield (vis: #aslVisitor) (x: mapfield): mapfield = - let aux (vis: #aslVisitor) (x: mapfield): mapfield = + method visit_mapfield (x: mapfield): mapfield = + let aux (_: #aslVisitor) (x: mapfield): mapfield = (match x with | MapField_Field (v, p) -> - let v' = visit_var vis v in - let p' = visit_pattern vis p in + let v' = self#visit_var v in + let p' = self#visit_pattern p in if v == v' && p == p' then x else MapField_Field (v', p') ) in doVisit vis (vis#vmapfield x) aux x -let visit_sformal (vis: #aslVisitor) (x: sformal): sformal = - let aux (vis: #aslVisitor) (x: sformal): sformal = + method visit_sformal (x: sformal): sformal = + let aux (_: #aslVisitor) (x: sformal): sformal = (match x with | Formal_In (ty, v) -> - let ty' = visit_type vis ty in - let v' = visit_lvar vis v in + let ty' = self#visit_type ty in + let v' = self#visit_lvar v in if ty == ty' && v == v' then x else Formal_In (ty', v') | Formal_InOut(ty, v) -> - let ty' = visit_type vis ty in - let v' = visit_lvar vis v in + let ty' = self#visit_type ty in + let v' = self#visit_lvar v in if ty == ty' && v == v' then x else Formal_InOut (ty', v') ) in doVisit vis (vis#vsformal x) aux x -let rec visit_dpattern (vis: #aslVisitor) (x: decode_pattern): decode_pattern = - let aux (vis: #aslVisitor) (x: decode_pattern): decode_pattern = + method visit_dpattern (x: decode_pattern): decode_pattern = + let aux (_: #aslVisitor) (x: decode_pattern): decode_pattern = (match x with | DecoderPattern_Bits _ -> x | DecoderPattern_Mask _ -> x | DecoderPattern_Wildcard _ -> x | DecoderPattern_Not p -> - let p' = visit_dpattern vis p in + let p' = self#visit_dpattern p in if p == p' then x else DecoderPattern_Not p' ) in doVisit vis (vis#vdpattern x) aux x -let visit_encoding (vis: #aslVisitor) (x: encoding): encoding = - let aux (vis: #aslVisitor) (x: encoding): encoding = + method visit_encoding (x: encoding): encoding = + let aux (_: #aslVisitor) (x: encoding): encoding = (match x with | Encoding_Block (nm, iset, fs, op, e, ups, b, loc) -> - let e' = visit_expr vis e in - let b' = visit_stmts vis b in + let e' = self#visit_expr e in + let b' = self#visit_stmts b in if e == e' && b == b' then x else Encoding_Block (nm, iset, fs, op, e, ups, b', loc) ) in doVisit vis (vis#vencoding x) aux x -let rec visit_decode_case (vis: #aslVisitor) (x: decode_case): decode_case = - let aux (vis: #aslVisitor) (x: decode_case): decode_case = + method visit_decode_case (x: decode_case): decode_case = + let aux (_: #aslVisitor) (x: decode_case): decode_case = (match x with | DecoderCase_Case (ss, alts, loc) -> - let alts' = mapNoCopy (visit_decode_alt vis) alts in + let alts' = mapNoCopy (self#visit_decode_alt) alts in if alts == alts' then x else DecoderCase_Case (ss, alts', loc) ) in doVisit vis (vis#vdcase x) aux x - and visit_decode_alt (vis: #aslVisitor) (x: decode_alt): decode_alt = - let aux (vis: #aslVisitor) (x: decode_alt): decode_alt = + method visit_decode_alt (x: decode_alt): decode_alt = + let aux (_: #aslVisitor) (x: decode_alt): decode_alt = (match x with | DecoderAlt_Alt (ps, b) -> - let ps' = mapNoCopy (visit_dpattern vis) ps in - let b' = visit_decode_body vis b in + let ps' = mapNoCopy (self#visit_dpattern) ps in + let b' = self#visit_decode_body b in if ps == ps' && b == b' then x else DecoderAlt_Alt (ps', b') ) in doVisit vis (vis#vdalt x) aux x - and visit_decode_body (vis: #aslVisitor) (x: decode_body): decode_body = - let aux (vis: #aslVisitor) (x: decode_body): decode_body = + method visit_decode_body (x: decode_body): decode_body = + let aux (_: #aslVisitor) (x: decode_body): decode_body = (match x with | DecoderBody_UNPRED _ -> x | DecoderBody_UNALLOC _ -> x | DecoderBody_NOP _ -> x | DecoderBody_Encoding _ -> x | DecoderBody_Decoder (fs, c, loc) -> - let c' = visit_decode_case vis c in + let c' = self#visit_decode_case c in if c == c' then x else DecoderBody_Decoder (fs, c', loc) ) in doVisit vis (vis#vdbody x) aux x -let visit_arg (vis: #aslVisitor) (x: (ty * ident)): (ty * ident) = + method visit_arg (x: (ty * ident)): (ty * ident) = (match x with | (ty, v) -> - let ty' = visit_type vis ty in - let v' = visit_var vis v in + let ty' = self#visit_type ty in + let v' = self#visit_var v in if ty == ty' && v == v' then x else (ty', v') ) -let visit_args (vis: #aslVisitor) (xs: (ty * ident) list): (ty * ident) list = - mapNoCopy (visit_arg vis) xs + method visit_args (xs: (ty * ident) list): (ty * ident) list = + mapNoCopy (self#visit_arg) xs -let arg_of_sformal (sf: sformal): (ty * ident) = - match sf with - | Formal_In (ty, id) - | Formal_InOut (ty, id) -> (ty, id) - -let arg_of_ifield (IField_Field (id, _, wd)): (ty * ident) = - (Type_Bits (Expr_LitInt (string_of_int wd)), id) - -let args_of_encoding (Encoding_Block (_, _, fs, _, _, _, _, _)): (ty * ident) list = - List.map arg_of_ifield fs - -let visit_decl (vis: #aslVisitor) (x: declaration): declaration = - let aux (vis: #aslVisitor) (x: declaration): declaration = + method visit_decl (x: declaration): declaration = + let aux (_: #aslVisitor) (x: declaration): declaration = (match x with | Decl_BuiltinType (v, loc) -> - let v' = visit_var vis v in + let v' = self#visit_var v in if v == v' then x else Decl_BuiltinType (v', loc) | Decl_Forward (v, loc) -> - let v' = visit_var vis v in + let v' = self#visit_var v in if v == v' then x else Decl_Forward (v', loc) | Decl_Record (v, fs, loc) -> - let v' = visit_var vis v in - let fs' = visit_args vis fs in + let v' = self#visit_var v in + let fs' = self#visit_args fs in if v == v' && fs == fs' then x else Decl_Record (v', fs', loc) | Decl_Typedef (v, ty, loc) -> - let v' = visit_var vis v in - let ty' = visit_type vis ty in + let v' = self#visit_var v in + let ty' = self#visit_type ty in if v == v' && ty == ty' then x else Decl_Typedef (v', ty', loc) | Decl_Enum (v, es, loc) -> - let v' = visit_var vis v in - let es' = mapNoCopy (visit_var vis) es in + let v' = self#visit_var v in + let es' = mapNoCopy (self#visit_var) es in if v == v' && es == es' then x else Decl_Enum (v', es', loc) | Decl_Var (ty, v, loc) -> - let ty' = visit_type vis ty in - let v' = visit_var vis v in + let ty' = self#visit_type ty in + let v' = self#visit_var v in if ty == ty' && v == v' then x else Decl_Var (ty', v', loc) | Decl_Const (ty, v, e, loc) -> - let ty' = visit_type vis ty in - let v' = visit_var vis v in - let e' = visit_expr vis e in + let ty' = self#visit_type ty in + let v' = self#visit_var v in + let e' = self#visit_expr e in if ty == ty' && v == v' && e == e' then x else Decl_Const (ty', v', e', loc) | Decl_BuiltinFunction (ty, f, args, loc) -> - let ty' = visit_type vis ty in - let f' = visit_var vis f in - let args' = visit_args vis args in + let ty' = self#visit_type ty in + let f' = self#visit_var f in + let args' = self#visit_args args in if ty == ty' && f == f' && args == args' then x else Decl_BuiltinFunction (ty', f', args', loc) | Decl_FunType (ty, f, args, loc) -> - let ty' = visit_type vis ty in - let f' = visit_var vis f in - let args' = visit_args vis args in + let ty' = self#visit_type ty in + let f' = self#visit_var f in + let args' = self#visit_args args in if ty == ty' && f == f' && args == args' then x else Decl_FunType (ty', f', args', loc) | Decl_FunDefn (ty, f, args, b, loc) -> - let ty' = visit_type vis ty in - let f' = visit_var vis f in - let args' = visit_args vis args in - let b' = with_locals args' vis visit_stmts b in + let ty' = self#visit_type ty in + let f' = self#visit_var f in + let args' = self#visit_args args in + let b' = self#with_locals args' self#visit_stmts b in if ty == ty' && f == f' && args == args' && b == b' then x else Decl_FunDefn (ty', f', args', b', loc) | Decl_ProcType (f, args, loc) -> - let f' = visit_var vis f in - let args' = visit_args vis args in + let f' = self#visit_var f in + let args' = self#visit_args args in if f == f' && args == args' then x else Decl_ProcType (f', args', loc) | Decl_ProcDefn (f, args, b, loc) -> - let f' = visit_var vis f in - let args' = visit_args vis args in - let b' = with_locals args' vis visit_stmts b in + let f' = self#visit_var f in + let args' = self#visit_args args in + let b' = self#with_locals args' self#visit_stmts b in if f == f' && args == args' && b == b' then x else Decl_ProcDefn (f', args', b', loc) | Decl_VarGetterType (ty, f, loc) -> - let ty' = visit_type vis ty in - let f' = visit_var vis f in + let ty' = self#visit_type ty in + let f' = self#visit_var f in if ty == ty' && f == f' then x else Decl_VarGetterType (ty', f', loc) | Decl_VarGetterDefn (ty, f, b, loc) -> - let ty' = visit_type vis ty in - let f' = visit_var vis f in - let b' = visit_stmts vis b in + let ty' = self#visit_type ty in + let f' = self#visit_var f in + let b' = self#visit_stmts b in if ty == ty' && f == f' && b == b' then x else Decl_VarGetterDefn (ty', f', b', loc) | Decl_ArrayGetterType (ty, f, args, loc) -> - let ty' = visit_type vis ty in - let f' = visit_var vis f in - let args' = visit_args vis args in + let ty' = self#visit_type ty in + let f' = self#visit_var f in + let args' = self#visit_args args in if ty == ty' && f == f' && args == args' then x else Decl_ArrayGetterType (ty', f', args', loc) | Decl_ArrayGetterDefn (ty, f, args, b, loc) -> - let ty' = visit_type vis ty in - let f' = visit_var vis f in - let args' = visit_args vis args in - let b' = with_locals args' vis visit_stmts b in + let ty' = self#visit_type ty in + let f' = self#visit_var f in + let args' = self#visit_args args in + let b' = self#with_locals args' self#visit_stmts b in if ty == ty' && f == f' && args == args' && b == b' then x else Decl_ArrayGetterDefn (ty', f', args', b', loc) | Decl_VarSetterType (f, ty, v, loc) -> - let f' = visit_var vis f in - let ty' = visit_type vis ty in - let v' = visit_var vis v in + let f' = self#visit_var f in + let ty' = self#visit_type ty in + let v' = self#visit_var v in if f == f' && ty == ty' && v == v' then x else Decl_VarSetterType (f', ty', v', loc) | Decl_VarSetterDefn (f, ty, v, b, loc) -> - let f' = visit_var vis f in - let ty' = visit_type vis ty in - let v' = visit_var vis v in - let b' = with_locals [(ty', v')] vis visit_stmts b in + let f' = self#visit_var f in + let ty' = self#visit_type ty in + let v' = self#visit_var v in + let b' = self#with_locals [(ty', v')] self#visit_stmts b in if f == f' && ty == ty' && v == v' && b == b' then x else Decl_VarSetterDefn (f', ty', v', b', loc) | Decl_ArraySetterType (f, args, ty, v, loc) -> - let f' = visit_var vis f in - let args' = mapNoCopy (visit_sformal vis) args in - let ty' = visit_type vis ty in - let v' = visit_var vis v in + let f' = self#visit_var f in + let args' = mapNoCopy (self#visit_sformal) args in + let ty' = self#visit_type ty in + let v' = self#visit_var v in if f == f' && args == args' && ty == ty' && v == v' then x else Decl_ArraySetterType (f', args', ty', v', loc) | Decl_ArraySetterDefn (f, args, ty, v, b, loc) -> - let f' = visit_var vis f in - let args' = mapNoCopy (visit_sformal vis) args in - let ty' = visit_type vis ty in - let v' = visit_var vis v in + let f' = self#visit_var f in + let args' = mapNoCopy (self#visit_sformal) args in + let ty' = self#visit_type ty in + let v' = self#visit_var v in let lvars = List.map arg_of_sformal args' @ [(ty', v')] in - let b' = with_locals lvars vis visit_stmts b in + let b' = self#with_locals lvars self#visit_stmts b in if f == f' && args == args' && ty == ty' && v == v' && b == b' then x else Decl_ArraySetterDefn (f', args', ty', v', b', loc) | Decl_InstructionDefn (d, es, opd, c, ex, loc) -> - let d' = visit_var vis d in - let es' = mapNoCopy (visit_encoding vis) es in + let d' = self#visit_var d in + let es' = mapNoCopy (self#visit_encoding) es in let lvars = List.concat (List.map args_of_encoding es) in - let opd' = mapOptionNoCopy (with_locals lvars vis visit_stmts) opd in - let ex' = with_locals lvars vis visit_stmts ex in + let opd' = mapOptionNoCopy (self#with_locals lvars self#visit_stmts) opd in + let ex' = self#with_locals lvars self#visit_stmts ex in if d == d' && es == es' && opd == opd' && ex == ex' then x else Decl_InstructionDefn (d', es', opd', c, ex', loc) | Decl_DecoderDefn (d, dc, loc) -> - let d' = visit_var vis d in - let dc' = visit_decode_case vis dc in + let d' = self#visit_var d in + let dc' = self#visit_decode_case dc in if d == d' && dc == dc' then x else Decl_DecoderDefn (d', dc', loc) | Decl_Operator1 (op, vs, loc) -> - let vs' = mapNoCopy (visit_var vis) vs in + let vs' = mapNoCopy (self#visit_var) vs in if vs == vs' then x else Decl_Operator1 (op, vs', loc) | Decl_Operator2 (op, vs, loc) -> - let vs' = mapNoCopy (visit_var vis) vs in + let vs' = mapNoCopy (self#visit_var) vs in if vs == vs' then x else Decl_Operator2 (op, vs', loc) | Decl_NewEventDefn(v, args, loc) -> - let v' = visit_var vis v in - let args' = visit_args vis args in + let v' = self#visit_var v in + let args' = self#visit_args args in if v == v' && args == args' then x else Decl_NewEventDefn(v', args', loc) | Decl_EventClause(v, b, loc) -> - let v' = visit_var vis v in - let b' = visit_stmts vis b in + let v' = self#visit_var v in + let b' = self#visit_stmts b in if v == v' && b == b' then x else Decl_EventClause(v', b', loc) | Decl_NewMapDefn(ty, v, args, b, loc) -> - let ty' = visit_type vis ty in - let v' = visit_var vis v in - let args' = visit_args vis args in - let b' = with_locals args' vis visit_stmts b in + let ty' = self#visit_type ty in + let v' = self#visit_var v in + let args' = self#visit_args args in + let b' = self#with_locals args' self#visit_stmts b in if v == v' && args == args' && b == b' then x else Decl_NewMapDefn(ty', v', args', b', loc) | Decl_MapClause(v, fs, oc, b, loc) -> - let v' = visit_var vis v in - let fs' = mapNoCopy (visit_mapfield vis) fs in - let oc' = mapOptionNoCopy (visit_expr vis) oc in - let b' = visit_stmts vis b in + let v' = self#visit_var v in + let fs' = mapNoCopy (self#visit_mapfield) fs in + let oc' = mapOptionNoCopy (self#visit_expr) oc in + let b' = self#visit_stmts b in if v == v' && fs == fs' && oc == oc' && b == b' then x else Decl_MapClause(v', fs', oc', b', loc) | Decl_Config(ty, v, e, loc) -> - let ty' = visit_type vis ty in - let v' = visit_var vis v in - let e' = visit_expr vis e in + let ty' = self#visit_type ty in + let v' = self#visit_var v in + let e' = self#visit_expr e in if ty == ty' && v == v' && e == e' then x else Decl_Config(ty', v', e', loc) ) in doVisit vis (vis#vdecl x) aux x +end + +class aslForwardsVisitor (vis: #aslVisitor) = object(self) + inherit aslTreeVisitor vis + + method visit_stmts (xs: stmt list): stmt list = + vis#enter_scope []; + let stmts' = List.concat_map (self#visit_stmt) xs in + vis#leave_scope (); + stmts' +end + +(** visit statement lists in a backwards order. + i.e., enter_scope is called before the final statement in a block and + exit_scope is called after the initial statement. *) +class aslBackwardsVisitor (vis: #aslVisitor) = object(self) + inherit aslTreeVisitor vis + + method visit_stmts (xs: stmt list): stmt list = + vis#enter_scope []; + (* reverse resultant statements as blocks, to avoid reversing + lists returned by the visitAction. *) + let stmts' = List.rev @@ List.map (self#visit_stmt) (List.rev xs) in + vis#leave_scope (); + List.concat stmts' +end + + +(* convenience methods to visit with the ordinary aslForwardsVisitor. *) + +let visit_exprs (vis: #aslVisitor) : expr list -> expr list = (new aslForwardsVisitor vis)#visit_exprs + +let visit_var (vis: #aslVisitor) : ident -> ident = (new aslForwardsVisitor vis)#visit_var + +let visit_lvar (vis: #aslVisitor) : ident -> ident = (new aslForwardsVisitor vis)#visit_lvar + +let visit_e_elsif (vis: #aslVisitor) : e_elsif -> e_elsif = (new aslForwardsVisitor vis)#visit_e_elsif + +let visit_slice (vis: #aslVisitor) : slice -> slice = (new aslForwardsVisitor vis)#visit_slice + +let visit_patterns (vis: #aslVisitor) : pattern list -> pattern list = (new aslForwardsVisitor vis)#visit_patterns + +let visit_pattern (vis: #aslVisitor) : pattern -> pattern = (new aslForwardsVisitor vis)#visit_pattern + +let visit_expr (vis: #aslVisitor) : expr -> expr = (new aslForwardsVisitor vis)#visit_expr + +let visit_types (vis: #aslVisitor) : ty list -> ty list = (new aslForwardsVisitor vis)#visit_types + +let visit_type (vis: #aslVisitor) : ty -> ty = (new aslForwardsVisitor vis)#visit_type + +let visit_lexprs (vis: #aslVisitor) : lexpr list -> lexpr list = (new aslForwardsVisitor vis)#visit_lexprs + +let visit_lexpr (vis: #aslVisitor) : lexpr -> lexpr = (new aslForwardsVisitor vis)#visit_lexpr + +let visit_stmts (vis: #aslVisitor) : stmt list -> stmt list = (new aslForwardsVisitor vis)#visit_stmts + +let visit_stmt (vis: #aslVisitor) : stmt -> stmt list = (new aslForwardsVisitor vis)#visit_stmt + +let visit_s_elsif (vis: #aslVisitor) : s_elsif -> s_elsif = (new aslForwardsVisitor vis)#visit_s_elsif + +let visit_alt (vis: #aslVisitor) : alt -> alt = (new aslForwardsVisitor vis)#visit_alt + +let visit_catcher (vis: #aslVisitor) : catcher -> catcher = (new aslForwardsVisitor vis)#visit_catcher + +let visit_mapfield (vis: #aslVisitor) : mapfield -> mapfield = (new aslForwardsVisitor vis)#visit_mapfield + +let visit_sformal (vis: #aslVisitor) : sformal -> sformal = (new aslForwardsVisitor vis)#visit_sformal + +let visit_dpattern (vis: #aslVisitor) : decode_pattern -> decode_pattern = (new aslForwardsVisitor vis)#visit_dpattern + +let visit_encoding (vis: #aslVisitor) : encoding -> encoding = (new aslForwardsVisitor vis)#visit_encoding + +let visit_decode_case (vis: #aslVisitor) : decode_case -> decode_case = (new aslForwardsVisitor vis)#visit_decode_case + +let visit_decode_alt (vis: #aslVisitor) : decode_alt -> decode_alt = (new aslForwardsVisitor vis)#visit_decode_alt + +let visit_decode_body (vis: #aslVisitor) : decode_body -> decode_body = (new aslForwardsVisitor vis)#visit_decode_body + +let visit_arg (vis: #aslVisitor) : (ty * ident) -> (ty * ident) = (new aslForwardsVisitor vis)#visit_arg + +let visit_args (vis: #aslVisitor) : (ty * ident) list -> (ty * ident) list = (new aslForwardsVisitor vis)#visit_args + +let visit_decl (vis: #aslVisitor) : declaration -> declaration = (new aslForwardsVisitor vis)#visit_decl + +let visit_stmt_single (vis: #aslVisitor) : stmt -> stmt = + fun s -> match visit_stmt vis s with + | [x] -> x + | _ -> failwith "visit_stmt_single requires exactly one returned statement" (****************************************************************) diff --git a/libASL/cpp_backend.ml b/libASL/cpp_backend.ml new file mode 100644 index 00000000..4b3b9143 --- /dev/null +++ b/libASL/cpp_backend.ml @@ -0,0 +1,553 @@ +open Asl_ast +open Asl_utils + +(**************************************************************** + * Write State + ****************************************************************) + +type st = { + mutable depth : int; + mutable skip_seq : bool; + file: string; + oc : out_channel; + mutable ref_vars : IdentSet.t; + + (* symbols declared within semantics, e.g. non-global variables and other generated functions *) + mutable genvars : ident list; +} + +let inc_depth st = + st.depth <- st.depth + 2 + +let dec_depth st = + st.depth <- st.depth - 2 + +let is_ref_var v st = + IdentSet.mem v st.ref_vars + +let clear_ref_vars st = + st.ref_vars <- IdentSet.empty + +let add_ref_var v st = + st.ref_vars <- IdentSet.add v st.ref_vars + +(**************************************************************** + * String Utils + ****************************************************************) + +let replace s = + let s = + String.fold_left (fun acc c -> + if c = '.' then acc ^ "_" + else if c = '#' then acc ^ "HASH" + else acc ^ (String.make 1 c)) "" s in + s + +let name_of_ident v = + let s = (match v with + | Ident n -> "v_" ^ n + | FIdent (n,0) -> "f_" ^ n + | FIdent (n,i) -> "f_" ^ n ^ "_" ^ (string_of_int i)) in + replace s + +let prefixed_name_of_ident st v = + let name = name_of_ident v in + match v with + (* non-generated functions and variables are translated to methods on an interface object. *) + | FIdent _ when not (List.mem v st.genvars) -> "iface." ^ name + | Ident _ when not (List.mem v st.genvars) -> "iface." ^ name ^ "()" + | _ -> name + +let rec name_of_lexpr l = + match l with + | LExpr_Var v -> name_of_ident v + | LExpr_Field (l, f) -> + let l = name_of_lexpr l in + let f = name_of_ident f in + l ^ "." ^ f + | LExpr_Wildcard -> "_" + | _ -> failwith @@ "name_of_lexpr: " ^ (pp_lexpr l) + +(**************************************************************** + * File IO + ****************************************************************) + +let write_preamble opens ?(header = true) ?(exports = []) st = + Printf.fprintf st.oc "/* AUTO-GENERATED LIFTER FILE */\n\n"; + if header then Printf.fprintf st.oc "#pragma once\n"; + List.iter (fun s -> + Printf.fprintf st.oc "#include <%s>\n" s) opens; + List.iter (fun s -> + Printf.fprintf st.oc "#include <%s> // IWYU pragma: export\n" s) exports; + Printf.fprintf st.oc "\n"; + Printf.fprintf st.oc "namespace aslp {\n\n" + +let write_epilogue fid st = + Printf.fprintf st.oc "\n} // namespace aslp" + +let write_line s st = + let padding = String.concat "" (List.init st.depth (fun _ -> " ")) in + output_string st.oc padding; + output_string st.oc s + +let write_seq st = + if st.skip_seq then + st.skip_seq <- false + else Printf.fprintf st.oc ";\n" + +let write_nl st = + Printf.fprintf st.oc "\n" + +(**************************************************************** + * Expr Printing + ****************************************************************) + +let rec prints_expr e st = + match e with + (* Boolean Expressions *) + | Expr_Var(Ident "TRUE") -> "true" + | Expr_Var(Ident "FALSE") -> "false" + | Expr_TApply(FIdent("and_bool", 0), [], [a;b]) -> + Printf.sprintf "(%s) && (%s)" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("or_bool", 0), [], [a;b]) -> + Printf.sprintf "(%s) || (%s)" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("implies_bool", 0), [], [a;b]) -> + Printf.sprintf "(%s) ? (%s) : true" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("not_bool", 0), [], [a]) -> + "! (" ^ prints_expr a st ^ ")" + + (* State Accesses *) + | Expr_Var(v) -> + let n = prefixed_name_of_ident st v in + if is_ref_var v st then "" ^ n else n + | Expr_Field(e, f) -> + prints_expr e st ^ "." ^ name_of_ident f + | Expr_Array(a,i) -> + Printf.sprintf "List.nth (%s) (%s)" (prints_expr a st) (prints_expr i st) + + (* Int Expressions *) + | Expr_LitInt i -> + Printf.sprintf "iface.bigint_lit(\"%s\")" i + | Expr_TApply(FIdent("add_int", 0), [], [a;b]) -> + Printf.sprintf "(%s) + (%s)" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("sub_int", 0), [], [a;b]) -> + Printf.sprintf "(%s) - (%s)" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("mul_int", 0), [], [a;b]) -> + Printf.sprintf "(%s) * (%s)" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("frem_int", 0), [], [a;b]) -> + Printf.sprintf "(%s) %% (%s)" (prints_expr a st) (prints_expr b st) + + (* Other operations *) + | Expr_LitBits b -> + let len = String.length b in + Printf.sprintf "iface.bits_lit(%dU, \"%s\")" len b + | Expr_Slices(e,[Slice_LoWd(i,w)]) -> + let e = prints_expr e st in + let i = prints_expr i st in + let w = prints_expr w st in + Printf.sprintf "iface.extract_bits(%s, /*lo*/ %s, /*wd*/ %s)" e i w + | Expr_TApply(f, targs, args) -> + let f = prefixed_name_of_ident st f in + (* let args = List.map (fun e -> prints_expr e st) (targs @ args) in *) + let args = List.map (fun e -> prints_expr e st) ([] @ args) in + f ^ "(" ^ (String.concat ", " args) ^ ")" + + | Expr_LitString s -> "\"" ^ s ^ "\"" + | Expr_Tuple(es) -> "std::make_tuple(" ^ (String.concat "," (List.map (fun e -> prints_expr e st) es)) ^ ")" + | Expr_Unknown(ty) -> default_value ty st + + | _ -> failwith @@ "prints_expr: " ^ pp_expr e + +and default_value t st = + match t with + | Type_Bits w -> + Printf.sprintf "iface.bits_zero(%s)" (prints_expr w st) + | Type_Constructor (Ident "boolean") -> "true" + | Type_Constructor (Ident "integer") -> "iface.bigint_zero()" + | Type_Constructor (Ident "rt_label") -> "typename Traits::rt_label{}" + | Type_Constructor (Ident "rt_expr") -> "typename Traits::rt_expr{}" + | Type_Array(Index_Range(lo, hi),ty) -> + let lo = prints_expr lo st in + let hi = prints_expr hi st in + let d = default_value ty st in + Printf.sprintf "std::vector{(%s)-(%s), %s}" hi lo d + | _ -> failwith @@ "Unknown type for default value: " ^ (pp_type t) + +let prints_type t = + match t with + | Type_Constructor (Ident "boolean") -> "bool" + | Type_Bits _ -> "typename Traits::bits" + | Type_Constructor (Ident "integer") -> "typename Traits::bigint" + | Type_Constructor (Ident "rt_label") -> "typename Traits::rt_label" + | Type_Constructor (Ident "rt_expr") -> "typename Traits::rt_expr" + | Type_Constructor (Ident "rt_sym") -> "typename Traits::rt_lexpr" + | _ -> failwith @@ Asl_utils.pp_type t + +let prints_ret_type = Option.fold ~none:"void" ~some:prints_type + +(**************************************************************** + * Prim Printing + ****************************************************************) + +let write_fun_return e st = + let s = Printf.sprintf "return (%s)" e in + write_line s st + +let write_proc_return st = + write_line "return" st + +let write_assert s st = + let s = Printf.sprintf "assert(%s)" s in + write_line s st + +let write_unsupported st = + write_line {|throw std::runtime_error{"aslp_lifter: unsupported! " + std::string{__func__} + " @ " + std::string{__FILE__} + ":" + std::to_string(__LINE__)}|} st + +let write_call f targs args st = + let f = prefixed_name_of_ident st f in + let args = [] @ args in + let call = f ^ "(" ^ (String.concat ", " args) ^ ")" in + write_line call st + +let write_ref ty v e st = + (* let t = prints_type ty in *) + let t = "auto" in + let name = prefixed_name_of_ident st v in + let s = Printf.sprintf "%s %s = %s" t name e in + write_line s st; + add_ref_var v st + +let write_let ty v e st = + (* let t = prints_type ty in *) + let t = "auto" in + let v = prefixed_name_of_ident st v in + let s = Printf.sprintf "const %s %s = %s" t v e in + write_line s st + +let write_if_start c st = + let s = Printf.sprintf "if (%s) {\n" c in + write_line s st + +let write_if_elsif c st = + let s = Printf.sprintf "} else if (%s) {\n" c in + write_line s st + +let write_if_else st = + write_line "} else {\n" st + +let write_if_end st = + write_line "} // if\n" st; + st.skip_seq <- true + +(**************************************************************** + * Stmt Printing + ****************************************************************) + +let rec write_lexpr v st = + match v with + | LExpr_Wildcard -> + "std::ignore" + + | LExpr_Var v -> + name_of_ident v + + | LExpr_Array (LExpr_Var v, i) -> + let i = prints_expr i st in + let v = name_of_ident v in + Printf.sprintf "%s.at(%s)" v i + + | LExpr_Field (l, f) -> + let v = name_of_lexpr l in + Printf.sprintf "%s" v + + | LExpr_Tuple (ls) -> + let vars = List.map (fun l -> write_lexpr l st) ls in + let v = String.concat "," vars in + Printf.sprintf "std::tie(%s)" v + + | _ -> failwith @@ "write_assign: " ^ (pp_lexpr v) + +let rec write_stmt s st = + match s with + | Stmt_VarDeclsNoInit(ty, vs, loc) -> + let e = default_value ty st in + st.genvars <- vs @ st.genvars; + List.iter (fun v -> write_ref ty v e st) vs + + | Stmt_VarDecl(ty, v, e, loc) -> + let e = prints_expr e st in + st.genvars <- v :: st.genvars; + write_ref ty v e st + + | Stmt_ConstDecl(ty, v, e, loc) -> + let e = prints_expr e st in + st.genvars <- v :: st.genvars; + write_let ty v e st + + | Stmt_Assign(l, r, loc) -> + let e = prints_expr r st in + let l = write_lexpr l st in + write_line (Printf.sprintf "%s = %s" l e) st + + | Stmt_TCall(f, tes, es, loc) -> + let tes = List.map (fun e -> prints_expr e st) tes in + let es = List.map (fun e -> prints_expr e st) es in + write_call f tes es st + + | Stmt_FunReturn(e, loc) -> + write_fun_return (prints_expr e st) st + + | Stmt_ProcReturn(loc) -> + write_proc_return st + + | Stmt_Assert(e, loc) -> + write_assert (prints_expr e st) st + + | Stmt_Throw _ -> + write_unsupported st + + | Stmt_If(c, t, els, f, loc) -> + let rec iter = function + | S_Elsif_Cond(c,b)::xs -> + write_if_elsif (prints_expr c st) st; + write_stmts b st; + iter xs + | [] -> () in + write_if_start (prints_expr c st) st; + write_stmts t st; + iter els; + if f <> [] then (write_if_else st; write_stmts f st); + write_if_end st + + | _ -> failwith @@ "write_stmt: " ^ (pp_stmt s); + +and write_stmts s st = + inc_depth st; + match s with + | [] -> + write_proc_return st; + write_seq st; + dec_depth st + | x::xs -> + write_stmt x st; + write_seq st; + List.iter (fun s -> + write_stmt s st; + write_seq st; + ) xs; + dec_depth st; + assert (not st.skip_seq) + +(* XXX: assumes all function arguments are bits. *) +let build_args prefix targs args = + let inner = String.concat " " @@ + List.map + (fun s -> prefix ^ "bits " ^ name_of_ident s) + (targs@args) in + "(" ^ inner ^ ")" + +let typenames = ["bits"; "bigint"; "rt_expr"; "rt_lexpr"; "rt_label"] +let template_header = "template \n" +let template_args = "" + +(** tuple of return type, function name, function arguments (parenthesised) *) +type cpp_fun_sig = { + rty: string; + prefix: string; + name: ident; + args: string; + file: string; +} + +let write_fn name (ret_tyo,_,targs,args,_,body) st : cpp_fun_sig = + clear_ref_vars st; + let oldvars = st.genvars in + st.genvars <- (targs @ args) @ oldvars; + + let prefix = "aslp_lifter" ^ template_args ^ "::" in + let fname = name_of_ident name in + let args = build_args "typename Traits::" targs args in + let ret = prints_ret_type ret_tyo in + + write_line template_header st; + Printf.fprintf st.oc "%s %s%s%s {\n" ret prefix fname args; + write_stmts body st; + Printf.fprintf st.oc "\n} // %s\n\n" fname; + + st.genvars <- oldvars; + { rty = ret; prefix; name; args; file = st.file; } + +(**************************************************************** + * Directory Setup + ****************************************************************) + +let init_st (genfns: cpp_fun_sig list) prefix file = + let genvars = List.map (fun {name;_} -> name) genfns in + let path = Filename.concat prefix file in + Utils.mkdir_p (Filename.dirname path); + let oc = open_out path in + + { depth = 0; skip_seq = false; file; oc; ref_vars = IdentSet.empty ; + genvars; } + +(* prefix used to access all generated header files. *) +let export_prefix = "aslp/generated" +(* directory for generated template headers. *) +let gen_dir = "include" +(* directory for generated source files for explicit instantiation. *) +let instantiate_dir = "src/generated" +(* headers required by all files. note aslp/interface.hpp is NOT generated. *) +let stdlib_deps = ["cassert"; "tuple"; "variant"; "vector"; "stdexcept"; "aslp/interface.hpp"] +(* headers required by instruction semantics files. + includes forward declaration of lifter class. *) +let global_deps = stdlib_deps @ [export_prefix^"/aslp_lifter.hpp"] + + +(** Write an instruction file, containing just the behaviour of one instructions *) +let write_instr_file fn fnsig prefix dir = + let m = name_of_FIdent fn in + let path = dir ^ "/" ^ m ^ ".hpp" in + let st = init_st [] prefix path in + write_preamble global_deps st; + let gen = write_fn fn fnsig st in + write_epilogue () st; + close_out st.oc; + gen + +(* Write the test file, containing all decode tests *) +let write_test_file tests prefix dir = + let m = "decode_tests" in + let path = dir ^ "/" ^ m ^ ".hpp" in + let st = init_st [] prefix path in + write_preamble global_deps st; + let gens = List.map (fun (i,s) -> write_fn i s st) @@ Bindings.bindings tests in + write_epilogue () st; + close_out st.oc; + gens + +(* Write the decoder file *) +let write_decoder_file fn fnsig genfns prefix dir = + let m = name_of_FIdent fn in + let path = dir ^ "/" ^ m ^ ".hpp" in + let st = init_st genfns prefix path in + write_preamble (global_deps@[export_prefix^"/decode_tests.hpp"]) st; + let gen = write_fn fn fnsig st in + write_epilogue fn st; + close_out st.oc; + gen + + +(* Write the public-facing header file. For compilation speed, this declares but does not define. *) +let write_header_file fn fnsig semfns testfns prefix dir = + let name = "aslp_lifter" in + let path = dir ^ "/" ^ name ^ ".hpp" in + let st = init_st [] prefix path in + write_preamble stdlib_deps st; + + write_line template_header st; + write_line "class aslp_lifter {\n" st; + + inc_depth st; + write_line ("public: using interface = lifter_interface" ^ template_args ^ ";\n") st; + write_line "private: interface& iface;\n" st; + write_line "public:\n" st; + write_line "aslp_lifter(interface& iface) : iface{iface} { }\n" st; + + write_line "/* generated semantics */\n" st; + List.iter + (fun {rty; name; args; _} -> write_line (rty ^ " " ^ name_of_ident name ^ args ^ ";\n") st) + semfns; + write_line "/* generated decode test conditions */\n" st; + List.iter + (fun {rty; name; args; _} -> write_line (rty ^ " " ^ name_of_ident name ^ args ^ ";\n") st) + testfns; + + dec_depth st; + write_line "};\n" st; + + write_epilogue fn st; + close_out st.oc; + (name, semfns @ testfns) + +(** Writes the template implementation file. If needed, this can be used to instantiate the entire lifter. + However, it is fairly slow. *) +let write_impl_file allfns prefix dir = + let name = "aslp_lifter_impl" in + let path = dir ^ "/" ^ name ^ ".hpp" in + let st = init_st [] prefix path in + let exports = Utils.nub @@ List.map (fun {file;_} -> file) allfns in + write_preamble stdlib_deps ~exports st; + + write_epilogue () st; + close_out st.oc; + name + +(* Creates a directory of explicit instantiations, supporting parallel compilation. *) +let write_explicit_instantiations cppfuns prefix dir = + let write_instantiation file (cppfuns : cpp_fun_sig list) = + let dep = file in + let file = Filename.(chop_extension (basename file)) in + let path = dir ^ "/" ^ file ^ ".cpp" in + let st = init_st [] prefix path in + + write_preamble ~header:false stdlib_deps ~exports:[dep] st; + + write_line "#ifdef ASLP_LIFTER_INSTANTIATE\n" st; + write_line "using Traits = ASLP_LIFTER_INSTANTIATE;\n" st; + List.iter + (fun {rty; name; args; _} -> + let fname = name_of_ident name in + let s = Printf.sprintf "template %s %s%s::%s%s;\n" rty "aslp_lifter" template_args fname args in + write_line s st) + cppfuns; + write_line "#endif\n" st; + + write_epilogue () st; + close_out st.oc; + cppfuns + in + (* group by the .hpp file where each template is defined. *) + let files = Utils.nub @@ List.map (fun x -> x.file) cppfuns in + List.map + (fun file -> + write_instantiation file (List.filter (fun x -> x.file = file) cppfuns)) + files + +(* Installs non-generated support headers, e.g. interface definition *) +let install_headers prefix dir = + let res = List.nth Res.Sites.aslfiles 0 ^ "/include/aslp" in + let files = Array.to_list @@ Sys.readdir res in + let files = List.filter (String.ends_with ~suffix:".hpp") files in + List.map + (fun f -> + let i = open_in_bin (res ^ "/" ^ f) in + let size = in_channel_length i in + let path = (prefix ^ "/" ^ dir ^ "/" ^ Filename.basename f) in + let o = open_out_bin path in + output_string o (really_input_string i size); + close_in i; + close_out o; + path) + files + +(* Write all of the above, expecting Utils.ml to already be present in dir *) +let run dfn dfnsig tests fns root = + + let genprefix = root ^ "/" ^ gen_dir in + let instprefix = root ^ "/" ^ instantiate_dir in + + let semfns = Bindings.fold (fun fn fnsig acc -> (write_instr_file fn fnsig genprefix export_prefix)::acc) fns [] in + let testfns = write_test_file tests genprefix export_prefix in + let allfns = semfns @ testfns in + + let dfn = write_decoder_file dfn dfnsig allfns genprefix export_prefix in + let allfns = dfn :: allfns in + + let _header = write_header_file dfn dfnsig (dfn :: semfns) testfns genprefix export_prefix in + let _explicits = write_explicit_instantiations allfns instprefix "." in + + let _impl = write_impl_file allfns genprefix export_prefix in + + let _headers = install_headers genprefix (export_prefix ^ "/..") in + + () diff --git a/libASL/cpu.ml b/libASL/cpu.ml index e734855b..638c9173 100644 --- a/libASL/cpu.ml +++ b/libASL/cpu.ml @@ -12,6 +12,7 @@ open Asl_utils type gen_backend = | Ocaml | Cpp + | Scala type gen_function = AST.ident -> Eval.fun_sig -> Eval.fun_sig Bindings.t -> Eval.fun_sig Bindings.t -> string -> unit @@ -71,7 +72,9 @@ let mkCPU (env : Eval.Env.t) (denv: Dis.env): cpu = let run_gen_backend : gen_function = match backend with | Ocaml -> Ocaml_backend.run - | Cpp -> failwith "cpp backend not yet implemented" in + | Cpp -> Cpp_backend.run + | Scala -> Scala_backend.run + in (* Build backend program *) run_gen_backend decoder_id decoder_fnsig tests instrs dir diff --git a/libASL/cpu.mli b/libASL/cpu.mli index a3a71a12..54801332 100644 --- a/libASL/cpu.mli +++ b/libASL/cpu.mli @@ -8,6 +8,7 @@ type gen_backend = | Ocaml | Cpp + | Scala type gen_function = Asl_ast.ident -> Eval.fun_sig -> Eval.fun_sig Asl_utils.Bindings.t -> Eval.fun_sig Asl_utils.Bindings.t -> string -> unit diff --git a/libASL/dis.ml b/libASL/dis.ml index d4e9fd06..d98a5068 100644 --- a/libASL/dis.ml +++ b/libASL/dis.ml @@ -1427,8 +1427,8 @@ and dis_decode_alt' (loc: AST.l) (DecoderAlt_Alt (ps, b)) (vs: value list) (op: let@ enc_match = dis_encoding enc op in if enc_match then begin (* todo: should evaluate ConditionHolds to decide whether to execute body *) - if !debug_level >= 1 then begin - Printf.printf "Dissasm: %s\n" (pprint_ident inst); + if !debug_level >= 0 then begin + Printf.printf "Disasm: %s\n" (pprint_ident inst); end; let@ (lenv',stmts) = DisEnv.locally_ ( diff --git a/libASL/dune b/libASL/dune index 1298b6ac..8e38b4e0 100644 --- a/libASL/dune +++ b/libASL/dune @@ -22,8 +22,12 @@ (modules asl_ast asl_parser asl_parser_pp asl_utils asl_visitor cpu dis elf eval lexer lexersupport loadASL monad primops rws symbolic tcheck testing transforms utils value visitor res symbolic_lifter decoder_program call_graph req_analysis - offline_transform ocaml_backend dis_tc + offline_transform + dis_tc offline_opt + ocaml_backend + cpp_backend + scala_backend ) (libraries pprint zarith z3 str pcre mlbdd dune-site)) diff --git a/libASL/offline_opt.ml b/libASL/offline_opt.ml index 9e837d3f..87f2af3b 100644 --- a/libASL/offline_opt.ml +++ b/libASL/offline_opt.ml @@ -289,10 +289,10 @@ module CopyProp = struct method! vstmt = function (* Transform runtime variable decls into expression decls *) | Stmt_ConstDecl(t, v, Expr_TApply(f, [], args), loc) when is_var_decl f && candidate_var v st -> - ChangeDoChildrenPost(Stmt_VarDeclsNoInit(Offline_transform.rt_expr_ty, [v], loc), fun e -> e) + ChangeDoChildrenPost([Stmt_VarDeclsNoInit(Offline_transform.rt_expr_ty, [v], loc)], fun e -> e) (* Transform stores into assigns *) | Stmt_TCall(f, [], [Expr_Var v; e], loc) when is_var_store f && candidate_var v st -> - ChangeDoChildrenPost(Stmt_Assign(LExpr_Var v, e, loc), fun e -> e) + ChangeDoChildrenPost([Stmt_Assign(LExpr_Var v, e, loc)], fun e -> e) | _ -> DoChildren end diff --git a/libASL/scala_backend.ml b/libASL/scala_backend.ml new file mode 100644 index 00000000..4c069b9a --- /dev/null +++ b/libASL/scala_backend.ml @@ -0,0 +1,787 @@ + +open Visitor + +open Asl_utils + +open AST +open Asl_visitor +open Value + +(* For splitting up functions we use type to indicate which parameters are passed by reference. *) + +let mutable_decl = "class Mutable[T](var v: T)" + +type var_type = + | Mutable of ty + | Immutable of ty + | Unit + | Infer (* Omit the type def on scala side *) + +type sc_fun_sig = { + rt: var_type ; + arg_types: (var_type * ident) list; + targs: ident list; + args: ident list; + body: stmt list; +} + +module DefSet = Set.Make(struct + type t = (ident * ty * bool) + let compare (a,b,e) (c,d,f) = (match (Stdlib.compare a c) with + | 0 -> (match (Stdlib.compare b d ) with + | 0 -> (Stdlib.compare e f ) + | s -> s) + | s -> s) +end) + +let compose a b f = b (a f) + +let (let@) x f = fun s -> + let (s,r) = x s in + (f r) s +let (let+) x f = fun s -> + let (s,r) = x s in + (s,f r) + + +module LocMap = Map.Make(struct + type t = l + let compare = Stdlib.compare +end) + +class find_defs = object (self) + inherit Asl_visitor.nopAslVisitor + val mutable defs = LocMap.empty + + method add_dep loc i = defs <- LocMap.add loc i defs + method get_deps = defs + + method! vstmt = function + | Stmt_VarDeclsNoInit(ty, [v], loc) -> self#add_dep loc (v, ty); SkipChildren + | Stmt_VarDecl(ty, v, e, loc) -> self#add_dep loc (v ,ty); SkipChildren + | Stmt_ConstDecl(ty, v, e, loc) -> self#add_dep loc (v , ty) ; SkipChildren + | _ -> DoChildren +end + + +type stvarinfo = { + ident : ident; + var : expr; + typ : var_type; +} + +let state_var = { var = Expr_Var (Ident "st"); typ = Immutable (Type_Constructor (Ident "LiftState")); ident = Ident "st" } + +type st = { + mutable indent: int; + mutable skip_seq: bool; + oc : out_channel; + + (* variables that are access thru mutable.v field *) + mutable mutable_vars : var_type Transforms.ScopedBindings.t; + + (* New functions generated by splitting functions *) + mutable extra_functions : sc_fun_sig Bindings.t; +} + +let define (st) (t: var_type) (v:ident) = Stack.push (Stack.pop st.mutable_vars |> Bindings.add v t) st.mutable_vars +let push_scope st = Stack.push Bindings.empty (st.mutable_vars) +let pop_scope st = Stack.pop (st.mutable_vars) |> ignore + +let var_mutable (v:ident) (st) = + let find_def = Transforms.ScopedBindings.find_binding in + match (find_def st.mutable_vars v) with + | Some Mutable _ -> true + | Some Unit -> false + | Some Infer -> true + | Some Immutable _ -> false + | None -> true (* Globals are mutable by default *) + +let global_imports = ["util.Logger"] +let global_opens = ["ir"] + +let uniq_counter : int ref = ref 0 + +let new_index _ : int = uniq_counter := !uniq_counter + 1 ; !uniq_counter +let new_name pref = Ident ( pref ^ "_" ^ (string_of_int (new_index ()))) +let new_indexs (b: string) : string = b ^ (string_of_int (new_index ())) + +class stmt_counter = object(this) + inherit Asl_visitor.nopAslVisitor + val mutable stmt_count: int = 0 + val mutable expr_count: int = 0 + + method !vstmt s = stmt_count <- stmt_count + 1; DoChildren + + method !vexpr s = expr_count <- expr_count + 1; DoChildren + + method count (s:stmt) : int = stmt_count <- 0; (visit_stmt this s) |> ignore; stmt_count + + method expr_count (e:expr) : int = expr_count <- 0; (visit_expr this e) |> ignore ; expr_count + method gexpr_count = expr_count +end + +let sl_complexity(sl:stmt list) : int = let s = new stmt_counter in visit_stmts s (sl) |> ignore ; s#gexpr_count +let count_stmts_list (s:stmt list) : int list = List.map ((new stmt_counter)#count) s +let count_stmts (s:stmt) : int = (new stmt_counter)#count s + +(* Shallow inspection of an expression to guess its type. *) +let infer_type e : ty option = + let tint = Some (Type_Constructor (Ident "integer")) in + let tbool = Some (Type_Constructor (Ident "boolean")) in + let tbits = fun b -> Some (Type_Bits b) in + match e with + (* Boolean Expressions *) + | Expr_Var(Ident "TRUE") -> tbool + | Expr_Var(Ident "FALSE") -> tbool + | Expr_TApply(FIdent("and_bool", 0), [], [a;b]) -> tbool + | Expr_TApply(FIdent("or_bool", 0), [], [a;b]) -> tbool + | Expr_TApply(FIdent("implies_bool", 0), [], [a;b]) -> tbool + | Expr_TApply(FIdent("not_bool", 0), [], [a]) -> tbool + + (* Int Expressions using Z *) + | Expr_LitInt i -> tint + | Expr_TApply(FIdent("add_int", 0), [], [a;b]) -> tint + | Expr_TApply(FIdent("sub_int", 0), [], [a;b]) -> tint + | Expr_TApply(FIdent("mul_int", 0), [], [a;b]) -> tint + | Expr_TApply(FIdent("frem_int", 0), [], [a;b]) -> tint + + (* Other operations *) + | Expr_LitBits b -> tbits (Expr_LitInt (b)) + | Expr_Slices(e,[Slice_LoWd(i,w)]) -> tbits (w) + | _ -> None + + +let prints_arg_type (t: var_type) : string = + let rec ctype t = + match t with + | (Type_Bits _) -> "BV" + | (Type_Constructor (Ident "integer")) -> "BigInt" + | (Type_Constructor (Ident "boolean")) -> "Boolean" + | (Type_Tuple l) -> ": (" ^ (String.concat "," (List.map ctype (l)) ) ^ ")" + | Type_Constructor (Ident "rt_label") -> "RTLabel" + | Type_Constructor (Ident "rt_sym") -> "RTSym" + | Type_Constructor (Ident "rt_expr") -> "Expr" + | Type_Constructor (Ident e) -> e + | t -> failwith @@ "Unknown arg type: " ^ (pp_type t) + in + match t with + | Mutable v -> Printf.sprintf "Mutable[%s]" (ctype v) + | Immutable v -> ctype v + | Infer -> "" + | Unit -> "Unit" + + +(**************************************************************** + * String Utils + ****************************************************************) + +let inc_depth st = st.indent <- st.indent + 2 +let dec_depth st = st.indent <- st.indent - 2 + +let replace s = + String.fold_left (fun acc c -> + if c = '.' then acc ^ "_" + else if c = '#' then acc ^ "HASH" + else acc ^ (String.make 1 c)) "" s + +let plain_ident v : string = + let s = (match v with + | Ident n -> n + | FIdent (n,0) -> n + | FIdent (n,i) -> n ^ "_" ^ (string_of_int i)) in + replace s + +let name_of_ident v : string = + let s = (match v with + | Ident n -> "v_" ^ n + | FIdent (n,0) -> "f_" ^ n + | FIdent (n,i) -> "f_" ^ n ^ "_" ^ (string_of_int i)) in + replace s + +let rec name_of_lexpr l = + match l with + | LExpr_Var v -> name_of_ident v + | LExpr_Field (l, f) -> + let l = name_of_lexpr l in + let f = name_of_ident f in + l ^ "." ^ f + | LExpr_Wildcard -> "_" + | _ -> failwith @@ "name_of_lexpr: " ^ (pp_lexpr l) + +(* Expr printing *) + + +let rec prints_expr ?(deref:bool=true) e (st:st) = + match e with + (* Boolean Expressions *) + | Expr_Var(Ident "TRUE") -> "true" + | Expr_Var(Ident "FALSE") -> "false" + | Expr_TApply(FIdent("and_bool", 0), [], [a;b]) -> + Printf.sprintf "((%s) && (%s))" (prints_expr a st ~deref) (prints_expr b st ~deref) + | Expr_TApply(FIdent("or_bool", 0), [], [a;b]) -> + Printf.sprintf "((%s) || (%s))" (prints_expr a st ~deref) (prints_expr b st ~deref) + | Expr_TApply(FIdent("implies_bool", 0), [], [a;b]) -> + Printf.sprintf "((!(%s)) || (%s))" (prints_expr a st ~deref) (prints_expr b st ~deref) + | Expr_TApply(FIdent("not_bool", 0), [], [a]) -> + Printf.sprintf " (!(%s))" (prints_expr a st ~deref) + + (* State Accesses *) + | Expr_Var(v) -> (if (deref && (var_mutable v st)) then ((name_of_ident v) ^ ".v" ) else (name_of_ident v)) + | Expr_Field(e, f) -> + prints_expr e st ^ "." ^ name_of_ident f + | Expr_Array(a,i) -> + Printf.sprintf "(%s).get(%s)" (prints_expr a st) (prints_expr i st) + + (* Int Expressions using Z *) + | Expr_LitInt i -> "BigInt(" ^ i ^ ")" + | Expr_TApply(FIdent("add_int", 0), [], [a;b]) -> + Printf.sprintf "((%s) + (%s))" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("sub_int", 0), [], [a;b]) -> + Printf.sprintf "((%s) - (%s))" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("mul_int", 0), [], [a;b]) -> + Printf.sprintf "((%s) * (%s))" (prints_expr a st) (prints_expr b st) + | Expr_TApply(FIdent("frem_int", 0), [], [a;b]) -> + let x = (prints_expr a st) in let y = (prints_expr b st) in + Printf.sprintf "((%s) - ( (%s) * ((%s) / (%s))))" x y x y + + (* Other operations *) + | Expr_TApply(FIdent("as_ref", 0), [], [a]) -> Printf.sprintf "%s" (prints_expr a st ~deref:false) + (*as_ref is only output by backend to idicate not to deref pointers*) + | Expr_LitBits b -> Printf.sprintf "mkBits(v_st, %d, BigInt(\"%s\", 2))" (String.length b) b + | Expr_Slices(e,[Slice_LoWd(i,w)]) -> + let e = prints_expr e st in + let i = prints_expr i st in + let w = prints_expr w st in + let stv = prints_expr state_var.var st in + Printf.sprintf "bvextract(%s,%s,%s,%s)" stv e i w + | Expr_TApply(f, targs, args) -> + let deref = not (Bindings.mem f st.extra_functions) in + let f = name_of_ident f in + let args = List.map (fun e -> prints_expr ~deref:deref e st) (state_var.var::targs@args) in + f ^ "(" ^ (String.concat ", " (args)) ^ ")" + + | Expr_LitString s -> "\"" ^ s ^ "\"" + | Expr_Tuple(es) -> "(" ^ (String.concat "," (List.map (fun e -> prints_expr e st) es)) ^ ")" + | Expr_Unknown(ty) -> default_value ty st (* Sound? *) + + | _ -> failwith @@ "prints_expr: " ^ pp_expr e + +and default_value t st = + match t with + | Type_Bits w -> Printf.sprintf "mkBits(v_st, %s, BigInt(0))" (prints_expr w st) + | Type_Constructor (Ident "boolean") -> "true" + | Type_Constructor (Ident "integer") -> "BigInt(0)" + | Type_Constructor (Ident "rt_label") -> "rTLabelDefault" + | Type_Constructor (Ident "rt_sym") -> "rTSymDefault" + | Type_Constructor (Ident "rt_expr") -> "rTExprDefault" + | Type_Constructor (Ident "Unit") -> "()" + | Type_Constructor (Ident "Any") -> "null" + | Type_Array(Index_Range(lo, hi),ty) -> + let lo = prints_expr lo st in + let hi = prints_expr hi st in + let d = default_value ty st in + Printf.sprintf "Range.Exclusive((%s), (%s)).map(%s).toList" lo hi d + | _ -> failwith @@ "Unknown type for default value: " ^ (pp_type t) + + + +(* End expr printing *) + + +let write_line s st = + let padding = String.concat "" (List.init st.indent (fun _ -> " ")) in + Printf.fprintf st.oc "%s%s" padding s + +let write_seq st = + if st.skip_seq then + st.skip_seq <- false + else Printf.fprintf st.oc "\n" + +let write_nl st = + Printf.fprintf st.oc "\n" + +(**************************************************************** + * Prim Printing + ****************************************************************) + +let write_fun_return e st = + let s = Printf.sprintf "%s" e in + write_line s st + +let write_proc_return st = + write_line "/*proc return */ ()" st + +let write_assert s st = + let s = Printf.sprintf "assert (%s)" s in + write_line s st + +let write_unsupported st = + write_line "throw Exception(\"not supported\")" st + +let write_call f (targs : typeid list) (args: typeid list) st = + let f = name_of_ident f in + let args = (prints_expr state_var.var st)::(targs@args) in + let call = f ^ " (" ^ (String.concat "," args) ^ ")" in + write_line call st + +let write_ref v ty e st = + let name = name_of_ident v in + let s = Printf.sprintf "val %s = %s(%s)\n" name (prints_arg_type ty) e in + st.skip_seq <- true; + write_line s st + +let write_let v ty e st = + let v = name_of_ident v in + let s = Printf.sprintf "val %s : %s = %s \n" v (prints_arg_type ty) e in + st.skip_seq <- true; + write_line s st + +let write_if_start c st = + let s = Printf.sprintf "if (%s) then {\n" c in + write_line s st + +let write_if_elsif c st = + write_nl st; + let s = Printf.sprintf "} else if (%s) then {\n" c in + write_line s st + +let write_if_else st = + write_nl st; + write_line "} else {\n" st + +let write_if_end st = + write_nl st; + write_line "}" st + +(**************************************************************** + * Stmt Printing + ****************************************************************) + + +let prints_lexpr v st = + match v with + | LExpr_Wildcard -> "_" + | LExpr_Var v -> name_of_ident v + | LExpr_Array (LExpr_Var v, i) -> name_of_ident v + | LExpr_Field (l, f) -> name_of_lexpr l + | LExpr_Tuple (ls) -> "(" ^ String.concat "," (List.map name_of_lexpr ls) ^ ")" + | _ -> failwith @@ "pritns_lexpr: " ^ (pp_lexpr v) + +let rec expr_of_lexpr v = + match v with + | LExpr_Var v -> Expr_Var v + | LExpr_Array (LExpr_Var v, i) -> Expr_Array (Expr_Var v, i) + | LExpr_Field (l, f) -> Expr_Field (expr_of_lexpr l , f) + | LExpr_Tuple (ls) -> Expr_Tuple (List.map expr_of_lexpr ls) + | _ -> failwith @@ "expr_of_lexpr: " ^ (pp_lexpr v) + +let rec write_assign v e st = + match v with + | LExpr_Wildcard -> + failwith @@ "write_assign: " ^ (pp_lexpr v); + (*let s = Printf.sprintf "val _ = %s \n" e in + st.skip_seq <- true; + write_line s st*) + + | LExpr_Var v -> + let v = (if (var_mutable v st) then ((name_of_ident v) ^ ".v" ) else (name_of_ident v)) in + let s = Printf.sprintf "%s = %s" v e in + write_line s st + + | LExpr_Array (LExpr_Var v, i) -> + let i = prints_expr i st in + let v = name_of_ident v in + let s = Printf.sprintf "%s = list_update (%s, %s, %s)" v v i e in + write_line s st + + | LExpr_Field (l, f) -> + let v = name_of_lexpr l in + let s = Printf.sprintf "%s = %s" v e in + write_line s st + + | LExpr_Tuple (ls) -> + let vars = List.init (List.length ls) (fun i -> "tmp" ^ (string_of_int (new_index ()))) in + let v = "(" ^ String.concat "," vars ^ ")" in + let s = Printf.sprintf "val %s = %s \n" v e in + st.skip_seq <- true; + write_line s st; + List.iter2 (fun l e -> + write_seq st; + write_assign l e st + ) ls vars + + | _ -> failwith @@ "write_assign: " ^ (pp_lexpr v) + + +module FunctionSplitter = struct + open Transforms.ScopedBindings + + module StmtSet = Set.Make(struct + type t = stmt + let compare = Stdlib.compare + end) + + type split_ctx = { + (* return type of containing function *) + return_type: var_type; + (* to look up the type of parameters *) + bindings: var_type Bindings.t ; + } + + class find_returns = object(this) + inherit Asl_visitor.nopAslVisitor + val mutable returns = StmtSet.empty + val mutable fun_returns = false + val mutable proc_returns = false + + method! vstmt s = + match s with + | Stmt_FunReturn _ -> (fun_returns <- true); (returns <- (StmtSet.add s returns)); DoChildren + | Stmt_ProcReturn _ -> (proc_returns <- true); (returns <- (StmtSet.add s returns)); DoChildren + | _ -> DoChildren + + method any_rets () = fun_returns || proc_returns + end + +(* Replaces a statement list with a a call to a function containing the same statements. + + If the statement list contains return statements then a return statement is returned. + (we are assuming no early returns). + + We are also assuming that the statement list is an entire scope, i.e. no definitions + escape from the block. + + Typically the statement list sl here is the code block in an if/elseif/else branch. + + returns (call stmt, ident * fun_sig) + *) + let stmt_list_to_function (c: split_ctx) (sl:stmt list)(*: stmt *) = + let in_return_context sl = let v = new find_returns in + visit_stmts v sl |> ignore ; v#any_rets () in + let swap (a,b) = (b,a) in + let param_types = List.filter (fun (t, i) -> i <> state_var.ident) (List.map swap (Bindings.bindings c.bindings)) in + let param_types = (state_var.typ,state_var.ident)::param_types in + let returning = if (in_return_context sl) then c.return_type else Unit in + let fname = new_name "split_fun" in + let targs = [] in + let args = List.map snd (List.tl param_types) in + let funsig : sc_fun_sig = {rt=returning; arg_types=param_types; targs=targs; args=args; body=sl} in + let new_funsig = (fname , funsig) in + let call_params = List.map (fun i -> Expr_TApply ((FIdent("as_ref", 0)), [], [Expr_Var i])) args in + let new_stmt = (match returning with + | Immutable x -> Some None + | Mutable x -> Some None + | Infer -> Some None + | Unit -> None) + |> (function + | Some _ -> Stmt_FunReturn ((Expr_TApply (fname, targs, call_params)), Unknown) + | None -> Stmt_TCall (fname,targs,call_params, Unknown)) + in (new_stmt, new_funsig) + + (* Create a function containing a single return statement of the given expression. All variables used in the expression + become parameters to the function. *) + let expr_to_function (c: split_ctx) (e: expr) = + let fname = new_name "split_expr" in + let returning = Infer in + let params = List.map (fun e -> Option.map (fun v -> v, e) (Bindings.find_opt e c.bindings)) (IdentSet.elements (fv_expr e)) in + let params = List.concat_map Option.to_list params in + let params = (state_var.typ,state_var.ident)::List.filter (fun (t, i) -> i <> state_var.ident) params in + let targs = [] in + let args = List.map snd (List.tl params) in (* chop off state var since its always added to calls*) + (* as_ref here is a bit of a hack *) + let call_params = List.map (fun i -> Expr_TApply ((FIdent("as_ref", 0)), [], [Expr_Var i])) args in + let body = [Stmt_FunReturn (e, Unknown)] in + let funsig = {rt=returning; arg_types=params; targs=targs; args=args; body=body} in + let callexpr = Expr_TApply (fname, [], call_params) in + (callexpr, (fname, funsig)) + + + (* When a scope block has many statements *) + let outline_function_on_size_threshold ctx sl thresh = + let sum x = List.fold_left (fun a b -> (a + b)) 0 x in + let branch_weight = compose count_stmts_list sum in + let stmt_weights = sl |> List.map (fun s -> match s with + | Stmt_If(c, t, els, f, loc) -> (s, [branch_weight t] + @ (List.map (branch_weight) (List.map (function | S_Elsif_Cond (e, sl) -> sl) els)) + @ [branch_weight f]) + | c -> (s, [count_stmts c] ) + ) in + let total = List.map (compose snd sum) stmt_weights |> sum in + if (total < thresh) then (sl, []) else + let (f,s) = stmt_list_to_function ctx sl + in ([f], [s]) + + (* When a function has many statements at the same level. *) + let chunk_outline_stmtlist_on_size_threshold ctx sl thresh = + (* Find all the definitions in the block, move them to the beginning of the block, split the + block into a number of chunks and outline those chunks as a function each. + + Not needed with agressive expression outlining. *) + () + + (* When an expression has many subexpressions *) + let outline_expr_on_size_threshold ctx (thresh:int) (e:expr) = + let c = new stmt_counter in + let c = c#expr_count e in + if (c > thresh) then let (e,f) = expr_to_function ctx e in (e, [f]) else (e, []) + + class branch_outliner (funsig: sc_fun_sig) (outline_thresh:int) = object(this) + inherit Asl_visitor.nopAslVisitor + + val scoped_bindings : var_type Transforms.ScopedBindings.t = let x = Transforms.ScopedBindings.init () in + List.iter (fun (t,i) -> add_bind x i t) (funsig.arg_types); + push_scope x (); + x + + val mutable extra_funs : sc_fun_sig Bindings.t = Bindings.empty + + method add_fun (bs: (ident * sc_fun_sig) list) = extra_funs <- Bindings.union + (fun i a b -> if a == b then Some a else failwith ("(branch_outliner) split function names be distinct " ^ (name_of_ident i))) extra_funs (Bindings.of_seq (List.to_seq bs)) + + method split_sl sl = + let ctx = {return_type = funsig.rt; bindings = current_scope_bindings scoped_bindings} in + let t,nf = outline_function_on_size_threshold ctx sl outline_thresh + in this#add_fun nf; + t + + method split_exp e = + let ctx = {return_type = funsig.rt; bindings = current_scope_bindings scoped_bindings} in + let ne,nf = outline_expr_on_size_threshold ctx outline_thresh e + in this#add_fun nf; + ne + + + method! enter_scope ss = push_scope scoped_bindings () + method! leave_scope ss = pop_scope scoped_bindings () + + method! vstmt s = + match s with + | Stmt_VarDeclsNoInit(ty, vs, loc) -> + List.iter (fun f -> add_bind scoped_bindings f (Mutable ty)) vs; + DoChildren + | Stmt_VarDecl(ty, v, i, loc) -> + add_bind scoped_bindings v (Mutable ty) ; + DoChildren + | Stmt_ConstDecl(ty, v, i, loc) -> + add_bind scoped_bindings v (Immutable ty) ; + DoChildren + | Stmt_If (c, t, els, e, loc) -> + ChangeDoChildrenPost ([Stmt_If (c, t, els, e, loc)], (function + | [Stmt_If (c, t, els, e, loc)] -> + let c' = visit_expr this c in + (* visit each statement list and then maybe outline it *) + let t' = visit_stmts this t in + let t' = this#split_sl t' in + let els' = mapNoCopy (visit_s_elsif this ) els in + let e' = visit_stmts this e in + let e' = this#split_sl e' in + [Stmt_If (c', t', els', e', loc)] + | _ -> [s] + )) + (* Statements with child scopes that shouldn't appear towards the end of transform pipeline *) + | Stmt_Case _ -> failwith "(FixRedefinitions) case not expected" + | Stmt_For _ -> failwith "(FixRedefinitions) for not expected" + | Stmt_While _ -> failwith "(FixRedefinitions) while not expected" + | Stmt_Repeat _ -> failwith "(FixRedefinitions) repeat not expected" + | Stmt_Try _ -> failwith "(FixRedefinitions) try not expected" + | _ -> DoChildren + + method! vs_elsif e = let c,e = match e with + | S_Elsif_Cond (c, sl) -> c,sl in + let sl = visit_stmts this e in + let sl = this#split_sl sl in + ChangeTo (S_Elsif_Cond (c,sl)) + + method! vexpr e = ChangeTo (this#split_exp e) + (*the generated code is too large with this, type inference seems to do ok without *) + (*method! vexpr e = match e with + | Expr_TApply _ -> DoChildren (* to preserve typing of fun returns *) + | e -> ChangeTo (this#split_exp e) *) + + method split_function (x:unit) = let sl = visit_stmts this (funsig.body) in (sl , extra_funs) + + end + +end + + +let rec write_stmt ?(primitive:bool=false) s st = + match s with + | Stmt_VarDeclsNoInit(ty, vs, loc) -> + List.iter (define st (Mutable ty)) vs ; + let e = default_value ty st in + List.iter (fun v -> write_ref v (Mutable ty) e st) vs + + | Stmt_VarDecl(ty, v, e, loc) -> + define st (Mutable ty) v; + let e = prints_expr e st in + write_ref v (Mutable ty) e st + + | Stmt_ConstDecl(ty, v, e, loc) -> + define st (Immutable ty) v; + let e = prints_expr e st in + write_let v (Immutable ty) e st + + | Stmt_Assign(l, r, loc) -> + let e = prints_expr r st in + write_assign l e st + + | Stmt_TCall(f, tes, es, loc) -> + let tes = List.map (fun e -> prints_expr e st) tes in + let es = List.map (fun e -> prints_expr e st) es in + write_call f tes es st + + | Stmt_FunReturn(e, loc) -> + write_fun_return (prints_expr e st) st + + | Stmt_ProcReturn(loc) -> + write_proc_return st + + | Stmt_Assert(e, loc) -> + write_assert (prints_expr e st) st + + | Stmt_Throw _ -> + write_unsupported st + + | Stmt_If(c, t, els, f, loc) -> + let rec iter = function + | S_Elsif_Cond(c,b)::xs -> + write_if_elsif (prints_expr c st) st; + write_stmts b st; + iter xs + | [] -> () in + write_if_start (prints_expr c st) st; + write_stmts t st; + iter els; + if f <> [] then (write_if_else st; write_stmts f st); + write_if_end st + + | _ -> failwith @@ "write_stmt: " ^ (pp_stmt s); + +and write_stmts ?(primitive:bool=false) s st = + inc_depth st; + push_scope st; + match s with + | [] -> + write_proc_return st; + dec_depth st + | x::xs -> + write_stmt x st; + List.iter (fun s -> + write_seq st; + write_stmt s st + ) xs; + dec_depth st; + assert (not st.skip_seq) + ; pop_scope st + + +let write_preamble imports opens st = + Printf.fprintf st.oc "/* AUTO-GENERATED ASLp LIFTER FILE */\npackage aslloader\n"; + List.iter (fun n -> + Printf.fprintf st.oc "import %s\n" n) imports; + List.iter (fun n -> + Printf.fprintf st.oc "import %s._\n" n) opens; + Printf.fprintf st.oc "\n" + + +let write_epilogue (fid: AST.ident) st = + Printf.fprintf st.oc "class %s { + + def decode(l: LiftState, insn: BV) : Any = { + %s(l, insn) + } +}" (plain_ident fid) (name_of_ident fid) + +open AST + + +let init_b (u:unit) = Transforms.ScopedBindings.init () +let init_st : st = { indent = 0; skip_seq=false; oc=stdout; mutable_vars = init_b (); extra_functions = Bindings.empty} +let rinit_st oc st : st = {st with indent = 0; skip_seq=false; oc=oc; mutable_vars = init_b ()} + +let build_args (tys: ((var_type * ident) list)) targs args = + let targs = List.map (fun t -> name_of_ident t) targs in + (*let args = List.map (fun t -> name_of_ident t) args in *) + let ta = List.map (fun (t, i) -> (name_of_ident i) ^ ": " ^ (prints_arg_type t)) tys in + "(" ^ (String.concat "," (targs@ta)) ^ ")" + +let print_write_fn (name: AST.ident) ((ret_tyo: var_type), (argtypes: ((var_type * ident) list)), targs, args, _, body) st = + let open Transforms.ScopedBindings in + (*update_funcalls {name=(name_of_ident name); + targs=(List.map (fun _ -> Some (Type_Constructor (Ident "integer"))) targs); + args = List.map (fun (t, _) -> Some t) (argtypes)} (st.output_fun) ; *) + let wargs = build_args argtypes targs args in + let ret = prints_arg_type ret_tyo in + let ret = if (ret <> "") then (": " ^ ret) else "" in + (* bind params *) + push_scope st.mutable_vars () ; + List.iter (fun (a,b) -> add_bind st.mutable_vars b a) argtypes ; + + Printf.fprintf st.oc "def %s %s %s = {\n" (name_of_ident name) wargs ret; + write_stmts body st; + Printf.fprintf st.oc "\n}\n"; + + (* unbind params *) + pop_scope st.mutable_vars () + +let write_fn (name: AST.ident) (fn:sc_fun_sig) st = + if ((sl_complexity fn.body) < 1000) then (print_write_fn name (fn.rt, fn.arg_types, fn.targs, fn.args, (), fn.body) st) else + let ol = new FunctionSplitter.branch_outliner fn 5 in + let (nb,to_add) = ol#split_function () in + print_write_fn name (fn.rt, fn.arg_types, fn.targs, fn.args, (), nb) st; + st.extra_functions <- to_add ; + Bindings.iter (fun n b -> print_write_fn n (b.rt, b.arg_types, b.targs, b.args, (), b.body) st) to_add + +let lift_fsig (fs: Eval.fun_sig) : sc_fun_sig = + let assigned = assigned_vars_of_stmts (fnsig_get_body fs) in + let params = List.map (fun (t,i) -> if (IdentSet.mem i assigned) then Mutable t,i else Immutable t, i) (fnsig_get_typed_args fs) in + let rt = match (fnsig_get_rt fs) with + | Some x -> Immutable x + | None -> Unit + in {rt=rt; arg_types=(state_var.typ,state_var.ident)::params; targs=(fnsig_get_targs fs); args=(fnsig_get_args fs); body=(fnsig_get_body fs)} + + +(* Write an instruction file, containing just the behaviour of one instructions *) +let write_instr_file fn fnsig dir st = + let m = name_of_FIdent fn in + let path = dir ^ "/" ^ m ^ ".scala" in + let oc = open_out path in + let st = rinit_st oc st in + write_preamble global_imports global_opens st; + write_fn fn fnsig st; + close_out oc; + name_of_FIdent fn + + +(* Write the decoder file - should depend on all of the above *) +let write_decoder_file fn fnsig deps dir st = + let m = "aslpOffline" in + let path = dir ^ "/" ^ m ^ ".scala" in + let oc = open_out path in + let st = rinit_st oc st in + write_preamble global_imports (global_opens) st; + write_fn fn fnsig st; + write_epilogue fn st; + close_out oc; + m + +(* Write the test file, containing all decode tests *) +let write_test_file tests dir st = + let m = "decode_tests" in + let path = dir ^ "/" ^ m ^".scala" in + let oc = open_out path in + let st = rinit_st oc st in + write_preamble global_imports (global_opens) st; + Bindings.iter (fun i s -> write_fn i (lift_fsig s) st) tests; + close_out oc; + m + +let run (dfn : ident) (dfnsig : ty option * 'a * ident list * ident list * 'b * stmt list) (tests : (ty option * 'a * ident list * ident list * 'b * stmt list) Bindings.t) (fns : (ty option * 'a * ident list * ident list * 'b * stmt list) Bindings.t) (dir : typeid) = + let st = init_st in + let files = Bindings.fold (fun fn fnsig acc -> (write_instr_file fn (lift_fsig fnsig) dir st)::acc) fns [] in + let files = (write_test_file tests dir st)::files in + write_decoder_file dfn (lift_fsig dfnsig) files dir st |> ignore ; + () + diff --git a/libASL/symbolic_lifter.ml b/libASL/symbolic_lifter.ml index 74a4dd11..fa51d26a 100644 --- a/libASL/symbolic_lifter.ml +++ b/libASL/symbolic_lifter.ml @@ -78,7 +78,7 @@ module RemoveUnsupported = struct inherit Asl_visitor.nopAslVisitor method! vstmt e = - (match e with + singletonVisitAction (match e with | Stmt_Assert(e, loc) -> if contains_unsupported e unsupported env then ChangeTo (assert_false loc) else DoChildren @@ -202,7 +202,7 @@ module Cleanup = struct (RemoveUnsupported.assert_false loc) else e | _ -> e) in - ChangeDoChildrenPost(e, reduce) + singletonVisitAction @@ ChangeDoChildrenPost(e, reduce) end let rec trim_post_term stmts = @@ -272,7 +272,7 @@ module DecoderCleanup = struct (RemoveUnsupported.assert_false loc) else e | _ -> e) in - ChangeDoChildrenPost(e, reduce) + singletonVisitAction @@ ChangeDoChildrenPost(e, reduce) end let run unsupported dsig = diff --git a/libASL/transforms.ml b/libASL/transforms.ml index cb06d8b0..70a4551b 100644 --- a/libASL/transforms.ml +++ b/libASL/transforms.ml @@ -223,7 +223,7 @@ module RefParams = struct inherit Asl_visitor.nopAslVisitor method! vstmt = function - | Stmt_ProcReturn _ -> ChangeTo s + | Stmt_ProcReturn _ -> ChangeTo [s] | Stmt_FunReturn _ -> failwith "unexpected function return in ref param conversion." | _ -> DoChildren end @@ -300,8 +300,8 @@ module RefParams = struct val mutable n = 0; - method! vstmt (s: stmt): stmt visitAction = - match s with + method! vstmt (s: stmt): stmt list visitAction = + singletonVisitAction @@ match s with | Stmt_Assign (LExpr_Write (setter, targs, args), r, loc) -> (match Bindings.find_opt setter ref_params with | None -> DoChildren @@ -843,7 +843,7 @@ module StatefulIntToBits = struct (merge t f,acc@[Stmt_If(e, tstmts, [], fstmts, loc)]) | _ -> (* Otherwise, we have no statement nesting *) - let stmt = Asl_visitor.visit_stmt v stmt in + let stmt = Asl_visitor.visit_stmt_single v stmt in let (st,stmt) = (match stmt with (* Match integer writes *) @@ -1489,8 +1489,8 @@ module RedundantSlice = struct | Some (Type_Array(_ix,ty)) -> Some ty | _ -> None - method! vstmt (s: stmt): stmt visitAction = - ChangeDoChildrenPost(s, fun s -> self#update_lvar_types s; s) + method! vstmt (s: stmt): stmt list visitAction = + singletonVisitAction @@ ChangeDoChildrenPost(s, fun s -> self#update_lvar_types s; s) method! vexpr (e: expr): expr visitAction = ChangeDoChildrenPost(e, fun e -> @@ -1581,7 +1581,7 @@ module CommonSubExprElim = struct in result - method! vstmt (s: stmt): stmt visitAction = + method! vstmt (s: stmt): stmt list visitAction = let () = match s with | Stmt_ConstDecl(_, Ident(n), _, Unknown) when (Str.string_match (Str.regexp "Cse") n 0) -> do_replace <- false @@ -1714,11 +1714,11 @@ module CaseSimp = struct inherit Asl_visitor.nopAslVisitor (* Assumes x is pure, as it is referenced within a branch condition *) - method! vstmt (s: stmt): stmt visitAction = + method! vstmt (s: stmt): stmt list visitAction = match match_outer s with | Some (x, r, w, loc, res) when is_total w res -> (match List.find_opt (fun (test,_) -> StringMap.for_all test res) fn_guess with - | Some (_,fn) -> ChangeTo (fn r x w loc) + | Some (_,fn) -> ChangeTo [fn r x w loc] | _ -> DoChildren) | _ -> DoChildren @@ -1781,6 +1781,7 @@ module type ScopedBindings = sig val add_bind : 'elt t -> ident -> 'elt -> unit val find_binding : 'elt t -> ident -> 'elt option val current_scope_bindings : 'elt t -> 'elt Bindings.t + val init: unit -> 'elt t end module ScopedBindings : ScopedBindings = struct @@ -1789,6 +1790,7 @@ module ScopedBindings : ScopedBindings = struct let pop_scope (b:'elt t) (_:unit) : unit = Stack.pop_opt b |> ignore let add_bind (b:'elt t) k v : unit = Stack.push (Bindings.add k v (Stack.pop b)) b let find_binding (b:'elt t) (i) : 'a option = Seq.find_map (fun s -> Bindings.find_opt i s) (Stack.to_seq b) + let init (u:unit) : 'elt t = let s = Stack.create () in Stack.push (Bindings.empty) s; s (** returns a flattened view of bindings accessible from the current (innermost) scope. *) @@ -1829,7 +1831,7 @@ module FixRedefinitions = struct | None -> {name=i; index=0} method! vstmt s = - match s with + singletonVisitAction @@ match s with | Stmt_VarDeclsNoInit(ty, vs, loc) -> let ns = List.map this#incr_binding vs in List.iter this#add_bind ns; DoChildren @@ -1866,7 +1868,6 @@ module FixRedefinitions = struct (match (this#existing_binding e) with | Some e -> ChangeTo (ident_for_v e) | None -> SkipChildren) - end let run (g: IdentSet.t) (s:stmt list) : stmt list = diff --git a/libASL/utils.ml b/libASL/utils.ml index e71887df..cc0917c1 100644 --- a/libASL/utils.ml +++ b/libASL/utils.ml @@ -7,6 +7,15 @@ (** Generic utility functions *) +let rec mkdir_p p = + let open Filename in + if Sys.file_exists p then + () + else + (* make parents, then make final directory. *) + (mkdir_p (dirname p); Sys.mkdir p 0o755) + + (**************************************************************** * Pretty-printer related ****************************************************************) diff --git a/offlineASL-cpp/.gitignore b/offlineASL-cpp/.gitignore new file mode 100644 index 00000000..07b5252f --- /dev/null +++ b/offlineASL-cpp/.gitignore @@ -0,0 +1,3 @@ +generated +!.gitkeep +build diff --git a/offlineASL-cpp/aslp_lifter.hpp b/offlineASL-cpp/aslp_lifter.hpp new file mode 100644 index 00000000..2a197681 --- /dev/null +++ b/offlineASL-cpp/aslp_lifter.hpp @@ -0,0 +1 @@ +#include "generated/aslp_lifter.hpp" // IWYU pragma: export diff --git a/offlineASL-cpp/aslp_lifter_impl.hpp b/offlineASL-cpp/aslp_lifter_impl.hpp new file mode 100644 index 00000000..a7dfaaab --- /dev/null +++ b/offlineASL-cpp/aslp_lifter_impl.hpp @@ -0,0 +1 @@ +#include "generated/aslp_lifter_impl.hpp" // IWYU pragma: export diff --git a/offlineASL-cpp/build.sh b/offlineASL-cpp/build.sh new file mode 100755 index 00000000..f062d62e --- /dev/null +++ b/offlineASL-cpp/build.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +set -o pipefail + +GEN_DIR="${GEN_DIR:-$(dirname $0)/../../aslp-lifter-cpp}" +set -xe + +[[ -d "$GEN_DIR" ]] +if [[ -z "$CXX" ]]; then + export CXX=$(command -v clang++) +fi + +echo ":gen A64 .+ cpp $GEN_DIR/subprojects/aslp-lifter" | dune exec asli + +cd $GEN_DIR +rm -rf build + +meson setup build $MESONFLAGS +meson compile -C build diff --git a/offlineASL-cpp/dune b/offlineASL-cpp/dune new file mode 100644 index 00000000..8a10dd75 --- /dev/null +++ b/offlineASL-cpp/dune @@ -0,0 +1,6 @@ +(install + (section (site (asli aslfiles))) + (files + (interface.hpp as include/aslp/interface.hpp) + (aslp_lifter.hpp as include/aslp/aslp_lifter.hpp) + (aslp_lifter_impl.hpp as include/aslp/aslp_lifter_impl.hpp))) diff --git a/offlineASL-cpp/interface.hpp b/offlineASL-cpp/interface.hpp new file mode 100644 index 00000000..690fc19e --- /dev/null +++ b/offlineASL-cpp/interface.hpp @@ -0,0 +1,168 @@ +#pragma once + +#include + +namespace aslp { + +template +concept lifter_traits = requires(T) +{ + typename T::bits; + typename T::bigint; + typename T::rt_expr; + typename T::rt_lexpr; + typename T::rt_label; +}; + +template +class lifter_interface : public Traits { +public: + // bits which are known at lift-time + using typename Traits::bits; + // bigints which are known at lift-time + using typename Traits::bigint; + + // runtime-expression type, i.e. the type of values produced by the semantics + using typename Traits::rt_expr; + using typename Traits::rt_lexpr; + // runtime-label, supports switching blocks during semantics generation + using typename Traits::rt_label; + + // TODO: split lift-time interface from run-time interface + // TODO: more flexible method of adding const-lvalue qualifiers +public: + virtual ~lifter_interface() = default; + + virtual rt_lexpr v_PSTATE_C() = 0; + virtual rt_lexpr v_PSTATE_Z() = 0; + virtual rt_lexpr v_PSTATE_V() = 0; + virtual rt_lexpr v_PSTATE_N() = 0; + virtual rt_lexpr v__PC() = 0; + virtual rt_lexpr v__R() = 0; + virtual rt_lexpr v__Z() = 0; + virtual rt_lexpr v_SP_EL0() = 0; + virtual rt_lexpr v_FPSR() = 0; + virtual rt_lexpr v_FPCR() = 0; + + virtual rt_lexpr v_PSTATE_BTYPE() = 0; + virtual rt_lexpr v_BTypeCompatible() = 0; + virtual rt_lexpr v___BranchTaken() = 0; + virtual rt_lexpr v_BTypeNext() = 0; + virtual rt_lexpr v___ExclusiveLocal() = 0; + + virtual bits bits_lit(unsigned width, std::string_view str) = 0; + virtual bits bits_zero(unsigned width) = 0; + virtual bigint bigint_lit(std::string_view str) = 0; + virtual bigint bigint_zero() = 0; + virtual bits extract_bits(const bits &val, bigint lo, bigint wd) = 0; + + virtual bool f_eq_bits(const bits &x, const bits &y) = 0; + virtual bool f_ne_bits(const bits &x, const bits &y) = 0; + virtual bits f_add_bits(const bits &x, const bits &y) = 0; + virtual bits f_sub_bits(const bits &x, const bits &y) = 0; + virtual bits f_mul_bits(const bits &x, const bits &y) = 0; + virtual bits f_and_bits(const bits &x, const bits &y) = 0; + virtual bits f_or_bits(const bits &x, const bits &y) = 0; + virtual bits f_eor_bits(const bits &x, const bits &y) = 0; + virtual bits f_not_bits(const bits &x) = 0; + virtual bool f_slt_bits(const bits &x, const bits &y) = 0; + virtual bool f_sle_bits(const bits &x, const bits &y) = 0; + virtual bits f_zeros_bits(bigint n) = 0; + virtual bits f_ones_bits(bigint n) = 0; + virtual bits f_replicate_bits(const bits &x, bigint n) = 0; + virtual bits f_append_bits(const bits &x, const bits &y) = 0; + virtual bits f_ZeroExtend(const bits &x, bigint wd) = 0; + virtual bits f_SignExtend(const bits &x, bigint wd) = 0; + virtual bits f_lsl_bits(const bits &x, const bits &y) = 0; + virtual bits f_lsr_bits(const bits &x, const bits &y) = 0; + virtual bits f_asr_bits(const bits &x, const bits &y) = 0; + virtual bigint f_cvt_bits_uint(const bits &x) = 0; + + virtual rt_lexpr f_decl_bv(std::string_view name, bigint width) = 0; + virtual rt_lexpr f_decl_bool(std::string_view name) = 0; + + virtual void f_switch_context(rt_label label) = 0; + virtual rt_label f_true_branch(std::tuple) = 0; + virtual rt_label f_false_branch(std::tuple) = 0; + virtual rt_label f_merge_branch(std::tuple) = 0; + virtual std::tuple f_gen_branch(rt_expr cond) = 0; + + virtual void f_gen_assert(rt_expr cond) = 0; + virtual rt_expr f_gen_bit_lit(bits bits) = 0; + virtual rt_expr f_gen_bool_lit(bool b) = 0; + virtual rt_expr f_gen_int_lit(bigint i) = 0; + virtual rt_expr f_gen_load(rt_lexpr ptr) = 0; + virtual void f_gen_store(rt_lexpr var, rt_expr exp) = 0; + virtual rt_expr f_gen_array_load(rt_lexpr array, bigint index) = 0; // XXX unsure of array ptr type. + virtual void f_gen_array_store(rt_lexpr array, bigint index, rt_expr exp) = 0; + virtual void f_gen_Mem_set(rt_expr ptr, rt_expr width, rt_expr acctype, rt_expr exp) = 0; + virtual rt_expr f_gen_Mem_read(rt_expr ptr, rt_expr width, rt_expr acctype) = 0; + virtual void f_gen_AArch64_MemTag_set(rt_expr address, rt_expr acctype, rt_expr value) = 0; + + virtual void f_AtomicStart() = 0; + virtual void f_AtomicEnd() = 0; + + virtual rt_expr f_gen_cvt_bits_uint(rt_expr bits) = 0; + virtual rt_expr f_gen_cvt_bool_bv(rt_expr e) = 0; + + virtual rt_expr f_gen_not_bool(rt_expr e) = 0; + virtual rt_expr f_gen_and_bool(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_or_bool(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_eq_enum(rt_expr x, rt_expr y) = 0; + + virtual rt_expr f_gen_not_bits(rt_expr x) = 0; + virtual rt_expr f_gen_eq_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_ne_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_or_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_eor_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_and_bits(rt_expr x, rt_expr y) = 0; + + virtual rt_expr f_gen_add_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_sub_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_sdiv_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_sle_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_slt_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_mul_bits(rt_expr x, rt_expr y) = 0; + + virtual rt_expr f_gen_append_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_lsr_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_lsl_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_asr_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_replicate_bits(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_ZeroExtend(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_SignExtend(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_slice(rt_expr e, bigint lo, bigint wd) = 0; + + virtual rt_expr f_gen_FPCompare(rt_expr x, rt_expr y, rt_expr signalnan, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPCompareEQ(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPCompareGE(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPCompareGT(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPAdd(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPSub(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMulAdd(rt_expr addend, rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMulAddH(rt_expr addend, rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMulX(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMul(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPDiv(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMin(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMinNum(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMax(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPMaxNum(rt_expr x, rt_expr y, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPRecpX(rt_expr x, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPSqrt(rt_expr x, rt_expr fpcr) = 0; + virtual rt_expr f_gen_FPRecipEstimate(rt_expr x, rt_expr fpcr) = 0; + virtual rt_expr f_gen_BFAdd(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_BFMul(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_FPConvertBF(rt_expr x, rt_expr fpcr, rt_expr rounding) = 0; + virtual rt_expr f_gen_FPRecipStepFused(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_FPRSqrtStepFused(rt_expr x, rt_expr y) = 0; + virtual rt_expr f_gen_FPToFixed(rt_expr x, rt_expr fbits, rt_expr isunsigned, rt_expr fpcr, rt_expr rounding) = 0; + virtual rt_expr f_gen_FixedToFP(rt_expr x, rt_expr fbits, rt_expr isunsigned, rt_expr fpcr, rt_expr rounding) = 0; + virtual rt_expr f_gen_FPConvert(rt_expr x, rt_expr fpcr, rt_expr rounding) = 0; + virtual rt_expr f_gen_FPRoundInt(rt_expr x, rt_expr fpcr, rt_expr rounding, rt_expr isexact) = 0; + virtual rt_expr f_gen_FPRoundIntN(rt_expr x, rt_expr fpcr, rt_expr rounding, rt_expr intsize) = 0; + virtual rt_expr f_gen_FPToFixedJS_impl(rt_expr x, rt_expr fpcr, rt_expr is64) = 0; // from override.asl + +}; + +} // namespace aslp diff --git a/scalaOfflineASL/utils.scala b/scalaOfflineASL/utils.scala new file mode 100644 index 00000000..8e10c73f --- /dev/null +++ b/scalaOfflineASL/utils.scala @@ -0,0 +1,505 @@ +package aslloader + +import util.Logger +import ir._ +import analysis.BitVectorEval._ +import collection.mutable.ArrayBuffer +import collection.mutable + +type RTSym = Expr +type RTLabel = String +type BV = BitVecLiteral +def rTLabelDefault = "undef" +def rTSymDefault = null +def rTExprDefault = null + +import ir.dsl._ + + +object Counter: + var counter = 0 + +case class BranchInfo(val branch: Option[String], val guard: Expr, val branchTaken: Boolean, pcAssigned: Option[Expr]) +class LiftState(val entry: String = "block") { + + val endian = Endian.LittleEndian + val memory = Memory("mem", 64, 8) + + var current_pos: String = entry + + + val controlFlow: mutable.Map[String, EventuallyJump] = mutable.Map() + val blocks: mutable.LinkedHashMap[String, ArrayBuffer[Statement]] = mutable.LinkedHashMap((entry -> ArrayBuffer.empty)) + val branches: mutable.Map[String, (String, String, String)] = mutable.Map() + + + var current_guard: BranchInfo = BranchInfo(None, TrueLiteral, false, None) + + // maps block ids to guards + // We push the current guard forward as blocks are appended. We maintain this mapping so we + // can update the current guard when switch_ctx is called. + val block_guard: mutable.Map[String, BranchInfo] = mutable.Map() + + def pcAssigns = block_guard.filter { + case (k,v) => v.pcAssigned.isDefined + } + + def new_name(p: Option[String] = None) = { + Counter.counter += 1 + entry + "_" + p.map(_ + "_").getOrElse("") + Counter.counter + } + + def merge_state(other: LiftState) = { + controlFlow.addAll(other.controlFlow) + blocks.addAll(other.blocks) + branches.addAll(other.branches) + current_pos = other.current_pos + } + + def escaping_jumps = controlFlow + .collect { + case (_, EventuallyGoto(tgts)) => tgts.map((t: DelayNameResolve) => t.ident) + case (_, EventuallyCall(t, Some(ft))) => List(ft.ident) + case (_, EventuallyIndirectCall(_, Some(ft))) => List(ft.ident) + } + .flatMap(_.toList) + .filter((t) => !blocks.keySet.contains(t)) + .toSet + + def push_block(p: Option[String] = None): String = { + val n = new_name(p) + + blocks(n) = ArrayBuffer.empty + n + } + + def implicit_set_pc(address: Long, label: Option[String] = None) = { + val la = LocalAssign(Register("_PC", BitVecType(64)), BitVecLiteral(BigInt(address), 64), label) + blocks(current_pos).append(la) + } + + def push_stmt(s: Statement) = { + s match { + case LocalAssign(Register("BranchTaken", BoolType), TrueLiteral, _) => current_guard = BranchInfo(current_guard.branch, current_guard.guard, true, current_guard.pcAssigned) + case LocalAssign(Register("_PC", BitVecType(64)), addr, _) => current_guard = BranchInfo(current_guard.branch, current_guard.guard, current_guard.branchTaken, Some(addr)) + case _ => () + } + + block_guard(current_pos) = current_guard + blocks(current_pos).append(s) + } + + def switch_ctx(c: String) = { + require(blocks.keySet.contains(c)) + current_pos = c + current_guard = block_guard(current_pos) + } + + def gen_branch(cond: Expr) = { + val branch_id = new_name(Some("branch")) + + val true_branch = push_block(Some("true")) + val false_branch = push_block(Some("false")) + val merge_block = push_block(Some("join")) + blocks(true_branch).append(Assume(cond)) + blocks(false_branch).append(Assume(UnaryExpr(BoolNOT, cond))) + + block_guard(true_branch) = BranchInfo(Some(branch_id), cond, false, None) + block_guard(false_branch) = BranchInfo(Some(branch_id), UnaryExpr(BoolNOT, cond), false, None) + block_guard(merge_block) = BranchInfo(None, TrueLiteral, false, None) + current_guard = BranchInfo(Some(branch_id), current_guard.guard, current_guard.branchTaken, current_guard.pcAssigned) + + controlFlow(current_pos) = goto(true_branch, false_branch) + controlFlow(true_branch) = goto(merge_block) + controlFlow(false_branch) = goto(merge_block) + branches.addOne((branch_id -> (true_branch, false_branch, merge_block))) + switch_ctx (merge_block) + (branch_id, true_branch, false_branch, merge_block) + } + + def replace_jmp(c: EventuallyJump) = controlFlow(current_pos) = c + + def add_call(from: String, c: EventuallyJump) : Unit = { + controlFlow.get(from) match { + case None => controlFlow(current_pos) = c + case Some(EventuallyGoto(List(x))) => { + c match { + case EventuallyCall(c, None) => EventuallyCall(c, Some(x)) + case EventuallyCall(_, Some(f)) => add_call(f.ident, c) + case EventuallyIndirectCall(c, None) => EventuallyIndirectCall(c, Some(x)) + case EventuallyIndirectCall(_, Some(f)) => add_call(f.ident, c) + case EventuallyGoto(cs) => (EventuallyGoto(cs ++ List(x))) + case _ => throw Exception(s"Existing jump ${EventuallyGoto(List(x))} adding $c") + } + + } + case Some(l) => throw Exception(s"Existing jump $l") + } + } + + def add_call(c: EventuallyJump) : Unit = { + add_call(current_pos, c) + } + + def add_goto(l: String) = { + controlFlow.get(current_pos) match { + case Some(EventuallyGoto(ts)) => controlFlow(current_pos) = EventuallyGoto(ts ++ List(DelayNameResolve(l))) + case Some(EventuallyCall(ts, None)) => controlFlow(current_pos) = EventuallyCall(ts, Some(DelayNameResolve(l))) + case None => controlFlow(current_pos) = EventuallyGoto(List(DelayNameResolve(l))) + case Some(l) => throw RuntimeException(s"Cannot add goto target to call $l") + } + } + + def toIR(): List[EventuallyBlock] = + blocks.map((n, stmts) => block(n, (stmts.toList ++ List(controlFlow.getOrElse(n, ret))))).toList +} + +object Lifter { + + def liftOpcode(op: Int, sp: BitVecLiteral): List[EventuallyBlock] = { + var liftState = LiftState() + val dec = A64_decoder() + val r = dec.decode(liftState, BitVecLiteral(BigInt(op), 32), sp) + liftState.toIR() + } + + + def liftOpcode(op: Int): List[EventuallyBlock] = { + var liftState = LiftState() + val dec = A64_decoder() + val r = dec.decode(liftState, BitVecLiteral(BigInt(op), 32), BitVecLiteral(BigInt(0), 64)) + liftState.toIR() + } + +} + +class Mutable[T](var v: T) + +def extract(x: BigInt, sz: BigInt) = x % (BigInt(2).pow((sz + 1).toInt)) + +def mkBits(st: LiftState, n: BigInt, y: BigInt): BitVecLiteral = { + require(n >= 0) + BitVecLiteral(extract(y, n), n.toInt) +} + +def zero_extend_to(s: BigInt, x: BitVecLiteral) = { + require(s > x.size) + BitVecLiteral(x.value, s.toInt) +} + +def gen_zero_extend_to(s: BigInt, x: Expr) = { + x.getType match { + case BitVecType(sz) if sz == s.toInt => x + case BitVecType(sz) => ZeroExtend((s - sz).toInt, x) + case _ => throw Exception("Type mismatch gen_zero_extend_to") + } +} + +def bvextract(st: LiftState, e: BitVecLiteral, lo: BigInt, width: BigInt): BitVecLiteral = + smt_extract((lo + width - 1).toInt, lo.toInt, e) + +def f_eq_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): Boolean = (smt_bveq(x, y) == TrueLiteral) + +def f_ne_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): Boolean = (smt_bveq(x, y) == FalseLiteral) + +def f_add_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): BitVecLiteral = (smt_bvadd(x, y)) + +def f_sub_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): BitVecLiteral = (smt_bvsub(x, y)) + +def f_mul_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): BitVecLiteral = (smt_bvmul(x, y)) + +def f_and_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): BitVecLiteral = (smt_bvand(x, y)) + +def f_or_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): BitVecLiteral = (smt_bvor(x, y)) + +def f_eor_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): BitVecLiteral = (smt_bvxor(x, y)) + +def f_not_bits(st: LiftState, t: BigInt, x: BitVecLiteral): BitVecLiteral = (smt_bvnot(x)) + +def f_slt_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): Boolean = + (TrueLiteral == (smt_bvslt(x, y))) + +def f_sle_bits(st: LiftState, t: BigInt, x: BitVecLiteral, y: BitVecLiteral): Boolean = + (TrueLiteral == (smt_bvsle(x, y))) + +def f_zeros_bits(st: LiftState, w: BigInt): BitVecLiteral = BitVecLiteral(0, w.toInt) + +def f_ones_bits(st: LiftState, w: BigInt): BitVecLiteral = BitVecLiteral(BigInt(2).pow(w.toInt) - 1, w.toInt) + +def f_ZeroExtend(st: LiftState, t0: BigInt, t1: BigInt, n: BitVecLiteral, x: BigInt): BitVecLiteral = + smt_zero_extend(x.toInt - n.size, n) + +def f_SignExtend(st: LiftState, t0: BigInt, t1: BigInt, n: BitVecLiteral, x: BigInt): BitVecLiteral = + smt_sign_extend(x.toInt - n.size, n) + +def f_asr_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: BitVecLiteral, arg1: BitVecLiteral): BitVecLiteral = + smt_bvashr(arg0, arg1) + +def f_lsl_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: BitVecLiteral, arg1: BitVecLiteral): BitVecLiteral = + smt_bvshl(arg0, zero_extend_to(arg0.size, arg1)) + +def f_lsr_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: BitVecLiteral, arg1: BitVecLiteral): BitVecLiteral = + smt_bvlshr(arg0, zero_extend_to(arg0.size, arg1)) + +def f_decl_bool(st: LiftState, arg0: String): RTSym = LocalVar(arg0, BoolType) +def f_decl_bv(st: LiftState, arg0: String, arg1: BigInt): RTSym = LocalVar(arg0, BitVecType(arg1.toInt)) + +def f_gen_BFAdd(st: LiftState, arg0: RTSym, arg1: RTSym): RTSym = throw NotImplementedError("func not implemented") +def f_gen_BFMul(st: LiftState, arg0: RTSym, arg1: RTSym): RTSym = throw NotImplementedError("func not implemented") + +def f_gen_FPAdd(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPCompare(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym, arg3: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPCompareEQ(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPCompareGE(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPCompareGT(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPConvert(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPConvertBF(st: LiftState, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPDiv(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPMax(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPMaxNum(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPMin(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPMinNum(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPMul(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPMulAdd(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym, arg3: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPMulAddH(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym, arg3: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPMulX(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPRSqrtStepFused(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPRecipEstimate(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPRecipStepFused(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPRecpX(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPRoundInt(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym, arg3: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPRoundIntN(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym, arg3: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FPSqrt(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPSub(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) +def f_gen_FPToFixed( + st: LiftState, + targ0: BigInt, + targ1: BigInt, + arg0: RTSym, + arg1: RTSym, + arg2: RTSym, + arg3: RTSym, + arg4: RTSym +): RTSym = throw NotImplementedError("func not implemented") +def f_gen_FPToFixedJS_impl(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + throw NotImplementedError("func not implemented") +def f_gen_FixedToFP( + st: LiftState, + targ0: BigInt, + targ1: BigInt, + arg0: RTSym, + arg1: RTSym, + arg2: RTSym, + arg3: RTSym, + arg4: RTSym +): RTSym = throw NotImplementedError("func not implemented") + +def f_gen_bit_lit(st: LiftState, targ0: BigInt, arg0: BitVecLiteral): RTSym = BitVecLiteral(arg0.value, targ0.toInt) +def f_gen_bool_lit(st: LiftState, arg0: Boolean): RTSym = if arg0 then TrueLiteral else FalseLiteral + +def f_gen_branch(st: LiftState, arg0: RTSym): RTLabel = st.gen_branch(arg0)._1 +def f_true_branch(st: LiftState, arg0: RTLabel): RTLabel = (st.branches(arg0))._1 +def f_false_branch(st: LiftState, arg0: RTLabel): RTLabel = (st.branches(arg0))._2 +def f_merge_branch(st: LiftState, arg0: RTLabel): RTLabel = (st.branches(arg0))._3 + +def f_cvt_bits_uint(st: LiftState, targ0: BigInt, arg0: BitVecLiteral): BigInt = arg0.value +def f_gen_cvt_bits_uint(st: LiftState, targ0: BigInt, arg0: RTSym): RTSym = throw Exception("cvt bitsnot implemented") +def f_gen_cvt_bool_bv(st: LiftState, arg0: RTSym): RTSym = arg0 match { + case b: BinaryExpr if b.op == BVEQ => BinaryExpr(BVCOMP, b.arg1, b.arg2) + case _ => throw Exception(s"unhandled conversion from bool to bitvector: ${arg0}") + } + +def f_gen_eor_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVXOR, arg0, arg1) +def f_gen_eq_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVEQ, arg0, arg1) +/*{ + (arg0.getType, arg1.getType) match { + case (b:BitVecType, v:BitVecType) => BinaryExpr(BVEQ, arg0, arg1) + case (b:BitVecType, BoolType) => BinaryExpr(BVEQ, BinaryExpr(BVEQ, BitVecLiteral(0, targ0.toInt), arg0), arg1) + case (BoolType, b:BitVecType) => BinaryExpr(BVEQ, BinaryExpr(BVEQ, BitVecLiteral(0, targ0.toInt), arg1), arg0) + case (b:BitVecType, IntType) => BinaryExpr(IntEQ, BinaryExpr(BVEQ, BitVecLiteral(0, targ0.toInt), arg0), arg1) + case (IntType, b:BitVecType) => BinaryExpr(IntEQ, BinaryExpr(BVEQ, BitVecLiteral(0, targ0.toInt), arg1), arg0) + case (BoolType, BoolType) => BinaryExpr(BoolEQ, arg0, arg1) + case (IntType, IntType) => BinaryExpr(IntEQ, arg0, arg1) + } +}*/ + + +/*def coerceTo(typ: IRType, v: Expr) = { + (typ, v.getType) match { + case (a, b) if a == b => v + case (BitVecType(a), BitVecType(b)) if a > b => ZeroExtend((a - b).toInt, v) + case (BoolType, BitVecType(s)) => BinaryExpr(BVNEQ, BitVecLiteral(0, s), v) + case (BitVecType(s), BoolType) = + } +}*/ + +def f_gen_eq_enum(st: LiftState, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVEQ, arg0, arg1) +def f_gen_int_lit(st: LiftState, arg0: BigInt): BitVecLiteral = BitVecLiteral(arg0, 1123) + +def f_gen_store(st: LiftState, lval: RTSym, e: RTSym): Unit = lval match + case v: Variable => st.push_stmt(LocalAssign(v, e)) + case m => throw NotImplementedError(s"fail assign $m") + +def f_gen_load(st: LiftState, e: RTSym): RTSym = e match + case m: Memory => throw NotImplementedError() + case _ => e + +def f_gen_SignExtend(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: Expr, arg1: BitVecLiteral): RTSym = { + val oldSize = (targ0) + val newSize = (targ1) + if (arg1.value != newSize) { + throw Exception() + } + SignExtend((newSize - oldSize).toInt, arg0) +} + +def f_gen_ZeroExtend(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: Expr, arg1: BitVecLiteral): RTSym = { + val oldSize = (targ0) + val newSize = (targ1) + if (arg1.value != newSize) { + throw Exception() + } + if ((newSize - oldSize) == 0) then arg0 else ZeroExtend((newSize - oldSize).toInt, arg0) +} + +def f_gen_add_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVADD, arg0, arg1) +def f_gen_and_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVAND, arg0, arg1) +def f_gen_and_bool(st: LiftState, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BoolAND, arg0, arg1) + +def f_gen_asr_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: RTSym, arg1: RTSym): RTSym = + BinaryExpr(BVASHR, arg0, gen_zero_extend_to(targ0, arg1)) +def f_gen_lsl_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: RTSym, arg1: RTSym): RTSym = + BinaryExpr(BVSHL, arg0, gen_zero_extend_to(targ0, arg1)) +def f_gen_lsr_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: RTSym, arg1: RTSym): RTSym = + BinaryExpr(BVLSHR, arg0, gen_zero_extend_to(targ0, arg1)) +def f_gen_mul_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVMUL, arg0, arg1) +def f_gen_ne_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVCOMP, arg0, arg1) +def f_gen_not_bits(st: LiftState, targ0: BigInt, arg0: RTSym): RTSym = arg0.getType match { + case BoolType => UnaryExpr(BoolNOT, arg0) + case BitVecType(_) => UnaryExpr(BVNOT, arg0) + case _: MapType => throw IllegalArgumentException() + case IntType => throw IllegalArgumentException() +} + +def f_gen_not_bool(st: LiftState, arg0: RTSym): RTSym = arg0.getType match { + case BoolType => UnaryExpr(BoolNOT, arg0) + case BitVecType(sz) => BinaryExpr(BVNEQ, BitVecLiteral(0, sz), arg0) + case _: MapType => throw IllegalArgumentException() + case IntType => throw IllegalArgumentException() +} + +def f_gen_or_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVOR, arg0, arg1) +def f_gen_or_bool(st: LiftState, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BoolOR, arg0, arg1) +def f_gen_sdiv_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVSDIV, arg0, arg1) +def f_gen_sle_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVSLE, arg0, arg1) +def f_gen_slt_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVSLT, arg0, arg1) +def f_gen_sub_bits(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym): RTSym = BinaryExpr(BVSUB, arg0, arg1) + +def f_AtomicEnd(st: LiftState): RTSym = Register("ATOMICSTART", BoolType) +def f_AtomicStart(st: LiftState): RTSym = Register("ATOMICSTART", BoolType) + +def f_replicate_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: BitVecLiteral, arg1: BigInt): BitVecLiteral = { + bv_replicate(arg0, arg1.toInt) +} +def f_append_bits(st: LiftState, targ0: BigInt, targ1: BigInt, a: BitVecLiteral, b: BitVecLiteral): BitVecLiteral = + BitVecLiteral((a.value << b.size) + b.value, (a.size + b.size)) + +def f_gen_AArch64_MemTag_set(st: LiftState, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = throw NotImplementedError( + "func not implemented" +) + +def f_gen_Mem_set(st: LiftState, sz: BigInt, ptr: RTSym, width: BitVecLiteral, acctype: RTSym, value: RTSym): Unit = + assert(width.value == sz) + val expr = MemoryStore(st.memory, ptr, value, st.endian, sz.toInt * st.memory.valueSize) + val stmt = MemoryAssign(st.memory, expr) + st.push_stmt(stmt) + +def f_gen_Mem_read(st: LiftState, targ0: BigInt, arg0: RTSym, arg1: RTSym, arg2: RTSym): RTSym = + MemoryLoad(st.memory, arg0, st.endian, targ0.toInt * st.memory.valueSize) + +def f_gen_slice(st: LiftState, e: RTSym, lo: BigInt, wd: BigInt): RTSym = { + Extract((wd + lo).toInt, lo.toInt, e) +} +def f_gen_replicate_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: RTSym, arg1: BitVecLiteral): RTSym = { + Range.Exclusive(1, arg1.value.toInt, 1).map(v => arg0).foldLeft(arg0)((a, b) => (BinaryExpr(BVCONCAT, a, b))) +} +def f_gen_append_bits(st: LiftState, targ0: BigInt, targ1: BigInt, arg0: RTSym, arg1: RTSym): RTSym = + BinaryExpr(BVCONCAT, arg0, arg1) + +def f_gen_array_load(st: LiftState, arg0: RTSym, arg1: BigInt): RTSym = arg0 match + case Register("_R", t) => Register("R" + arg1, BitVecType(64)) + case Register("_Z", t) => Register("V" + arg1, BitVecType(128)) + case _ => { + Logger.warn(s"Unknown array load $arg0") + arg0 + } +def f_gen_array_store(st: LiftState, arg0: RTSym, arg1: BigInt, arg2: RTSym): Unit = arg0 match + case Register(n, t) if n.contains("R") => st.push_stmt(LocalAssign(Register("R" + arg1, BitVecType(64)), arg2)) + case _ => Logger.warn(s"Unknown array store $arg0") + +def f_gen_assert(st: LiftState, arg0: RTSym) = st.push_stmt(Assert(arg0)) +def f_switch_context(st: LiftState, arg0: RTLabel) = st.switch_ctx(arg0) + +/** Global variable definitions * */ + +def v_PSTATE_UAO = Mutable(Register("UAO", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_PAN = Mutable(Register("PAN", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_DIT = Mutable(Register("DIT", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_SSBS = Mutable(Register("SSBS", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_G = Mutable(Register("G", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_A = Mutable(Register("A", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_I = Mutable(Register("I", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_F = Mutable(Register("F", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_D = Mutable(Register("D", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "G") +def v_PSTATE_C = Mutable(Register("CF", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "C") +def v_PSTATE_Z = Mutable(Register("ZF", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "Z") +def v_PSTATE_V = Mutable(Register("VF", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "V") +def v_PSTATE_N = Mutable(Register("NF", BitVecType(1))) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "N") +def v__PC = Mutable(Register("_PC", BitVecType(64))) // Expr_Var(Ident "_PC") +def v__R = Mutable(Register("_R", MapType(BitVecType(64),BitVecType(64)))) +def v__Z = Mutable(Register("_Z", MapType(BitVecType(64),BitVecType(128)))) +def v_SP_EL0 = Mutable(Register("R31", BitVecType(64))) +def v_FPSR = Mutable(Register("FPSR", BoolType)) +def v_FPCR = Mutable(LocalVar("FPCR", BitVecType(32))) + +def v_PSTATE_BTYPE = Mutable(Register("PSTATE.BTYPE", BoolType)) // Expr_Field(Expr_Var(Ident "PSTATE"), Ident "BTYPE") +def v_BTypeCompatible = Mutable(Register("BTypeCompatible", BoolType)) // Expr_Var (Ident "BTypeCompatible") +def v___BranchTaken = Mutable(Register("BranchTaken", BoolType)) +def v_BTypeNext = Mutable(Register("BTypeNext", BoolType)) +def v___ExclusiveLocal = Mutable(Register("__ExclusiveLocal", BoolType))