diff --git a/Game_env/test_model.onnx b/Game_env/test_model.onnx new file mode 100644 index 0000000000..d67d151a47 Binary files /dev/null and b/Game_env/test_model.onnx differ diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index 8f765786f8..b3b0b91976 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -7,6 +7,8 @@ object Versions { const val jcdb = "1.2.0" const val mockk = "1.13.4" const val junitParams = "5.9.3" + const val serialization = "1.5.1" + const val onnxruntime = "1.15.1" const val logback = "1.4.8" // versions for jvm samples diff --git a/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts b/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts index 2d3c8a49c4..1f8c90aa72 100644 --- a/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts @@ -20,6 +20,7 @@ dependencies { implementation(kotlin("stdlib-jdk8")) implementation(kotlin("reflect")) implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}") + implementation("com.microsoft.onnxruntime", "onnxruntime", Versions.onnxruntime) testImplementation(kotlin("test")) } diff --git a/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt b/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt index 13ae4d9b88..f8425720b9 100644 --- a/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt +++ b/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt @@ -34,6 +34,11 @@ sealed class PathsTrieNode, Stateme */ abstract val depth: Int + /** + * States that were forked from this node + */ + abstract val accumulatedForks: MutableCollection + /** * Adds a new label to [labels] collection. */ @@ -77,9 +82,10 @@ class PathsTrieNodeImpl, Statement> depth = parentNode.depth + 1, parent = parentNode, states = hashSetOf(state), - statement = statement + statement = statement, ) { parentNode.children[statement] = this + parentNode.accumulatedForks.addAll(this.states) } internal constructor(parentNode: PathsTrieNodeImpl, statement: Statement, state: State) : this( @@ -89,11 +95,14 @@ class PathsTrieNodeImpl, Statement> statement = statement ) { parentNode.children[statement] = this + parentNode.accumulatedForks.addAll(this.states) parentNode.states -= state } override val labels: MutableSet = hashSetOf() + override val accumulatedForks: MutableCollection = mutableSetOf() + override fun addLabel(label: Any) { labels.add(label) } @@ -115,6 +124,8 @@ class RootNode, Statement> : PathsT override val labels: MutableSet = hashSetOf() + override val accumulatedForks: MutableCollection = mutableSetOf() + override val depth: Int = 0 override fun addLabel(label: Any) { diff --git a/usvm-core/src/main/kotlin/org/usvm/ps/BlockGraph.kt b/usvm-core/src/main/kotlin/org/usvm/ps/BlockGraph.kt new file mode 100644 index 0000000000..612cd8f36b --- /dev/null +++ b/usvm-core/src/main/kotlin/org/usvm/ps/BlockGraph.kt @@ -0,0 +1,149 @@ +package org.usvm.ps + +import org.usvm.statistics.ApplicationGraph + +data class Block( + val id: Int, + var path: MutableList = mutableListOf(), + + var parents: MutableSet> = mutableSetOf(), + var children: MutableSet> = mutableSetOf() +) { + override fun hashCode(): Int = id + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as Block<*> + + if (id != other.id) return false + + return true + } +} + +class BlockGraph( + initialStatement: Statement, + private val applicationGraph: ApplicationGraph, +) { + val root: Block + private var nextBlockId: Int = 0 + private val blockStatementMapping = HashMap, MutableList>() + + val blocks: Collection> + get() = blockStatementMapping.keys + + init { + root = buildGraph(initialStatement) + } + + fun getGraphBlock(statement: Statement): Block? { + blockStatementMapping.forEach { + if (statement in it.value) { + return it.key + } + } + return null + } + + private fun initializeGraphBlockWith(statement: Statement): Block { + val currentBlock = Block(nextBlockId++, path = mutableListOf(statement)) + blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(statement) + return currentBlock + } + + private fun createAndLinkWithPreds(statement: Statement): Block { + val currentBlock = initializeGraphBlockWith(statement) + for (pred in applicationGraph.predecessors(statement)) { + getGraphBlock(pred)?.children?.add(currentBlock) + getGraphBlock(pred)?.let { currentBlock.parents.add(it) } + } + return currentBlock + } + + private fun Statement.inBlock() = getGraphBlock(this) != null + + private fun ApplicationGraph.filterStmtSuccsNotInBlock( + statement: Statement, + forceNewBlock: Boolean + ): Sequence> { + return this.successors(statement).filter { !it.inBlock() }.map { Pair(it, forceNewBlock) } + } + + fun buildGraph(initial: Statement): Block { + val root = initializeGraphBlockWith(initial) + var currentBlock = root + val statementQueue = ArrayDeque>() + + val initialHasMultipleSuccessors = applicationGraph.successors(initial).count() > 1 + statementQueue.addAll( + applicationGraph.filterStmtSuccsNotInBlock( + initial, + forceNewBlock = initialHasMultipleSuccessors + ) + ) + + while (statementQueue.isNotEmpty()) { + val (currentStatement, forceNew) = statementQueue.removeFirst() + + if (forceNew) { + // don't need to add `currentStatement` succs, we did it earlier + createAndLinkWithPreds(currentStatement) + continue + } + + // if statement is a call or if statement has multiple successors: next statements start new block + if (applicationGraph.callees(currentStatement).any() || applicationGraph.successors(currentStatement).count() > 1) { + currentBlock.path.add(currentStatement) + blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement) + statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true)) + continue + } + + // if statement has multiple ins: next statements start new block + if (applicationGraph.predecessors(currentStatement).count() > 1) { + currentBlock = createAndLinkWithPreds(currentStatement) + blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement) + statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = true)) + continue + } + + currentBlock.path.add(currentStatement) + blockStatementMapping.computeIfAbsent(currentBlock) { mutableListOf() }.add(currentStatement) + statementQueue.addAll(applicationGraph.filterStmtSuccsNotInBlock(currentStatement, forceNewBlock = false)) + } + + return root + } + + fun getEdges(): List { + return blocks.flatMap { block -> + block.children.map { GameMapEdge(it.id, block.id, GameEdgeLabel(0)) } + } + } + + fun getVertices(): Collection> = blocks + + fun getBlockFeatures( + block: Block, isCovered: (Statement) -> Boolean, + inCoverageZone: (Statement) -> Boolean, + isVisited: (Statement) -> Boolean, + stateIdsInBlock: List + ): BlockFeatures { + val firstStatement = block.path.first() + val lastStatement = block.path.last() + val visitedByState = isVisited(lastStatement) + val touchedByState = visitedByState || isVisited(firstStatement) + + return BlockFeatures( + id = block.id, + inCoverageZone = inCoverageZone(firstStatement), + basicBlockSize = block.path.size, + coveredByTest = isCovered(firstStatement), + visitedByState = visitedByState, + touchedByState = touchedByState, + states = stateIdsInBlock + ) + } +} diff --git a/usvm-core/src/main/kotlin/org/usvm/ps/BlockGraphPathSelector.kt b/usvm-core/src/main/kotlin/org/usvm/ps/BlockGraphPathSelector.kt new file mode 100644 index 0000000000..f45cab5941 --- /dev/null +++ b/usvm-core/src/main/kotlin/org/usvm/ps/BlockGraphPathSelector.kt @@ -0,0 +1,196 @@ +package org.usvm.ps + +import org.usvm.* +import org.usvm.constraints.UPathConstraints +import org.usvm.statistics.* +import java.io.File +import kotlin.io.path.Path + +data class GameEdgeLabel( + val token: Int +) + +data class GameMapEdge( + val vertexFrom: Int, + val vertexTo: Int, + val label: GameEdgeLabel, +) + +data class BlockFeatures( + val uid: Int = 0, + val id: Int, + val basicBlockSize: Int, + val inCoverageZone: Boolean, + val coveredByTest: Boolean, + val visitedByState: Boolean, + val touchedByState: Boolean, + val states: List, +) + +data class StateHistoryElem( + val graphVertexId: Int, + val numOfVisits: Int, +) + +data class StateFeatures( + val id: StateId, + val position: Int = 0, + val predictedUsefulness: Int = 42, + val pathConditionSize: Int, + val visitedAgainVertices: Int, + val visitedNotCoveredVerticesInZone: Int, + val visitedNotCoveredVerticesOutOfZone: Int, + val history: List, + val children: List, +) + +open class BlockGraphPathSelector, Statement, Method>( + private val coverageStatistics: CoverageStatistics, + val applicationGraph: ApplicationGraph +) : UPathSelector { + protected val states: MutableList = mutableListOf() + private val visitedStatements = HashSet() + + private val filename: String + + protected val blockGraph: BlockGraph + + init { + val method = applicationGraph.methodOf(coverageStatistics.getUncoveredStatements().first()) + filename = method.toString().dropWhile { it != ')' }.drop(1) + blockGraph = BlockGraph(applicationGraph.entryPoints(method).first(), applicationGraph) + } + + private fun getNonThrowingLeaves(root: Block): Collection> { + val queue = ArrayDeque>() + val visited = HashSet>() + val leaves = mutableListOf>() + queue.addAll(root.children) + + while (queue.isNotEmpty()) { + val next = queue.removeFirst() + if (next.children.isEmpty() && next.path.none { applicationGraph.isThrowing(it) }) { + leaves.add(next) + } + visited.add(next) + queue.addAll(next.children) + } + + return leaves + } + + fun getStateFeatures(state: State): StateFeatures { + val blockHistory = mutableListOf>() + + var lastBlock: Block = blockGraph.root + blockHistory.add(lastBlock) + for (statement in state.listPath()) { + if (statement !in lastBlock.path) { + val someBlockOpt = blockGraph.getGraphBlock(statement) + // if `statement` already has a block + if (someBlockOpt != null) { + blockHistory.add(someBlockOpt) + lastBlock = someBlockOpt + } + else { // encountered non-explored statements, extend block graph + val callRoot = blockGraph.buildGraph(statement) + val callExitBlocksToConnect = getNonThrowingLeaves(callRoot) + + // if `statement` is last in prev block -> connect to `lastBlock` children + if (statement == lastBlock.path.last() && statement != state.listPath().last()) { + lastBlock.children.forEach { lastBlockChild -> + callExitBlocksToConnect.forEach { callExitBlock -> + callExitBlock.children.add(lastBlockChild) + } + } + } else { // connect to last block itself + callExitBlocksToConnect.forEach { externalExitBlock -> + externalExitBlock.children.add(lastBlock) + } + } + } + } + } + + var visitedNotCoveredVerticesInZone = 0 + var visitedNotCoveredVerticesOutOfZone = 0 + + for (block in blockHistory.map { block -> + blockGraph.getBlockFeatures( + block = block, + isCovered = ::isCovered, + inCoverageZone = ::inCoverageZone, + isVisited = ::isVisited, + stateIdsInBlock = states.filter { it.currentStatement in block.path }.map { it.id } + ) + }) { + if (block.visitedByState && !block.coveredByTest) { + if (block.inCoverageZone) { + visitedNotCoveredVerticesInZone += 1 + } else visitedNotCoveredVerticesOutOfZone += 1 + } + } + + return StateFeatures( + id = state.id, + pathConditionSize = state.pathConstraints.size(), + visitedAgainVertices = state.listPath().count() - state.listPath().distinct().count(), + visitedNotCoveredVerticesInZone = visitedNotCoveredVerticesInZone, + visitedNotCoveredVerticesOutOfZone = visitedNotCoveredVerticesOutOfZone, + history = blockHistory.map { block -> + StateHistoryElem( + block.id, + blockHistory.count { block.id == it.id }) + }, + children = state.pathLocation.accumulatedForks.map { it.id } + ) + } + + protected fun isCovered(statement: Statement): Boolean { + return statement in coverageStatistics.getUncoveredStatements() + } + + protected fun inCoverageZone(statement: Statement): Boolean { + return coverageStatistics.inCoverageZone(applicationGraph.methodOf(statement)) + } + + protected fun isVisited(statement: Statement) = statement in visitedStatements + + override fun isEmpty(): Boolean { + return states.isEmpty() + } + + override fun peek(): State { + return states.first() + } + + override fun update(state: State) {} + + override fun add(states: Collection) { + this.states += states + } + + override fun remove(state: State) { + states.remove(state) + } +} + +fun UPathConstraints.size(): Int { + return numericConstraints.constraints().count() + + this.equalityConstraints.distinctReferences.count() + + this.equalityConstraints.equalReferences.count() + + this.equalityConstraints.referenceDisequalities.count() + + this.equalityConstraints.nullableDisequalities.count() + + this.logicalConstraints.count() + + this.typeConstraints.symbolicRefToTypeRegion.count() // TODO: maybe throw out? +} + +fun, Statement> UState<*, *, Statement, *, *, State>.listPath(): List { + val statements = mutableListOf() + var current: PathsTrieNode? = this.pathLocation + while (current !is RootNode && current != null) { + statements.add(current.statement) + current = current.parent + } + return statements +} diff --git a/usvm-core/src/main/kotlin/org/usvm/ps/GNNPathSelector.kt b/usvm-core/src/main/kotlin/org/usvm/ps/GNNPathSelector.kt new file mode 100644 index 0000000000..94bfbf6187 --- /dev/null +++ b/usvm-core/src/main/kotlin/org/usvm/ps/GNNPathSelector.kt @@ -0,0 +1,221 @@ +package org.usvm.ps + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import org.usvm.UState +import org.usvm.statistics.ApplicationGraph +import org.usvm.statistics.CoverageStatistics +import java.nio.FloatBuffer +import java.nio.LongBuffer +import kotlin.collections.set +import kotlin.io.path.Path + +data class GraphNative( + val gameVertex: List>, + val stateVertex: List>, + val gameVertexToGameVertex: List>, + val gameVertexHistoryStateVertexIndex: List>, + val gameVertexHistoryStateVertexAttrs: List, + val gameVertexInStateVertex: List>, + val stateVertexParentOfStateVertex: List>, + val stateMap: Map +) + +open class GNNPathSelector, Statement, Method>( + applicationGraph: ApplicationGraph, + private val coverageStatistics: CoverageStatistics, +) : BlockGraphPathSelector( + coverageStatistics, + applicationGraph +) { + companion object { + private val gnnModelPath = Path("/Users/emax/Data/usvm/Game_env/test_model.onnx").toString() + private var env: OrtEnvironment = OrtEnvironment.getEnvironment() + private var gnnSession: OrtSession = env.createSession(gnnModelPath) + } + + fun coverage(): Float { + return coverageStatistics.getTotalCoverage() + } + + override fun peek(): State { + val nativeInput = createNativeInput() + + val gameVertexTensor = onnxFloatTensor(nativeInput.gameVertex) + val stateVertexTensor = onnxFloatTensor(nativeInput.stateVertex) + val gameVertexToGameVertexTensor = onnxLongTensor(nativeInput.gameVertexToGameVertex.transpose()) + val gameVertexHistoryStateVertexIndexTensor = + onnxLongTensor(nativeInput.gameVertexHistoryStateVertexIndex.transpose()) + val gameVertexHistoryStateVertexAttrsTensor = + onnxLongTensor(nativeInput.gameVertexHistoryStateVertexAttrs.map { listOf(it) }) + val gameVertexInStateVertexTensor = onnxLongTensor(nativeInput.gameVertexInStateVertex.transpose()) + val stateVertexParentOfStateVertexTensor = + onnxLongTensor(nativeInput.stateVertexParentOfStateVertex.transpose()) + + val res = gnnSession.run( + mapOf( + "game_vertex" to gameVertexTensor, + "state_vertex" to stateVertexTensor, + "game_vertex to game_vertex" to gameVertexToGameVertexTensor, + "game_vertex history state_vertex index" to gameVertexHistoryStateVertexIndexTensor, + "game_vertex history state_vertex attrs" to gameVertexHistoryStateVertexAttrsTensor, + "game_vertex in state_vertex" to gameVertexInStateVertexTensor, + "state_vertex parent_of state_vertex" to stateVertexParentOfStateVertexTensor + ) + ) + + val predictedStatesRanks = + (res["out"].get().value as Array<*>).map { it as FloatArray }.map { it.toList() }.toList() + val chosenStateId = predictState(predictedStatesRanks, nativeInput.stateMap) + + return states.single { state -> state.id.toInt() == chosenStateId } + } + + private fun createNativeInput(): GraphNative { + val nodesState = mutableListOf>() + val nodesVertex = mutableListOf>() + val edgesIndexVSHistory = mutableListOf>() + val edgesAttrVS = mutableListOf() + val edgesIndexSS = mutableListOf>() + val edgesIndexVSIn = mutableListOf>() + + val stateMap = mutableMapOf() + val vertexMap = mutableMapOf() + + val statesFeatures = states.map { getStateFeatures(it) } + + for ((stateIndexOrder, stateFeatures) in statesFeatures.withIndex()) { + stateMap[stateFeatures.id.toInt()] = stateIndexOrder + nodesState.add(stateFeatures.toList()) + } + + for ((vertexIndexOrder, vertex) in blockGraph.getVertices().withIndex()) { + vertexMap[vertex.id] = vertexIndexOrder + val blockFeatures = blockGraph.getBlockFeatures( + block = vertex, + isCovered = ::isCovered, + inCoverageZone = ::inCoverageZone, + isVisited = ::isVisited, + stateIdsInBlock = states.filter { it.currentStatement in vertex.path }.map { it.id } + ) + nodesVertex.add(blockFeatures.toList()) + } + + val edgesIndexVV = blockGraph.getEdges().map { listOf(vertexMap[it.vertexFrom]!!, vertexMap[it.vertexTo]!!) } + + for ((stateIndexOrder, stateFeatures) in statesFeatures.withIndex()) { + for (historyEdge in stateFeatures.history) { + val vertexTo = vertexMap[historyEdge.graphVertexId]!! + edgesIndexVSHistory.add(listOf(vertexTo, stateIndexOrder)) + edgesAttrVS.add(historyEdge.numOfVisits) + } + } + + for (stateFeatures in statesFeatures) { + for (childId in stateFeatures.children) { + if (childId.toInt() in stateMap.keys) + edgesIndexSS.add(listOf(stateMap[stateFeatures.id.toInt()]!!, stateMap[childId.toInt()]!!)) + } + } + + for (vertex in blockGraph.getVertices()) { + for (state in states) { + edgesIndexVSIn.add(listOf(vertexMap[vertex.id]!!, stateMap[state.id.toInt()]!!)) + } + } + + return GraphNative( + gameVertex = nodesVertex, + stateVertex = nodesState, + gameVertexToGameVertex = edgesIndexVV, + gameVertexHistoryStateVertexIndex = edgesIndexVSHistory, + gameVertexHistoryStateVertexAttrs = edgesAttrVS, + gameVertexInStateVertex = edgesIndexVSIn, + stateVertexParentOfStateVertex = edgesIndexSS, + stateMap = stateMap + ) + } + + private fun List.create2DFloatBuffer(shape: Pair): OnnxTensor { + val longArrayOfShape = longArrayOf(shape.first.toLong(), shape.second.toLong()) + return OnnxTensor.createTensor( + env, + FloatBuffer.wrap(this.toFloatArray()), + longArrayOfShape + ) + } + + private fun List.create2DLongBuffer(shape: Pair): OnnxTensor { + val longArrayOfShape = longArrayOf(shape.first.toLong(), shape.second.toLong()) + return OnnxTensor.createTensor( + env, + LongBuffer.wrap(this.toLongArray()), + longArrayOfShape + ) + } + + private fun onnxFloatTensor(data: List>): OnnxTensor { + return data.flatten().map { it.toFloat() }.create2DFloatBuffer(get2DShape(data)) + } + + private fun onnxLongTensor(data: List>): OnnxTensor { + return data.flatten().map { it.toLong() }.create2DLongBuffer(get2DShape(data)) + } +} + +private fun predictState(stateRank: List>, stateMap: Map): Int { + val reverseStateMap = stateMap.entries.associate { (k, v) -> v to k } + + val stateRankMapping = stateRank.mapIndexed { orderIndex, rank -> + reverseStateMap[orderIndex]!! to rank + } + + return stateRankMapping.maxBy { it.second.sum() }.first +} + +private fun get2DShape(data: List>): Pair { + if (data.isEmpty()) { + return Pair(0, 0) + } + if (data[0].isEmpty()) { + return Pair(data.size, 0) + } + return Pair(data.size, data[0].size) +} + +private fun List>.transpose(): List> { + if (this.isEmpty()) { + return listOf(listOf(), listOf()) + } + + val (rows, cols) = get2DShape(this) + return List(cols) { j -> + List(rows) { i -> + this[i][j] + } + } +} + +private fun Boolean.toInt(): Int = if (this) 1 else 0 + +private fun BlockFeatures.toList(): List { + return listOf( + this.inCoverageZone.toInt(), + this.basicBlockSize, + this.coveredByTest.toInt(), + this.visitedByState.toInt(), + this.touchedByState.toInt() + ) +} + +private fun StateFeatures.toList(): List { + return listOf( + this.position, + this.predictedUsefulness, + this.pathConditionSize, + this.visitedAgainVertices, + this.visitedNotCoveredVerticesInZone, + this.visitedNotCoveredVerticesOutOfZone + ) +} diff --git a/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt b/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt index 321c04ff56..9b1a933fcc 100644 --- a/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt +++ b/usvm-core/src/main/kotlin/org/usvm/ps/PathSelectorFactory.kt @@ -84,6 +84,11 @@ fun , State : USta applicationGraph, random ) + + PathSelectionStrategy.GNN -> GNNPathSelector( + applicationGraph, + requireNotNull(coverageStatistics()) { "Coverage statistics is required for Hetero GNN path selector" }, + ) } } diff --git a/usvm-core/src/main/kotlin/org/usvm/statistics/ApplicationGraph.kt b/usvm-core/src/main/kotlin/org/usvm/statistics/ApplicationGraph.kt index 48fab1befe..aa02c7fdd2 100644 --- a/usvm-core/src/main/kotlin/org/usvm/statistics/ApplicationGraph.kt +++ b/usvm-core/src/main/kotlin/org/usvm/statistics/ApplicationGraph.kt @@ -13,4 +13,6 @@ interface ApplicationGraph { fun methodOf(node: Statement): Method fun statementsOf(method: Method): Sequence + + fun isThrowing(node: Statement): Boolean } diff --git a/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt b/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt index d2751472f8..2d345ec347 100644 --- a/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt +++ b/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt @@ -89,6 +89,10 @@ class CoverageStatistics> : UMachineObserver { + + private var steps = 0 + override fun onState(parent: State, forks: Sequence) { + steps += 1 + } + + fun getStepsCount() = steps +} diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApplicationGraph.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApplicationGraph.kt index f0a0e2bb4a..405159e5cb 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApplicationGraph.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApplicationGraph.kt @@ -5,6 +5,7 @@ import org.jacodb.api.JcClasspath import org.jacodb.api.JcMethod import org.jacodb.api.JcTypedMethod import org.jacodb.api.cfg.JcInst +import org.jacodb.api.cfg.JcThrowInst import org.jacodb.api.ext.toType import org.jacodb.impl.features.HierarchyExtensionImpl import org.jacodb.impl.features.SyncUsagesExtension @@ -73,4 +74,8 @@ class JcApplicationGraph( return statements } + + override fun isThrowing(node: JcInst): Boolean { + return node is JcThrowInst + } } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt index 7fd0043522..c7d6fb55e4 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt @@ -805,7 +805,7 @@ class JcExprResolver( fun allocateException(type: JcRefType): (JcState) -> Unit = { state -> // TODO should we consider exceptions with negative addresses? val address = state.memory.allocConcrete(type) - state.throwExceptionWithoutStackFrameDrop(address, type) + state.throwExceptionWithoutStackFrameDrop(address, type, false) } fun checkArrayIndex(idx: USizeExpr, length: USizeExpr) = with(ctx) { diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt index 5e6f15c3b9..372d81288b 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt @@ -334,7 +334,7 @@ class JcInterpreter( val address = resolver.resolveJcExpr(stmt.throwable)?.asExpr(ctx.addressSort) ?: return scope.calcOnState { - throwExceptionWithoutStackFrameDrop(address, stmt.throwable.type) + throwExceptionWithoutStackFrameDrop(address, stmt.throwable.type, expected = true) } } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcMethodResult.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcMethodResult.kt index f0f05b2663..29c5688d68 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcMethodResult.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcMethodResult.kt @@ -31,8 +31,9 @@ sealed interface JcMethodResult { open class JcException( val address: UHeapRef, val type: JcType, - val symbolicStackTrace: List> + val symbolicStackTrace: List>, + val expected: Boolean, ) : JcMethodResult { - override fun toString(): String = "${this::class.simpleName}: Address: $address, type: ${type.typeName}" + override fun toString(): String = "${this::class.simpleName}: Address: $address, type: ${type.typeName}, is expected: $expected" } } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt index 066c9ded34..ae05f53271 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt @@ -47,6 +47,9 @@ class JcState( override val isExceptional: Boolean get() = methodResult is JcMethodResult.JcException + val isExceptionalAndNotExpected: Boolean + get() = isExceptional && !(methodResult as JcMethodResult.JcException).expected + override fun toString(): String = buildString { appendLine("Instruction: $lastStmt") if (isExceptional) appendLine("Exception: $methodResult") diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt index 4dafd26f58..9b7a8e112e 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt @@ -37,8 +37,8 @@ fun JcState.returnValue(valueToReturn: UExpr) { /** * Create an unprocessed exception with the [address] and the [type] and assign it to the [JcState.methodResult]. */ -fun JcState.throwExceptionWithoutStackFrameDrop(address: UHeapRef, type: JcType) { - methodResult = JcMethodResult.JcException(address, type, callStack.stackTrace(lastStmt)) +fun JcState.throwExceptionWithoutStackFrameDrop(address: UHeapRef, type: JcType, expected: Boolean) { + methodResult = JcMethodResult.JcException(address, type, callStack.stackTrace(lastStmt), expected) } fun JcState.throwExceptionAndDropStackFrame() { diff --git a/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt b/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt index 797a631e09..b2e52287af 100644 --- a/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt +++ b/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt @@ -77,7 +77,12 @@ enum class PathSelectionStrategy { * reachability. * States are selected randomly with distribution based on distance to targets. */ - TARGETED_CALL_STACK_LOCAL_RANDOM + TARGETED_CALL_STACK_LOCAL_RANDOM, + + /** + * Selects state with the best score according to GNN + */ + GNN } enum class PathSelectorCombinationStrategy {