Skip to content

Commit

Permalink
Add GNN path selector
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma committed Sep 20, 2023
1 parent 1ccd345 commit 919c57e
Show file tree
Hide file tree
Showing 18 changed files with 631 additions and 8 deletions.
Binary file added Game_env/test_model.onnx
Binary file not shown.
2 changes: 2 additions & 0 deletions buildSrc/src/main/kotlin/Versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ object Versions {
const val jcdb = "1.2.0"
const val mockk = "1.13.4"
const val junitParams = "5.9.3"
const val serialization = "1.5.1"
const val onnxruntime = "1.15.1"
const val logback = "1.4.8"

// versions for jvm samples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
implementation(kotlin("stdlib-jdk8"))
implementation(kotlin("reflect"))
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}")
implementation("com.microsoft.onnxruntime", "onnxruntime", Versions.onnxruntime)

testImplementation(kotlin("test"))
}
Expand Down
13 changes: 12 additions & 1 deletion usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ sealed class PathsTrieNode<State : UState<*, *, Statement, *, *, State>, Stateme
*/
abstract val depth: Int

/**
* States that were forked from this node
*/
abstract val accumulatedForks: MutableCollection<State>

/**
* Adds a new label to [labels] collection.
*/
Expand Down Expand Up @@ -77,9 +82,10 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
depth = parentNode.depth + 1,
parent = parentNode,
states = hashSetOf(state),
statement = statement
statement = statement,
) {
parentNode.children[statement] = this
parentNode.accumulatedForks.addAll(this.states)
}

internal constructor(parentNode: PathsTrieNodeImpl<State, Statement>, statement: Statement, state: State) : this(
Expand All @@ -89,11 +95,14 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
statement = statement
) {
parentNode.children[statement] = this
parentNode.accumulatedForks.addAll(this.states)
parentNode.states -= state
}

override val labels: MutableSet<Any> = hashSetOf()

override val accumulatedForks: MutableCollection<State> = mutableSetOf()

override fun addLabel(label: Any) {
labels.add(label)
}
Expand All @@ -115,6 +124,8 @@ class RootNode<State : UState<*, *, Statement, *, *, State>, Statement> : PathsT

override val labels: MutableSet<Any> = hashSetOf()

override val accumulatedForks: MutableCollection<State> = mutableSetOf()

override val depth: Int = 0

override fun addLabel(label: Any) {
Expand Down
149 changes: 149 additions & 0 deletions usvm-core/src/main/kotlin/org/usvm/ps/BlockGraph.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package org.usvm.ps

import org.usvm.statistics.ApplicationGraph

data class Block<Statement>(
val id: Int,
var path: MutableList<Statement> = mutableListOf(),

var parents: MutableSet<Block<Statement>> = mutableSetOf(),
var children: MutableSet<Block<Statement>> = mutableSetOf()
) {
override fun hashCode(): Int = id

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as Block<*>

if (id != other.id) return false

return true
}
}

class BlockGraph<Method, Statement>(
initialStatement: Statement,
private val applicationGraph: ApplicationGraph<Method, Statement>,
) {
val root: Block<Statement>
private var nextBlockId: Int = 0
private val blockStatementMapping = HashMap<Block<Statement>, MutableList<Statement>>()

val blocks: Collection<Block<Statement>>
get() = blockStatementMapping.keys

init {
root = buildGraph(initialStatement)
}

fun getGraphBlock(statement: Statement): Block<Statement>? {
blockStatementMapping.forEach {
if (statement in it.value) {
return it.key
}
}
return null
}

private fun initializeGraphBlockWith(statement: Statement): Block<Statement> {
val currentBlock = Block(nextBlockId++, path = mutableListOf(statement))
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(statement)
return currentBlock
}

private fun createAndLinkWithPreds(statement: Statement): Block<Statement> {
val currentBlock = initializeGraphBlockWith(statement)
for (pred in applicationGraph.predecessors(statement)) {
getGraphBlock(pred)?.children?.add(currentBlock)
getGraphBlock(pred)?.let { currentBlock.parents.add(it) }
}
return currentBlock
}

private fun Statement.inBlock() = getGraphBlock(this) != null

private fun ApplicationGraph<Method, Statement>.filterStmtSuccsNotInBlock(
statement: Statement,
forceNewBlock: Boolean
): Sequence<Pair<Statement, Boolean>> {
return this.successors(statement).filter { !it.inBlock() }.map { Pair(it, forceNewBlock) }
}

fun buildGraph(initial: Statement): Block<Statement> {
val root = initializeGraphBlockWith(initial)
var currentBlock = root
val statementQueue = ArrayDeque<Pair<Statement, Boolean>>()

val initialHasMultipleSuccessors = applicationGraph.successors(initial).count() > 1
statementQueue.addAll(
applicationGraph.filterStmtSuccsNotInBlock(
initial,
forceNewBlock = initialHasMultipleSuccessors
)
)

while (statementQueue.isNotEmpty()) {
val (currentStatement, forceNew) = statementQueue.removeFirst()

if (forceNew) {
// don't need to add `currentStatement` succs, we did it earlier
createAndLinkWithPreds(currentStatement)
continue
}

// if statement is a call or if statement has multiple successors: next statements start new block
if (applicationGraph.callees(currentStatement).any() || applicationGraph.successors(currentStatement).count() > 1) {
currentBlock.path.add(currentStatement)
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true))
continue
}

// if statement has multiple ins: next statements start new block
if (applicationGraph.predecessors(currentStatement).count() > 1) {
currentBlock = createAndLinkWithPreds(currentStatement)
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true))
continue
}

currentBlock.path.add(currentStatement)
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = false))
}

return root
}

fun getEdges(): List<GameMapEdge> {
return blocks.flatMap { block ->
block.children.map { GameMapEdge(it.id, block.id, GameEdgeLabel(0)) }
}
}

fun getVertices(): Collection<Block<Statement>> = blocks

fun getBlockFeatures(
block: Block<Statement>, isCovered: (Statement) -> Boolean,
inCoverageZone: (Statement) -> Boolean,
isVisited: (Statement) -> Boolean,
stateIdsInBlock: List<UInt>
): BlockFeatures {
val firstStatement = block.path.first()
val lastStatement = block.path.last()
val visitedByState = isVisited(lastStatement)
val touchedByState = visitedByState || isVisited(firstStatement)

return BlockFeatures(
id = block.id,
inCoverageZone = inCoverageZone(firstStatement),
basicBlockSize = block.path.size,
coveredByTest = isCovered(firstStatement),
visitedByState = visitedByState,
touchedByState = touchedByState,
states = stateIdsInBlock
)
}
}
Loading

0 comments on commit 919c57e

Please sign in to comment.