Skip to content

Commit

Permalink
Eliminate comparisons during IntToBits
Browse files Browse the repository at this point in the history
Leverage interval information to reduce trivial comparisons
during post-passes, then cleanup in RemoveUnused.
  • Loading branch information
ncough committed May 19, 2024
1 parent 256213d commit 3e0f2d0
Showing 1 changed file with 75 additions and 20 deletions.
95 changes: 75 additions & 20 deletions libASL/transforms.ml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ let infer_type (e: expr): ty option =

(** Remove variables which are unused at the end of the statement list. *)
module RemoveUnused = struct
let rec is_false = function
| Expr_Var (Ident "FALSE") -> true
| Expr_TApply (FIdent ("or_bool", 0), [], [a;b]) -> is_false a && is_false b
| Expr_TApply (FIdent ("and_bool", 0), [], [a;b]) -> is_false a || is_false b
| _ -> false

let rec is_true = function
| Expr_Var (Ident "TRUE") -> true
| Expr_TApply (FIdent ("and_bool", 0), [], [a;b]) -> is_true a && is_true b
| Expr_TApply (FIdent ("or_bool", 0), [], [a;b]) -> is_true a || is_true b
| _ -> false

let rec remove_unused (globals: IdentSet.t) xs = fst (remove_unused' globals IdentSet.empty xs)

and remove_unused' globals (used: IdentSet.t) (xs: stmt list): (stmt list * IdentSet.t) =
Expand Down Expand Up @@ -152,10 +164,14 @@ module RemoveUnused = struct
else pass

(* Skip if structure if possible - often seen in decode tests *)
| Stmt_If(Expr_Var (Ident "TRUE"), tstmts, elsif, fstmts, loc) ->
| Stmt_If(c, tstmts, elsif, fstmts, loc) when is_true c ->
let (tstmts',tused) = remove_unused' globals used tstmts in
(tstmts'@acc,tused)

| Stmt_If(c, tstmts, [], fstmts, loc) when is_false c ->
let (fstmts',tused) = remove_unused' globals used fstmts in
(fstmts'@acc,tused)

| Stmt_If(c, tstmts, elsif, fstmts, loc) ->
let (tstmts',tused) = remove_unused' globals used tstmts in
let (fstmts',fused) = remove_unused' globals used fstmts in
Expand All @@ -171,8 +187,10 @@ module RemoveUnused = struct
| [], [], [] -> pass
| _, _, _ -> (Stmt_If(c, tstmts', List.map fst elsif', fstmts', loc)::acc,used))

| Stmt_Assert (c, _) when is_true c -> pass

(* Unreachable points *)
| Stmt_Assert (Expr_Var (Ident "FALSE"), _)
| Stmt_Assert (c, _) when is_false c -> halt stmt
| Stmt_Throw _ -> halt stmt

| x -> emit x
Expand Down Expand Up @@ -387,6 +405,11 @@ module StatefulIntToBits = struct
let l = Z.zero in
(w, false, (u,l))

let abs_of_interval (u: int) (l: int): abs =
let i = (Z.of_int u, Z.of_int l) in
let (w,s) = width_of_interval i in
(w, s, i)

(* Basic merge of abstract points *)
let merge_abs ((lw,ls,(l1,l2)): abs) ((rw,rs,(r1,r2)): abs): abs =
let s = ls || rs in
Expand Down Expand Up @@ -442,6 +465,8 @@ module StatefulIntToBits = struct
let width (n,_,_) = n
let signed (_,s,_) = s
let interval (_,_,i) = i
let lower (_,_,(_,l)) = l
let upper (_,_,(u,_)) = u

(** Convert abstract point width into exprs & symbols *)
let expr_of_abs a =
Expand Down Expand Up @@ -649,49 +674,79 @@ module StatefulIntToBits = struct
| Expr_TApply (FIdent ("eq_int", 0), [], [x;y]) ->
let x = bv_of_int_expr vars x in
let y = bv_of_int_expr vars y in
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
sym_expr @@ sym_prim (FIdent ("eq_bits", 0)) [sym_of_abs w] [ex x; ex y]
(* If y is strictly greater, must be false *)
if Z.gt (lower (snd y)) (upper (snd x)) then expr_false
(* If x is strictly greater, must be false *)
else if Z.gt (lower (snd x)) (upper (snd y)) then expr_false
else
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
sym_expr @@ sym_prim (FIdent ("eq_bits", 0)) [sym_of_abs w] [ex x; ex y]

| Expr_TApply (FIdent ("ne_int", 0), [], [x;y]) ->
let x = bv_of_int_expr vars x in
let y = bv_of_int_expr vars y in
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
sym_expr @@ sym_prim (FIdent ("ne_bits", 0)) [sym_of_abs w] [ex x; ex y]
(* If y is strictly greater, must be true *)
if Z.gt (lower (snd y)) (upper (snd x)) then expr_true
(* If x is strictly greater, must be true *)
else if Z.gt (lower (snd x)) (upper (snd y)) then expr_true
else
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
sym_expr @@ sym_prim (FIdent ("ne_bits", 0)) [sym_of_abs w] [ex x; ex y]

(* x >= y iff y <= x iff x - y >= 0*)
| Expr_TApply (FIdent ("ge_int", 0), [], [x;y])
| Expr_TApply (FIdent ("le_int", 0), [], [y;x]) ->
let x = force_signed (bv_of_int_expr vars x) in
let y = force_signed (bv_of_int_expr vars y) in
let w = merge_abs (snd x) (snd y) in
let ex x = sym_expr (extend w x) in
expr_prim' "sle_bits" [expr_of_abs w] [ex y;ex x]
(* if largest y is smaller or equal than smallest x, must be true *)
if Z.leq (upper (snd y)) (lower (snd x)) then expr_true
(* if smallest y is greater than largest x, must be false *)
else if Z.gt (lower (snd y)) (upper (snd x)) then expr_false
else
let w = merge_abs (snd x) (snd y) in
let ex x = sym_expr (extend w x) in
expr_prim' "sle_bits" [expr_of_abs w] [ex y;ex x]

(* x < y iff y > x iff x - y < 0 *)
| Expr_TApply (FIdent ("lt_int", 0), [], [x;y])
| Expr_TApply (FIdent ("gt_int", 0), [], [y;x]) ->
let x = force_signed (bv_of_int_expr vars x) in
let y = force_signed (bv_of_int_expr vars y) in
let w = merge_abs (snd x) (snd y) in
let ex x = sym_expr (extend w x) in
expr_prim' "slt_bits" [expr_of_abs w] [ex x;ex y]
(* if largest y is smaller or equal than smallest x, must be true *)
if Z.lt (upper (snd x)) (lower (snd y)) then expr_true
(* if smallest y is greater than largest x, must be false *)
else if Z.geq (lower (snd x)) (upper (snd y)) then expr_false
else
let w = merge_abs (snd x) (snd y) in
let ex x = sym_expr (extend w x) in
expr_prim' "slt_bits" [expr_of_abs w] [ex x;ex y]

(* Translation from enum to bit *)
| Expr_TApply (FIdent ("eq_enum", n), [], [x;y]) when n > 0 ->
let x = bv_of_int_expr vars x in
let y = bv_of_int_expr vars y in
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
(sym_expr @@ sym_prim (FIdent ("eq_bits", 0)) [sym_of_abs w] [ex x; ex y])
(* If y is strictly greater, must be false *)
if Z.gt (lower (snd y)) (upper (snd x)) then expr_false
(* If x is strictly greater, must be false *)
else if Z.gt (lower (snd x)) (upper (snd y)) then expr_false
else
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
(sym_expr @@ sym_prim (FIdent ("eq_bits", 0)) [sym_of_abs w] [ex x; ex y])

| Expr_TApply (FIdent ("ne_enum", n), [], [x;y]) when n > 0 ->
let x = bv_of_int_expr vars x in
let y = bv_of_int_expr vars y in
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
(sym_expr @@ sym_prim (FIdent ("ne_bits", 0)) [sym_of_abs w] [ex x; ex y])
(* If y is strictly greater, must be true *)
if Z.gt (lower (snd y)) (upper (snd x)) then expr_true
(* If x is strictly greater, must be true *)
else if Z.gt (lower (snd x)) (upper (snd y)) then expr_true
else
let w = merge_abs (snd x) (snd y) in
let ex = extend w in
(sym_expr @@ sym_prim (FIdent ("ne_bits", 0)) [sym_of_abs w] [ex x; ex y])

(* these functions take bits as first argument and integer as second. just coerce second to bits. *)
(* TODO: primitive implementations of these expressions expect the shift amount to be signed,
Expand Down

0 comments on commit 3e0f2d0

Please sign in to comment.