Skip to content

Commit

Permalink
Merge pull request #208 from UQ-PAC/intrusive-list-parent-assertion-fix
Browse files Browse the repository at this point in the history
IntrusiveList/Parent/Block assertion fixes
  • Loading branch information
ailrst authored Jun 5, 2024
2 parents 104f4aa + df44840 commit 4b3a278
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 87 deletions.
6 changes: 3 additions & 3 deletions src/main/scala/ir/IRCursor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ object IRWalk:

def commandBegin(pos: CFGPosition) : Option[Command] = {
pos match {
case p: Procedure => p.entryBlock.map(b => b.statements.headOption().getOrElse(b.jump))
case b: Block => Some(b.statements.headOption().getOrElse(b.jump))
case p: Procedure => p.entryBlock.map(b => b.statements.headOption.getOrElse(b.jump))
case b: Block => Some(b.statements.headOption.getOrElse(b.jump))
case c: Command => Some(c)
}
}
Expand Down Expand Up @@ -81,7 +81,7 @@ trait IntraProcIRCursor extends IRWalk[CFGPosition, CFGPosition] {
def succ(pos: CFGPosition): Set[CFGPosition] = {
pos match {
case proc: Procedure => proc.entryBlock.toSet
case b: Block => Set(b.statements.headOption().getOrElse(b.jump))
case b: Block => Set(b.statements.headOption.getOrElse(b.jump))
case s: Statement => Set(s.succ().getOrElse(s.parent.jump))
case n: GoTo => n.targets.asInstanceOf[Set[CFGPosition]]
case c: Call => c.parent.fallthrough.toSet
Expand Down
11 changes: 8 additions & 3 deletions src/main/scala/ir/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,8 @@ class Block private (
statements.onInsert = x => x.setParent(this)
statements.onRemove = x => x.deParent()


def this(label: String, address: Option[Int] = None, statements: IterableOnce[Statement] = Set.empty, jump: Jump = GoTo(Set.empty)) = {
this(label, address, IntrusiveList.from(statements), jump, mutable.HashSet.empty, None)
this(label, address, IntrusiveList().addAll(statements), jump, mutable.HashSet.empty, None)
}

def jump: Jump = _jump
Expand All @@ -377,7 +376,8 @@ class Block private (
_fallthrough = g
}

def jump_=(j: Jump): Unit = {
private def jump_=(j: Jump): Unit = {
require(!j.hasParent)
if (j ne _jump) {
_jump.deParent()
_jump = j
Expand All @@ -386,6 +386,11 @@ class Block private (
}

def replaceJump(j: Jump): Block = {
if (j.hasParent) {
val parent = j.parent
j.deParent()
parent.jump = GoTo(Set.empty)
}
jump = j
this
}
Expand Down
22 changes: 3 additions & 19 deletions src/main/scala/ir/Statement.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,7 @@ object GoTo:


sealed trait Call extends Jump {
private var _returnTarget: Option[Block] = None

// replacing the return target of a call
def returnTarget_=(b: Block): Unit = {
require(b.hasParent)

if (hasParent) {
// if we don't have a parent now, delay adding the fallthrough block until linking
parent.fallthrough = Some(GoTo(Set(b)))
}

_returnTarget = Some(b)
}

def returnTarget: Option[Block] = _returnTarget
val returnTarget: Option[Block]

// moving a call between blocks
override def linkParent(p: Block): Unit = {
Expand All @@ -154,10 +140,9 @@ sealed trait Call extends Jump {
}

class DirectCall(val target: Procedure,
private val _returnTarget: Option[Block] = None,
override val returnTarget: Option[Block] = None,
override val label: Option[String] = None
) extends Call {
_returnTarget.foreach(x => returnTarget = x)
/* override def locals: Set[Variable] = condition match {
case Some(c) => c.locals
case None => Set()
Expand All @@ -182,10 +167,9 @@ object DirectCall:
def unapply(i: DirectCall): Option[(Procedure, Option[Block], Option[String])] = Some(i.target, i.returnTarget, i.label)

class IndirectCall(var target: Variable,
private val _returnTarget: Option[Block] = None,
override val returnTarget: Option[Block] = None,
override val label: Option[String] = None
) extends Call {
_returnTarget.foreach(x => returnTarget = x)
/* override def locals: Set[Variable] = condition match {
case Some(c) => c.locals + target
case None => Set(target)
Expand Down
11 changes: 6 additions & 5 deletions src/main/scala/ir/dsl/DSL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ case class EventuallyIndirectCall(target: Variable, fallthrough: Option[DelayNam

case class EventuallyCall(target: DelayNameResolve, fallthrough: Option[DelayNameResolve]) extends EventuallyJump {
override def resolve(p: Program): DirectCall = {
val t = target.resolveProc(p).get
val t = target.resolveProc(p) match {
case Some(x) => x
case None => throw Exception("can't resolve proc " + p)
}
val ft = fallthrough.flatMap(_.resolveBlock(p))
DirectCall(t, ft)
}
Expand All @@ -70,11 +73,9 @@ def goto(targets: List[String]): EventuallyGoto = {
EventuallyGoto(targets.map(p => DelayNameResolve(p)))
}

def indirectCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x)))
def directCall(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x)))

def call(tgt: String, fallthrough: Option[String]): EventuallyCall = EventuallyCall(DelayNameResolve(tgt), fallthrough.map(x => DelayNameResolve(x)))

def call(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x)))
def indirectCall(tgt: Variable, fallthrough: Option[String]): EventuallyIndirectCall = EventuallyIndirectCall(tgt, fallthrough.map(x => DelayNameResolve(x)))
// def directcall(tgt: String) = EventuallyCall(DelayNameResolve(tgt), None)


Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/translating/GTIRBToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[
private def cleanUpIfPCAssign(block: Block, procedure: Procedure): Unit = {
var newBlockCount = 0
var currentBlock = block
var currentStatement = currentBlock.statements.head()
var currentStatement = currentBlock.statements.head
var breakLoop = false
val queue = mutable.Queue[Block]()
while (!breakLoop) {
Expand All @@ -306,7 +306,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[

if (queue.nonEmpty) {
currentBlock = queue.dequeue()
currentStatement = currentBlock.statements.head()
currentStatement = currentBlock.statements.head
} else {
breakLoop = true
}
Expand All @@ -326,7 +326,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[

if (queue.nonEmpty) {
currentBlock = queue.dequeue()
currentStatement = currentBlock.statements.head()
currentStatement = currentBlock.statements.head
} else {
breakLoop = true
}
Expand All @@ -335,7 +335,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[
currentStatement = currentBlock.statements.getNext(currentStatement)
} else if (queue.nonEmpty) {
currentBlock = queue.dequeue()
currentStatement = currentBlock.statements.head()
currentStatement = currentBlock.statements.head
} else {
breakLoop = true
}
Expand Down Expand Up @@ -374,7 +374,7 @@ class GTIRBToIR(mods: Seq[Module], parserMap: immutable.Map[String, Array[Array[
g.targets.head
} else {
// case where goto has multiple targets: create an extra block and point to that
val afterBlock = Block(parentLabel + "$__" + newBlockCount, None)
val afterBlock = Block(parentLabel + "$__" + newBlockCount)
newBlockCount += 1
newBlocks.append(afterBlock)
afterBlock.replaceJump(currentBlock.jump)
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/translating/IRToBoogie.scala
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ class IRToBoogie(var program: Program, var spec: Specification) {
case g: GoTo =>
// collects all targets of the goto with a branch condition that we need to check the security level for
// and collects the variables for that
val conditions = g.targets.flatMap(_.statements.headOption()).collect { case a: Assume if a.checkSecurity => a }
val conditions = g.targets.flatMap(_.statements.headOption.collect { case a: Assume if a.checkSecurity => a })
val conditionVariables = conditions.flatMap(_.body.variables)
val gammas = conditionVariables.map(_.toGamma).toList.sorted
val conditionAssert: List[BCmd] = if (gammas.size > 1) {
Expand Down
30 changes: 12 additions & 18 deletions src/main/scala/util/intrusive_list/IntrusiveList.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package util.intrusive_list
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

// TODO: implement IterableOps
// So need iterablefactory https://docs.scala-lang.org/overviews/core/custom-collections.html
Expand Down Expand Up @@ -90,7 +91,7 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private (
elem
}

class IntrusiveListIterator(var elem: Option[T], forward: Boolean) extends Iterator[T] {
private class IntrusiveListIterator(var elem: Option[T], forward: Boolean) extends Iterator[T] {
override def hasNext: Boolean = elem.isDefined
override def next: T = {
val t = elem.get
Expand Down Expand Up @@ -133,14 +134,14 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private (
/**
* Unsafely return the first element of the list.
*/
override def head(): T = firstElem.get
override def head: T = firstElem.get

override def headOption(): Option[T] = firstElem
override def headOption: Option[T] = firstElem

/**
* Unsafely return the first element of the list.
*/
def begin(): T = firstElem.get
def begin: T = firstElem.get

/**
* Check whether the list contains the given element (by reference) by linear scan.
Expand Down Expand Up @@ -169,7 +170,7 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private (
/**
* Unsafely return the last element of the list.
*/
def back(): T = lastElem.get
def back: T = lastElem.get

/**
* Add an element to the beginning of the list.
Expand Down Expand Up @@ -232,19 +233,19 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private (
}

/**
* Split the list into two lists, the first retains all elements up to to and including the provided element,
* and and returns the second list from the element until the end.
* Removes all elements after the provided element n and returns an ArrayBuffer containing the removed elements,
* maintaining the ordering.
*
* @param n The element to split on, remains in the first list.
* @return A list containing all elements after n.
* @return An ArrayBuffer containing all elements after n.
*/
def splitOn(n: T): IntrusiveList[T] = {
def splitOn(n: T): ArrayBuffer[T] = {
require(!lastElem.contains(n))
require(containsRef(n))

val ne = n.next

val newlist = new IntrusiveList[T]()
val newlist = ArrayBuffer[T]()
var next = n.next
while (next.isDefined) {
remove(next.get)
Expand Down Expand Up @@ -355,12 +356,7 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private (
}

object IntrusiveList {

def from[T <: IntrusiveListElement[T]](it: IntrusiveList[T]): IntrusiveList[T] = it

def from[T <: IntrusiveListElement[T]](it: IterableOnce[T]): IntrusiveList[T] = IntrusiveList[T]().addAll(it)

def empty[T <: IntrusiveListElement[T]]: IntrusiveList[T] = new IntrusiveList[T]()
def empty[T <: IntrusiveListElement[T]]: IntrusiveList[T] = IntrusiveList[T]()
}

/**
Expand All @@ -382,8 +378,6 @@ trait IntrusiveListElement[T <: IntrusiveListElement[T]]:
elem
}



private[intrusive_list] final def unitary: Boolean = next.isEmpty && prev.isEmpty

private[intrusive_list] final def insertAfter(elem: T): T = {
Expand Down
26 changes: 13 additions & 13 deletions src/test/scala/LiveVarsAnalysisTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil {
block("first_call",
r0ConstantAssign,
r1ConstantAssign,
call("callee1", Some("second_call"))
directCall("callee1", Some("second_call"))
),
block("second_call",
call("callee2", Some("returnBlock"))
directCall("callee2", Some("returnBlock"))
),
block("returnBlock",
ret
Expand Down Expand Up @@ -69,10 +69,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil {
block("first_call",
r0ConstantAssign,
r1ConstantAssign,
call("callee1", Some("second_call"))
directCall("callee1", Some("second_call"))
),
block("second_call",
call("callee2", Some("returnBlock"))
directCall("callee2", Some("returnBlock"))
),
block("returnBlock",
ret
Expand Down Expand Up @@ -104,10 +104,10 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil {
var program = prog(
proc("main",
block("main_first_call",
call("wrapper1", Some("main_second_call"))
directCall("wrapper1", Some("main_second_call"))
),
block("main_second_call",
call("wrapper2", Some("main_return"))
directCall("wrapper2", Some("main_return"))
),
block("main_return", ret)
),
Expand All @@ -117,19 +117,19 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil {
proc("wrapper1",
block("wrapper1_first_call",
LocalAssign(R1, constant1),
call("callee", Some("wrapper1_second_call"))
directCall("callee", Some("wrapper1_second_call"))
),
block("wrapper1_second_call",
call("callee2", Some("wrapper1_return"))),
directCall("callee2", Some("wrapper1_return"))),
block("wrapper1_return", ret)
),
proc("wrapper2",
block("wrapper2_first_call",
LocalAssign(R2, constant1),
call("callee", Some("wrapper2_second_call"))
directCall("callee", Some("wrapper2_second_call"))
),
block("wrapper2_second_call",
call("callee3", Some("wrapper2_return"))),
directCall("callee3", Some("wrapper2_return"))),
block("wrapper2_return", ret)
)
)
Expand All @@ -148,7 +148,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil {
var program = prog(
proc("main",
block("lmain",
call("killer", Some("aftercall"))
directCall("killer", Some("aftercall"))
),
block("aftercall",
LocalAssign(R0, R1),
Expand Down Expand Up @@ -212,7 +212,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil {
block(
"lmain",
LocalAssign(R0, R1),
call("main", Some("return"))
directCall("main", Some("return"))
),
block("return",
LocalAssign(R0, R2),
Expand Down Expand Up @@ -240,7 +240,7 @@ class LiveVarsAnalysisTests extends AnyFunSuite, TestUtil {
),
block(
"recursion",
call("main", Some("assign"))
directCall("main", Some("assign"))
),
block("assign",
LocalAssign(R0, R2),
Expand Down
6 changes: 3 additions & 3 deletions src/test/scala/PointsToTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft
goto("0x1")
),
block("0x1",
call("p2", Some("returntarget"))
directCall("p2", Some("returntarget"))
),
block("returntarget",
ret
Expand Down Expand Up @@ -217,7 +217,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft
goto("0x1")
),
block("0x1",
call("p2", Some("returntarget"))
directCall("p2", Some("returntarget"))
),
block("returntarget",
ret
Expand All @@ -227,7 +227,7 @@ class PointsToTest extends AnyFunSuite with OneInstancePerTest with BeforeAndAft
block("l_foo",
LocalAssign(getRegister("R0"), MemoryLoad(mem, BinaryExpr(BVADD, getRegister("R31"), bv64(6)), LittleEndian, 64)),
LocalAssign(getRegister("R1"), BinaryExpr(BVADD, getRegister("R31"), bv64(10))),
call("p2", Some("l_foo_1"))
directCall("p2", Some("l_foo_1"))
),
block("l_foo_1",
ret,
Expand Down
Loading

0 comments on commit 4b3a278

Please sign in to comment.