Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GNN path selector #67

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added Game_env/test_model.onnx
Binary file not shown.
2 changes: 2 additions & 0 deletions buildSrc/src/main/kotlin/Versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ object Versions {
const val jcdb = "1.2.0"
const val mockk = "1.13.4"
const val junitParams = "5.9.3"
const val serialization = "1.5.1"
const val onnxruntime = "1.15.1"
const val logback = "1.4.8"

// versions for jvm samples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies {
implementation(kotlin("stdlib-jdk8"))
implementation(kotlin("reflect"))
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}")
implementation("com.microsoft.onnxruntime", "onnxruntime", Versions.onnxruntime)

testImplementation(kotlin("test"))
}
Expand Down
13 changes: 12 additions & 1 deletion usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ sealed class PathsTrieNode<State : UState<*, *, Statement, *, *, State>, Stateme
*/
abstract val depth: Int

/**
* States that were forked from this node
*/
abstract val accumulatedForks: MutableCollection<State>

/**
* Adds a new label to [labels] collection.
*/
Expand Down Expand Up @@ -77,9 +82,10 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
depth = parentNode.depth + 1,
parent = parentNode,
states = hashSetOf(state),
statement = statement
statement = statement,
) {
parentNode.children[statement] = this
parentNode.accumulatedForks.addAll(this.states)
}

internal constructor(parentNode: PathsTrieNodeImpl<State, Statement>, statement: Statement, state: State) : this(
Expand All @@ -89,11 +95,14 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
statement = statement
) {
parentNode.children[statement] = this
parentNode.accumulatedForks.addAll(this.states)
parentNode.states -= state
}

override val labels: MutableSet<Any> = hashSetOf()

override val accumulatedForks: MutableCollection<State> = mutableSetOf()

override fun addLabel(label: Any) {
labels.add(label)
}
Expand All @@ -115,6 +124,8 @@ class RootNode<State : UState<*, *, Statement, *, *, State>, Statement> : PathsT

override val labels: MutableSet<Any> = hashSetOf()

override val accumulatedForks: MutableCollection<State> = mutableSetOf()

override val depth: Int = 0

override fun addLabel(label: Any) {
Expand Down
149 changes: 149 additions & 0 deletions usvm-core/src/main/kotlin/org/usvm/ps/BlockGraph.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package org.usvm.ps

import org.usvm.statistics.ApplicationGraph

data class Block<Statement>(
val id: Int,
var path: MutableList<Statement> = mutableListOf(),

var parents: MutableSet<Block<Statement>> = mutableSetOf(),
var children: MutableSet<Block<Statement>> = mutableSetOf()
) {
override fun hashCode(): Int = id

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as Block<*>

if (id != other.id) return false

return true
}
}

class BlockGraph<Method, Statement>(
initialStatement: Statement,
private val applicationGraph: ApplicationGraph<Method, Statement>,
) {
val root: Block<Statement>
private var nextBlockId: Int = 0
private val blockStatementMapping = HashMap<Block<Statement>, MutableList<Statement>>()

val blocks: Collection<Block<Statement>>
get() = blockStatementMapping.keys

init {
root = buildGraph(initialStatement)
}

fun getGraphBlock(statement: Statement): Block<Statement>? {
blockStatementMapping.forEach {
if (statement in it.value) {
return it.key
}
}
return null
}

private fun initializeGraphBlockWith(statement: Statement): Block<Statement> {
val currentBlock = Block(nextBlockId++, path = mutableListOf(statement))
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(statement)
return currentBlock
}

private fun createAndLinkWithPreds(statement: Statement): Block<Statement> {
val currentBlock = initializeGraphBlockWith(statement)
for (pred in applicationGraph.predecessors(statement)) {
getGraphBlock(pred)?.children?.add(currentBlock)
getGraphBlock(pred)?.let { currentBlock.parents.add(it) }
}
return currentBlock
}

private fun Statement.inBlock() = getGraphBlock(this) != null

private fun ApplicationGraph<Method, Statement>.filterStmtSuccsNotInBlock(
statement: Statement,
forceNewBlock: Boolean
): Sequence<Pair<Statement, Boolean>> {
return this.successors(statement).filter { !it.inBlock() }.map { Pair(it, forceNewBlock) }
}

fun buildGraph(initial: Statement): Block<Statement> {
val root = initializeGraphBlockWith(initial)
var currentBlock = root
val statementQueue = ArrayDeque<Pair<Statement, Boolean>>()

val initialHasMultipleSuccessors = applicationGraph.successors(initial).count() > 1
statementQueue.addAll(
applicationGraph.filterStmtSuccsNotInBlock(
initial,
forceNewBlock = initialHasMultipleSuccessors
)
)

while (statementQueue.isNotEmpty()) {
val (currentStatement, forceNew) = statementQueue.removeFirst()

if (forceNew) {
// don't need to add `currentStatement` succs, we did it earlier
createAndLinkWithPreds(currentStatement)
continue
}

// if statement is a call or if statement has multiple successors: next statements start new block
if (applicationGraph.callees(currentStatement).any() || applicationGraph.successors(currentStatement).count() > 1) {
currentBlock.path.add(currentStatement)
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true))
continue
}

// if statement has multiple ins: next statements start new block
if (applicationGraph.predecessors(currentStatement).count() > 1) {
currentBlock = createAndLinkWithPreds(currentStatement)
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true))
continue
}

currentBlock.path.add(currentStatement)
blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement)
statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = false))
}

return root
}

fun getEdges(): List<GameMapEdge> {
return blocks.flatMap { block ->
block.children.map { GameMapEdge(it.id, block.id, GameEdgeLabel(0)) }
}
}

fun getVertices(): Collection<Block<Statement>> = blocks

fun getBlockFeatures(
block: Block<Statement>, isCovered: (Statement) -> Boolean,
inCoverageZone: (Statement) -> Boolean,
isVisited: (Statement) -> Boolean,
stateIdsInBlock: List<UInt>
): BlockFeatures {
val firstStatement = block.path.first()
val lastStatement = block.path.last()
val visitedByState = isVisited(lastStatement)
val touchedByState = visitedByState || isVisited(firstStatement)

return BlockFeatures(
id = block.id,
inCoverageZone = inCoverageZone(firstStatement),
basicBlockSize = block.path.size,
coveredByTest = isCovered(firstStatement),
visitedByState = visitedByState,
touchedByState = touchedByState,
states = stateIdsInBlock
)
}
}
Loading