From f313ed34466fee570d7b163d4c9d6699cb6cd4a5 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Thu, 31 Oct 2024 17:44:18 +1000 Subject: [PATCH] block coalescing --- src/main/scala/cfg_visualiser/DotTools.scala | 2 +- src/main/scala/ir/eval/SimplifyExpr.scala | 6 +- .../transforms/DynamicSingleAssignment.scala | 44 ++- src/main/scala/ir/transforms/Simp.scala | 278 ++++++++++++------ src/main/scala/util/RunUtils.scala | 19 +- 5 files changed, 225 insertions(+), 124 deletions(-) diff --git a/src/main/scala/cfg_visualiser/DotTools.scala b/src/main/scala/cfg_visualiser/DotTools.scala index 75f7b55ee..0cd7614c5 100644 --- a/src/main/scala/cfg_visualiser/DotTools.scala +++ b/src/main/scala/cfg_visualiser/DotTools.scala @@ -57,7 +57,7 @@ class DotNode(val id: String, val label: String) extends DotElement { override def toString: String = toDotString def toDotString: String = - s"\"$id\"" + "[label=\"" + wrap(label, 100) + "\", shape=\"box\", fontname=\"Mono\", fontsize=\"5\"]" + s"\"$id\"" + "[label=\"" + wrap(label, 200) + "\", shape=\"box\", fontname=\"Mono\", fontsize=\"5\"]" } diff --git a/src/main/scala/ir/eval/SimplifyExpr.scala b/src/main/scala/ir/eval/SimplifyExpr.scala index bd837b4e0..ec42a540d 100644 --- a/src/main/scala/ir/eval/SimplifyExpr.scala +++ b/src/main/scala/ir/eval/SimplifyExpr.scala @@ -468,9 +468,9 @@ def simplifyExpr(e: Expr): (Expr, Boolean) = { case BinaryExpr(BVEQ, a, b) if a.loads.isEmpty && b.loads.isEmpty && a == b => TrueLiteral // compose slices - case Extract(ed1, be1, Extract(ed2, be2, body)) => { - Extract(ed1 + be2, be1 + be2, (body)) - } + case Extract(ed1, be1, Extract(ed2, be2, body)) => Extract(ed1 + be2, be1 + be2, (body)) + case SignExtend(sz1, SignExtend(sz2, exp)) => SignExtend(sz1 + sz2, exp) + case ZeroExtend(sz1, ZeroExtend(sz2, exp)) => ZeroExtend(sz1 + sz2, exp) // (comp (comp x y) 1) = (comp x y) case BinaryExpr(BVCOMP, body @ BinaryExpr(BVCOMP, _, _), BitVecLiteral(1, 1)) => (body) diff --git a/src/main/scala/ir/transforms/DynamicSingleAssignment.scala b/src/main/scala/ir/transforms/DynamicSingleAssignment.scala index 5a009ad39..53908582e 100644 --- a/src/main/scala/ir/transforms/DynamicSingleAssignment.scala +++ b/src/main/scala/ir/transforms/DynamicSingleAssignment.scala @@ -9,14 +9,14 @@ import analysis._ val phiAssignLabel = Some("phi") -/** This transforms the program by no-op copies and renaming local variable indices to establish the property that +/** This transforms the program by adding no-op copies and renaming local variable indices to establish the property that * * \forall variables v, forall uses of v : u, No subset of definitions of v defines the use u. */ class OnePassDSA( - liveVarsCheck: Option[Map[CFGPosition, Set[Variable]]] = None - /** Check our (faster) live var result against the TIP sovler solution */ + /** Check our (faster) live var result against the TIP sovler solution + */ ) { val liveVarsDom = transforms.IntraLiveVarsDomain() @@ -141,7 +141,7 @@ class OnePassDSA( // if there is no renaming such that all the incoming renames agree // then we create a new copy case h :: tl => - tl.foldLeft(Some(h): Option[Int])((acc, rn) => + tl.foldLeft(Some(h): Option[Int])((acc : Option[Int], rn: Int) => acc match { case Some(v) if v == rn => Some(v) case _ => None @@ -238,19 +238,19 @@ class OnePassDSA( block: Block ) = { - /** VisitBlock: - * - * 1. for all complete incoming - * - if there is an incoming rename that is not equal across the predecessors - * - add back phi block to each incoming edge - * - create a fresh copy of each non-uniform renamed variable - * - add copies to each phi block unify the incoming rename with the nominated new rename 2. add local value - * numbering to this block 3. if all predecessors are filled, mark this complete, otherwise mark this - * filled. 4. for all successors - * - if marked filled, add phi block to unify our outgoing with its incoming - * - if > 1 total predecessors for each unmarked , add phi block to nominate a new copy 5. for all successors, if - * all predecessors are now filled, mark complete - */ + /** VisitBlock: + * + * 1. for all complete incoming + * - if there is an incoming rename that is not equal across the predecessors + * - add back phi block to each incoming edge + * - create a fresh copy of each non-uniform renamed variable + * - add copies to each phi block unify the incoming rename with the nominated new rename 2. add local value + * numbering to this block 3. if all predecessors are filled, mark this complete, otherwise mark this + * filled. 4. for all successors + * - if marked filled, add phi block to unify our outgoing with its incoming + * - if > 1 total predecessors for each unmarked , add phi block to nominate a new copy 5. for all successors, if + * all predecessors are now filled, mark complete + */ def state(b: Block) = withDefault(_st)(b) @@ -287,16 +287,6 @@ class OnePassDSA( val liveBefore = mutable.Map.from(liveBeforeIn) val liveAfter = mutable.Map.from(liveAfterIn) - for (lvResult <- liveVarsCheck) { - // check live vars has equal result to TIP solver - for ((b, s) <- liveBefore) { - assert( - lvResult(b) == s, - s"LiveVars unequal ${b.label}: ${lvResult(b)} == $s (differing ${lvResult(b).diff(s)})" - ) - } - } - val worklist = mutable.PriorityQueue[Block]()(Ordering.by(b => b.rpoOrder)) worklist.addAll(p.blocks) var seen = Set[Block]() diff --git a/src/main/scala/ir/transforms/Simp.scala b/src/main/scala/ir/transforms/Simp.scala index 1b698c2ec..77974081b 100644 --- a/src/main/scala/ir/transforms/Simp.scala +++ b/src/main/scala/ir/transforms/Simp.scala @@ -28,14 +28,30 @@ trait AbstractDomain[L] { def bot: L } - def getLiveVars(p: Procedure) = { val liveVarsDom = IntraLiveVarsDomain() val liveVarsSolver = worklistSolver(liveVarsDom) - liveVarsSolver.solveProc(p, backwards=true) + liveVarsSolver.solveProc(p, backwards = true) +} + +def difftestLiveVars(p: Procedure, compareResult: Map[CFGPosition, Set[Variable]]) = { + val (liveBefore, liveAfter) = getLiveVars(p) + var passed = true + + for ((b, s) <- liveBefore) { + val c = (compareResult(b) == s) + passed = passed && c + if (!c) { + Logger.error( + s"LiveVars unequal ${b.label}: ${compareResult(b)} == $s (differing ${compareResult(b).diff(s)})" + ) + } + } + passed } + def basicReachingDefs(p: Procedure): Map[Command, Map[Variable, Set[Assign | DirectCall]]] = { val (beforeLive, afterLive) = getLiveVars(p) val dom = DefUseDomain(beforeLive) @@ -43,31 +59,33 @@ def basicReachingDefs(p: Procedure): Map[Command, Map[Variable, Set[Assign | Dir // type rtype = Map[Block, Map[Variable, Set[Assign | DirectCall]]] val (beforeRes, afterRes) = solver.solveProc(p) - - val merged: Map[Command, Map[Variable, Set[Assign | DirectCall]]] = - beforeRes.flatMap((block, sts) => { - val b = Seq(IRWalk.firstInBlock(block) -> sts) - val stmts = - if (block.statements.nonEmpty) then - (block.statements.toList: List[Command]).zip(block.statements.toList.tail ++ List(block.jump)) - else List() - val transferred = stmts - .foldLeft((sts, List[(Command, Map[Variable, Set[Assign | DirectCall]])]()))((st, s) => { - // map successor to transferred predecessor - val x = dom.transfer(st._1, s._1) - (x, (s._2 -> x) :: st._2) - }) - ._2.toMap - b ++ transferred - }) - .toMap + val merged: Map[Command, Map[Variable, Set[Assign | DirectCall]]] = + beforeRes + .flatMap((block, sts) => { + val b = Seq(IRWalk.firstInBlock(block) -> sts) + val stmts = + if (block.statements.nonEmpty) then + (block.statements.toList: List[Command]).zip(block.statements.toList.tail ++ List(block.jump)) + else List() + val transferred = stmts + .foldLeft((sts, List[(Command, Map[Variable, Set[Assign | DirectCall]])]()))((st, s) => { + // map successor to transferred predecessor + val x = dom.transfer(st._1, s._1) + (x, (s._2 -> x) :: st._2) + }) + ._2 + .toMap + b ++ transferred + }) + .toMap merged } case class DefUse(defined: Map[Variable, Assign | DirectCall]) // map v -> definitions reached here -class DefUseDomain(liveBefore: Map[Block, Set[Variable]]) extends AbstractDomain[Map[Variable, Set[Assign | DirectCall]]] { +class DefUseDomain(liveBefore: Map[Block, Set[Variable]]) + extends AbstractDomain[Map[Variable, Set[Assign | DirectCall]]] { // TODO: cull values using liveness override def transfer(s: Map[Variable, Set[Assign | DirectCall]], b: Command) = { @@ -81,7 +99,8 @@ class DefUseDomain(liveBefore: Map[Block, Set[Variable]]) extends AbstractDomain def bot = Map[Variable, Set[Assign | DirectCall]]() def join(l: Map[Variable, Set[Assign | DirectCall]], r: Map[Variable, Set[Assign | DirectCall]], pos: Block) = { l.keySet - .union(r.keySet).filter(k => liveBefore(pos).contains(k)) + .union(r.keySet) + .filter(k => liveBefore(pos).contains(k)) .map(k => { k -> (l.get(k).getOrElse(Set()) ++ r.get(k).getOrElse(Set())) }) @@ -278,42 +297,52 @@ def removeSlices(p: Procedure): Unit = { })) ) - // try and find a single extension size for all rhs of assignments to all variables in the assigned equality class - // val varHighZeroBits: Map[LocalVar, HighZeroBits] = assignments.map((v, assigns) => - // // note: this overapproximates on x := y when x and y may both be smaller than their declared size - // val allRHSExtended = assigns.foldLeft(HighZeroBits.Bot: HighZeroBits)((e, assign) => - // (e, assign.rhs) match { - // case (HighZeroBits.Bot, ZeroExtend(i, lhs)) => HighZeroBits.Bits(i) - // case (b @ HighZeroBits.Bits(ei), ZeroExtend(i, _)) if i == ei => b - // case (b @ HighZeroBits.Bits(ei), ZeroExtend(i, _)) if i != ei => HighZeroBits.False - // case (HighZeroBits.False, _) => HighZeroBits.False - // case (_, other) => HighZeroBits.False - // } - // ) - // (v, allRHSExtended) - // ) - - //val varsWithExtend: Map[LocalVar, HighZeroBits] = assignments - // .map((lhs, _) => { - // // map all lhs to the result for their representative - // val rep = ufsolver.find(LVTerm(lhs)) match { - // case LVTerm(r) => r - // case _ => ??? - // } - // lhs -> varHighZeroBits.get(rep) - // }) - // .collect { case (l, Some(x)) /* remove anything we have no information on */ => - // (l, x) - // } - class CheckUsesHaveExtend() extends CILVisitor { val result: mutable.HashMap[LocalVar, HighZeroBits] = mutable.HashMap[LocalVar, HighZeroBits]() + def extractAccess(v: LocalVar, highestBit: Int): Unit = { + if (!size(v).isDefined) { + return (); + } + if (((!result.contains(v)) || result.get(v).contains(HighZeroBits.Bot))) { + result(v) = HighZeroBits.Bits(highestBit) + } else { + result(v) match { + case HighZeroBits.Bits(n) if highestBit == size(v).get => { + // access full expr + result(v) = HighZeroBits.False + } + case HighZeroBits.Bits(n) if highestBit > n => { + // relax constraint to bits accessed + result(v) = HighZeroBits.Bits(highestBit) + } + case HighZeroBits.Bits(n) if n >= highestBit => { + // access satisfied by upper constraint + } + case _ => () + } + } + } + + override def vstmt(s: Statement) = { + s match { + case d: DirectCall => { + d.outParams.map(_._2).collect { + case l: LocalVar => { + result(l) = HighZeroBits.False + } + } + } + case _ => () + } + DoChildren() + } + override def vrvar(v: Variable) = { v match { - case v: LocalVar => { - result(v) = HighZeroBits.False + case v: LocalVar if size(v).isDefined => { + extractAccess(v, size(v).get) } case _ => () } @@ -322,13 +351,10 @@ def removeSlices(p: Procedure): Unit = { override def vexpr(v: Expr) = { v match { - case Extract(i, 0, v: LocalVar) - if size(v).isDefined && ((!result.contains(v)) || result.get(v).contains(HighZeroBits.Bot)) => { - result(v) = HighZeroBits.Bits(i) + case Extract(i, 0, v: LocalVar) if size(v).isDefined => { + extractAccess(v, i) SkipChildren() } - case Extract(i, 0, v: LocalVar) if size(v).isDefined && result.get(v).contains(HighZeroBits.Bits(i)) => - SkipChildren() case _ => DoChildren() } } @@ -351,13 +377,37 @@ def removeSlices(p: Procedure): Unit = { v match { case Extract(i, 0, v: LocalVar) if size(v).isDefined && !(formals.contains(v)) && varHighZeroBits.contains(v) => { - assert(varHighZeroBits(v) == i) - ChangeTo(LocalVar(v.name, BitVecType(varHighZeroBits(v)))) + assert(varHighZeroBits(v) >= i) + if (varHighZeroBits(v) == i) { + ChangeTo(v.copy(irType = BitVecType(varHighZeroBits(v)))) + } else { + ChangeTo(Extract(i, 0, v.copy(irType = BitVecType(varHighZeroBits(v))))) + } } case _ => DoChildren() } } + override def vlvar(v: Variable) = { + v match { + case lhs: LocalVar if (varHighZeroBits.contains(lhs) && !formals.contains(lhs)) => { + val n = lhs.copy(irType = BitVecType(varHighZeroBits(lhs))) + ChangeTo(n) + } + case _ => SkipChildren() + } + } + + override def vrvar(v: Variable) = { + v match { + case lhs: LocalVar if (varHighZeroBits.contains(lhs) && !formals.contains(lhs)) => { + val n = lhs.copy(irType = BitVecType(varHighZeroBits(lhs)), lhs.index) + ChangeTo(n) + } + case _ => SkipChildren() + } + } + override def vproc(p: Procedure) = { formals = p.formalInParam.toSet ++ p.formalOutParam.toSet DoChildren() @@ -369,7 +419,7 @@ def removeSlices(p: Procedure): Unit = { if size(lhs).isDefined && varHighZeroBits.get(lhs).contains(size(rhs).get) && !(formals.contains(lhs)) => { // assert(varHighZeroBits(lhs) == sz) val varsize = varHighZeroBits(lhs) - a.lhs = LocalVar(lhs.name, BitVecType(varsize)) + a.lhs = LocalVar(lhs.varName, BitVecType(varsize), lhs.index) a.rhs = rhs assert(size(a.lhs).get == size(a.rhs).get) DoChildren() @@ -377,7 +427,7 @@ def removeSlices(p: Procedure): Unit = { case a @ Assign(lhs: LocalVar, ZeroExtend(sz, rhs), _) if size(lhs).isDefined && varHighZeroBits.get(lhs).contains(size(rhs).get) && !(formals.contains(lhs)) => { val varsize = varHighZeroBits(lhs) - a.lhs = LocalVar(lhs.name, BitVecType(varsize)) + a.lhs = LocalVar(lhs.varName, BitVecType(varsize), lhs.index) a.rhs = rhs assert(size(a.lhs).get == size(a.rhs).get) DoChildren() @@ -385,7 +435,7 @@ def removeSlices(p: Procedure): Unit = { case a @ Assign(lhs: LocalVar, rhs, _) if size(lhs).isDefined && varHighZeroBits.contains(lhs) && !(formals.contains(lhs)) => { // promote extract to the definition - a.lhs = LocalVar(lhs.name, BitVecType(varHighZeroBits(lhs))) + a.lhs = LocalVar(lhs.varName, BitVecType(varHighZeroBits(lhs)), lhs.index) a.rhs = Extract(varHighZeroBits(lhs), 0, rhs) assert(size(a.lhs).get == size(a.rhs).get, s"${size(a.lhs).get} != ${size(a.rhs).get}") DoChildren() @@ -575,10 +625,73 @@ def copypropTransform(p: Procedure) = { } + def removeEmptyBlocks(p: Program) = { + for (proc <- p.procedures) { + val blocks = proc.blocks.toList + for (b <- blocks) { + b match { + case b: Block if b.statements.size == 0 && b.prevBlocks.size == 1 && b.jump.isInstanceOf[GoTo] => { + val prev = b.prevBlocks + val next = b.nextBlocks + for (p <- prev) { + p.jump match { + case g: GoTo => { + for (n <- next) { + g.addTarget(n) + } + g.removeTarget(b) + } + case _ => throw Exception("Must have goto") + } + } + b.replaceJump(Unreachable()) + b.parent.removeBlocks(b) + } + case _ => () + } + } + } + } + +def coalesceBlocks(p: Program) = { + var didAny = false + for (proc <- p.procedures) { + val blocks = proc.blocks.toList + for (b <- blocks.sortBy(_.rpoOrder)) { + if (b.prevBlocks.size == 1 && b.prevBlocks.head.statements.nonEmpty && b.statements.nonEmpty + && b.prevBlocks.head.nextBlocks.size == 1 + && b.prevBlocks.head.statements.lastOption.map(s => !(s.isInstanceOf[Call])).getOrElse(true)) { + didAny = true + // append topredecessor + // we know prevBlock is only jumping to b and has no call at the end + val prevBlock = b.prevBlocks.head + val stmts = b.statements.map(b.statements.remove).toList + prevBlock.statements.appendAll(stmts) + // leave empty block b and cleanup with removeEmptyBlocks + } else if (b.nextBlocks.size == 1 && b.nextBlocks.head.statements.nonEmpty && b.statements.nonEmpty + && b.nextBlocks.head.prevBlocks.size == 1 + && b.statements.lastOption.map(s => !(s.isInstanceOf[Call])).getOrElse(true)) { + didAny = true + // append to successor + // we know b is only jumping to nextBlock and does not end in a call + val nextBlock = b.nextBlocks.head + val stmts = b.statements.map(b.statements.remove).toList + nextBlock.statements.prependAll(stmts) + // leave empty block b and cleanup with removeEmptyBlocks + } + } + } + didAny +} + def doCopyPropTransform(p: Program) = { applyRPO(p) + removeEmptyBlocks(p) + coalesceBlocks(p) + removeEmptyBlocks(p) + Logger.info("[!] Simplify :: Expr/Copy-prop Transform") val work = p.procedures .filter(_.blocks.size > 0) @@ -606,30 +719,18 @@ def doCopyPropTransform(p: Program) = { // cleanup visit_prog(CleanupAssignments(), p) - for (proc <- p.procedures) { - val blocks = proc.blocks.toList - for (b <- blocks) { - b match { - case b: Block if b.statements.size == 0 && b.prevBlocks.size == 1 && b.nextBlocks.size == 1 => { - val p = b.prevBlocks.head - val n = b.nextBlocks.head - p.jump match { - case g: GoTo => { - g.addTarget(n) - g.removeTarget(b) - } - case _ => ??? - } - b.replaceJump(Unreachable()) - b.parent.removeBlocks(b) - } - case _ => () - } - } - } - Logger.info("[!] Simplify :: Merge empty blocks") + + + removeEmptyBlocks(p) + coalesceBlocks(p) + coalesceBlocks(p) + coalesceBlocks(p) + coalesceBlocks(p) + removeEmptyBlocks(p) + + } def reversePostOrder(startBlock: Block): Unit = { @@ -721,8 +822,10 @@ class worklistSolver[L, A <: AbstractDomain[L]](domain: A) { while (worklist.nonEmpty) { val b = worklist.dequeue - while (worklist.nonEmpty && (if backwards then (worklist.head.rpoOrder <= b.rpoOrder) - else (worklist.head.rpoOrder >= b.rpoOrder))) + while ( + worklist.nonEmpty && (if backwards then (worklist.head.rpoOrder <= b.rpoOrder) + else (worklist.head.rpoOrder >= b.rpoOrder)) + ) do { // drop rest of blocks with same priority val m = worklist.dequeue() @@ -732,7 +835,6 @@ class worklistSolver[L, A <: AbstractDomain[L]](domain: A) { ) } - val prev = savedAfter.get(b) val x = { predecessors(b).toList.flatMap(b => savedAfter.get(b).toList) match { @@ -1053,7 +1155,7 @@ class Simplify( } override def vblock(b: Block) = { - block = b + block = b DoChildren() } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 8a20a5070..9a9c00694 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -527,13 +527,14 @@ object RunUtils { } def doSimplify(ctx: IRContext, config: Option[StaticAnalysisConfig]) : Unit = { + + // writeToFile(dotBlockGraph(ctx.program, ctx.program.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"blockgraph-before-simp.dot") Logger.info("[!] Running Simplify") val timer = PerformanceTimer("Simplify") transforms.applyRPO(ctx.program) Logger.info(s"RPO ${timer.checkPoint("RPO")} ms ") Logger.info("[!] Simplify :: DynamicSingleAssignment") - // val liveVars : Map[CFGPosition, Set[Variable]] = analysis.IntraLiveVarsAnalysis(ctx.program).analyze() // writeToFile(serialiseIL(ctx.program), s"il-before-dsa.il") @@ -541,9 +542,15 @@ object RunUtils { transforms.OnePassDSA().applyTransform(ctx.program) Logger.info(s"DSA ${timer.checkPoint("DSA ")} ms ") // writeToFile(dotBlockGraph(ctx.program, ctx.program.filter(_.isInstanceOf[Block]).map(b => b -> b.toString).toMap), s"blockgraph-after-dsa.dot") - Logger.info("DSA Check") - val x = ctx.program.procedures.forall(transforms.rdDSAProperty) - assert(x) + if (ir.eval.SimplifyValidation.validate) { + Logger.info("Live vars difftest") + val tipLiveVars : Map[CFGPosition, Set[Variable]] = analysis.IntraLiveVarsAnalysis(ctx.program).analyze() + assert(ctx.program.procedures.forall(transforms.difftestLiveVars(_, tipLiveVars))) + + Logger.info("DSA Check") + val x = ctx.program.procedures.forall(transforms.rdDSAProperty) + assert(x) + } Logger.info("DSA Check passed") assert(invariant.singleCallBlockEnd(ctx.program)) assert(invariant.cfgCorrect(ctx.program)) @@ -551,7 +558,9 @@ object RunUtils { // writeToFile(serialiseIL(ctx.program), s"il-before-copyprop.il") - + // brute force run the analysis twice because it cleans up more stuff + transforms.doCopyPropTransform(ctx.program) + transforms.doCopyPropTransform(ctx.program) transforms.doCopyPropTransform(ctx.program) Logger.info(s"CopyProp ${timer.checkPoint("CopyProp")} ms ") // writeToFile(serialiseIL(ctx.program), s"il-after-copyprop.il")