From 29e0d77552bb7bf0ac13fd7a87d11c7321466fd3 Mon Sep 17 00:00:00 2001 From: Max Nigmatulin Date: Wed, 4 Oct 2023 15:37:41 +0300 Subject: [PATCH] Create ApplicationBlockGraph interface, jvm implementation --- .gitignore | 3 + buildSrc/src/main/kotlin/Versions.kt | 1 + .../kotlin/usvm.kotlin-conventions.gradle.kts | 1 + settings.gradle.kts | 1 + .../src/main/kotlin/org/usvm/PathTrieNode.kt | 11 + .../org/usvm/statistics/CoverageStatistics.kt | 4 + usvm-ml-path-selection/build.gradle | 11 + .../kotlin/org/usvm/ApplicationBlockGraph.kt | 9 + .../main/kotlin/org/usvm/MLMachineOptions.kt | 25 + .../kotlin/org/usvm/MLPathSelectorStrategy.kt | 53 ++ .../org/usvm/jvm/JcApplicationBlockGraph.kt | 100 ++++ .../jvm/interpreter/JcBlockInterpreter.kt | 493 ++++++++++++++++++ .../org/usvm/jvm/machine/MLJcMachine.kt | 110 ++++ .../kotlin/org/usvm/ps/GNNPathSelector.kt | 343 ++++++++++++ .../src/main/kotlin/org/usvm/ps/Utils.kt | 61 +++ .../usvm/statistics/StateVisitsStatistics.kt | 14 + .../kotlin/org/usvm/jvm/JacoDBContainer.kt | 53 ++ .../org/usvm/jvm/JavaMethodTestRunner.kt | 95 ++++ .../src/test/kotlin/org/usvm/jvm/util/Util.kt | 15 + 19 files changed, 1403 insertions(+) create mode 100644 usvm-ml-path-selection/build.gradle create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/ApplicationBlockGraph.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/MLMachineOptions.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/MLPathSelectorStrategy.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/JcApplicationBlockGraph.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/interpreter/JcBlockInterpreter.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/machine/MLJcMachine.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/GNNPathSelector.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/Utils.kt create mode 100644 usvm-ml-path-selection/src/main/kotlin/org/usvm/statistics/StateVisitsStatistics.kt create mode 100644 usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JacoDBContainer.kt create mode 100644 usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JavaMethodTestRunner.kt create mode 100644 usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/util/Util.kt diff --git a/.gitignore b/.gitignore index 649ad2b413..0e98610ea8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ buildSrc/.gradle # Ignore Idea directory .idea + +# Ignore MacOS-specific files +/**/.DS_Store diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index 0a2f551ebd..b054302768 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -8,6 +8,7 @@ object Versions { const val mockk = "1.13.4" const val junitParams = "5.9.3" const val logback = "1.4.8" + const val onnxruntime = "1.15.1" // versions for jvm samples const val samplesLombok = "1.18.20" diff --git a/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts b/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts index 2d3c8a49c4..f9daf7586f 100644 --- a/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/usvm.kotlin-conventions.gradle.kts @@ -19,6 +19,7 @@ dependencies { implementation(kotlin("stdlib-jdk8")) implementation(kotlin("reflect")) + implementation("com.microsoft.onnxruntime", "onnxruntime", Versions.onnxruntime) implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}") testImplementation(kotlin("test")) diff --git a/settings.gradle.kts b/settings.gradle.kts index 75ba8e1ab5..04e91e25df 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -4,3 +4,4 @@ include("usvm-core") include("usvm-jvm") include("usvm-util") include("usvm-sample-language") +include("usvm-ml-path-selection") diff --git a/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt b/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt index 13ae4d9b88..d0d3a19277 100644 --- a/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt +++ b/usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt @@ -34,6 +34,11 @@ sealed class PathsTrieNode, Stateme */ abstract val depth: Int + /** + * States that were forked from this node + */ + abstract val accumulatedForks: MutableCollection + /** * Adds a new label to [labels] collection. */ @@ -80,6 +85,7 @@ class PathsTrieNodeImpl, Statement> statement = statement ) { parentNode.children[statement] = this + parentNode.accumulatedForks.addAll(this.states) } internal constructor(parentNode: PathsTrieNodeImpl, statement: Statement, state: State) : this( @@ -89,11 +95,14 @@ class PathsTrieNodeImpl, Statement> statement = statement ) { parentNode.children[statement] = this + parentNode.accumulatedForks.addAll(this.states) parentNode.states -= state } override val labels: MutableSet = hashSetOf() + override val accumulatedForks: MutableCollection = mutableSetOf() + override fun addLabel(label: Any) { labels.add(label) } @@ -115,6 +124,8 @@ class RootNode, Statement> : PathsT override val labels: MutableSet = hashSetOf() + override val accumulatedForks: MutableCollection = mutableSetOf() + override val depth: Int = 0 override fun addLabel(label: Any) { diff --git a/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt b/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt index d2751472f8..3a2c49cc96 100644 --- a/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt +++ b/usvm-core/src/main/kotlin/org/usvm/statistics/CoverageStatistics.kt @@ -89,6 +89,10 @@ class CoverageStatistics : ApplicationGraph { + fun blockOf(stmt: Statement): BasicBlock + fun instructions(block: BasicBlock): Sequence + fun blocks(): Sequence +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLMachineOptions.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLMachineOptions.kt new file mode 100644 index 0000000000..76bf0b14ab --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLMachineOptions.kt @@ -0,0 +1,25 @@ +package org.usvm + +enum class MLPathSelectionStrategy { + /** + * Collects features according to states selected by any other path selector. + */ + FEATURES_LOGGING, + + /** + * Collects features and feeds them to the ML model to select states. + * Extends FEATURE_LOGGING path selector. + */ + MACHINE_LEARNING, + + /** + * Selects states with best Graph Neural Network state score + */ + GNN, +} + +data class MLMachineOptions( + val basicOptions: UMachineOptions, + val pathSelectionStrategy: MLPathSelectionStrategy, + val heteroGNNModelPath: String = "" +) diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLPathSelectorStrategy.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLPathSelectorStrategy.kt new file mode 100644 index 0000000000..2dedcfb784 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLPathSelectorStrategy.kt @@ -0,0 +1,53 @@ +package org.usvm + +import org.usvm.ps.ExceptionPropagationPathSelector +import org.usvm.ps.GNNPathSelector +import org.usvm.statistics.CoverageStatistics +import org.usvm.statistics.StateVisitsStatistics + +fun > createPathSelector( + initialState: State, + options: MLMachineOptions, + applicationGraph: ApplicationBlockGraph, + stateVisitsStatistics: StateVisitsStatistics, + coverageStatistics: CoverageStatistics, +): UPathSelector { + val selector = when (options.pathSelectionStrategy) { + MLPathSelectionStrategy.GNN -> createGNNPathSelector( + stateVisitsStatistics, + coverageStatistics, applicationGraph, options.heteroGNNModelPath + ) + + else -> { + throw NotImplementedError() + } + } + + val propagateExceptions = options.basicOptions.exceptionsPropagation + + val resultSelector = selector.wrapIfRequired(propagateExceptions) + resultSelector.add(listOf(initialState)) + + return selector +} + +private fun > UPathSelector.wrapIfRequired(propagateExceptions: Boolean) = + if (propagateExceptions && this !is ExceptionPropagationPathSelector) { + ExceptionPropagationPathSelector(this) + } else { + this + } + +private fun > createGNNPathSelector( + stateVisitsStatistics: StateVisitsStatistics, + coverageStatistics: CoverageStatistics, + applicationGraph: ApplicationBlockGraph, + heteroGNNModelPath: String, +): UPathSelector { + return GNNPathSelector( + coverageStatistics, + stateVisitsStatistics, + applicationGraph, + heteroGNNModelPath + ) +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/JcApplicationBlockGraph.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/JcApplicationBlockGraph.kt new file mode 100644 index 0000000000..180f388a4d --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/JcApplicationBlockGraph.kt @@ -0,0 +1,100 @@ +package org.usvm.jvm + +import org.jacodb.api.JcClasspath +import org.jacodb.api.JcMethod +import org.jacodb.api.JcTypedMethod +import org.jacodb.api.cfg.JcBasicBlock +import org.jacodb.api.cfg.JcBlockGraph +import org.jacodb.api.cfg.JcInst +import org.jacodb.api.ext.toType +import org.usvm.ApplicationBlockGraph +import org.usvm.machine.JcApplicationGraph +import java.util.concurrent.ConcurrentHashMap + +class JcApplicationBlockGraph(cp: JcClasspath) : + ApplicationBlockGraph { + val jcApplicationGraph: JcApplicationGraph = JcApplicationGraph(cp) + var initialStatement: JcInst? = null + + private fun initialStatement(): JcInst { + if (initialStatement == null) { + throw RuntimeException("initial statement not set") + } + return initialStatement!! + } + + private fun getBlockGraph() = initialStatement().location.method.flowGraph().blockGraph() + + override fun predecessors(node: JcBasicBlock): Sequence { + val jcBlockGraphImpl: JcBlockGraph = getBlockGraph() + return jcBlockGraphImpl.predecessors(node).asSequence() + } + + override fun successors(node: JcBasicBlock): Sequence { + val jcBlockGraphImpl: JcBlockGraph = getBlockGraph() + return jcBlockGraphImpl.successors(node).asSequence() + } + + override fun callees(node: JcBasicBlock): Sequence { + val jcBlockGraphImpl: JcBlockGraph = getBlockGraph() + + return jcBlockGraphImpl.instructions(node) + .map { jcApplicationGraph.callees(it) } + .reduce { acc, sequence -> acc + sequence } + .toSet() + .asSequence() + } + + override fun callers(method: JcMethod): Sequence { + return jcApplicationGraph + .callers(method) + .map { stmt -> blockOf(stmt) } + .toSet() + .asSequence() + } + + override fun entryPoints(method: JcMethod): Sequence = + method.flowGraph().blockGraph().entries.asSequence() + + override fun exitPoints(method: JcMethod): Sequence = + method.flowGraph().blockGraph().exits.asSequence() + + override fun methodOf(node: JcBasicBlock): JcMethod { + val firstInstruction = getBlockGraph().instructions(node).first() + return jcApplicationGraph.methodOf(firstInstruction) + } + + override fun instructions(block: JcBasicBlock): Sequence { + return getBlockGraph().instructions(block).asSequence() + } + + override fun statementsOf(method: JcMethod): Sequence { + return jcApplicationGraph + .statementsOf(method) + .map { stmt -> blockOf(stmt) } + .toSet() + .asSequence() + } + + override fun blockOf(stmt: JcInst): JcBasicBlock { + val jcBlockGraphImpl: JcBlockGraph = stmt.location.method.flowGraph().blockGraph() + val blocks = blocks() + for (block in blocks) { + if (stmt in jcBlockGraphImpl.instructions(block)) { + return block + } + } + throw IllegalStateException("block not found for $stmt in ${jcBlockGraphImpl.jcGraph.method}") + } + + override fun blocks(): Sequence { + return initialStatement().location.method.flowGraph().blockGraph().asSequence() + } + + private val typedMethodsCache = ConcurrentHashMap() + + val JcMethod.typed + get() = typedMethodsCache.getOrPut(this) { + enclosingClass.toType().declaredMethods.first { it.method == this } + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/interpreter/JcBlockInterpreter.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/interpreter/JcBlockInterpreter.kt new file mode 100644 index 0000000000..29da6006d8 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/interpreter/JcBlockInterpreter.kt @@ -0,0 +1,493 @@ +package org.usvm.jvm.interpreter + +import io.ksmt.utils.asExpr +import mu.KLogging +import org.jacodb.api.* +import org.jacodb.api.cfg.* +import org.jacodb.api.ext.boolean +import org.jacodb.api.ext.isEnum +import org.jacodb.api.ext.void +import org.usvm.* +import org.usvm.api.allocateStaticRef +import org.usvm.api.evalTypeEquals +import org.usvm.api.targets.JcTarget +import org.usvm.api.typeStreamOf +import org.usvm.jvm.JcApplicationBlockGraph +import org.usvm.machine.* +import org.usvm.machine.interpreter.JcExprResolver +import org.usvm.machine.interpreter.JcFixedInheritorsNumberTypeSelector +import org.usvm.machine.interpreter.JcStepScope +import org.usvm.machine.interpreter.JcTypeSelector +import org.usvm.machine.state.* +import org.usvm.memory.URegisterStackLValue +import org.usvm.solver.USatResult +import org.usvm.types.first +import org.usvm.util.findMethod +import org.usvm.util.write + +/** + * A JacoDB interpreter. + */ +class JcBlockInterpreter( + private val ctx: JcContext, + private val applicationGraph: JcApplicationBlockGraph, +) : UInterpreter() { + + companion object { + val logger = object : KLogging() {}.logger + } + + fun getInitialState(method: JcMethod, targets: List = emptyList()): JcState { + val state = JcState(ctx, targets = targets) + val typedMethod = with(applicationGraph) { method.typed } + + val entrypointArguments = mutableListOf>() + if (!method.isStatic) { + with(ctx) { + val thisLValue = URegisterStackLValue(addressSort, 0) + val ref = state.memory.read(thisLValue).asExpr(addressSort) + state.pathConstraints += mkEq(ref, nullRef).not() + val thisType = typedMethod.enclosingType + state.pathConstraints += mkIsSubtypeExpr(ref, thisType) + + entrypointArguments += thisType to ref + } + } + + typedMethod.parameters.forEachIndexed { idx, typedParameter -> + with(ctx) { + val type = typedParameter.type + if (type is JcRefType) { + val argumentLValue = URegisterStackLValue(typeToSort(type), method.localIdx(idx)) + val ref = state.memory.read(argumentLValue).asExpr(addressSort) + state.pathConstraints += mkIsSubtypeExpr(ref, type) + + entrypointArguments += type to ref + } + } + } + + val solver = ctx.solver() + + val model = (solver.checkWithSoftConstraints(state.pathConstraints) as USatResult).model + state.models = listOf(model) + + val entrypointInst = JcMethodEntrypointInst(method, entrypointArguments) + state.newStmt(entrypointInst) + return state + } + + override fun step(state: JcState): StepResult { + val stmt = state.lastStmt + + logger.debug("Step: {}", stmt) + + val scope = StepScope(state) + + // handle exception firstly + val result = state.methodResult + if (result is JcMethodResult.JcException) { + handleException(scope, result, stmt) + return scope.stepResult() + } + + when (stmt) { + is JcMethodCallBaseInst -> visitMethodCall(scope, stmt) + is JcAssignInst -> visitAssignInst(scope, stmt) + is JcIfInst -> visitIfStmt(scope, stmt) + is JcReturnInst -> visitReturnStmt(scope, stmt) + is JcGotoInst -> visitGotoStmt(scope, stmt) + is JcCatchInst -> visitCatchStmt(scope, stmt) + is JcSwitchInst -> visitSwitchStmt(scope, stmt) + is JcThrowInst -> visitThrowStmt(scope, stmt) + is JcCallInst -> visitCallStmt(scope, stmt) + is JcEnterMonitorInst -> visitMonitorEnterStmt(scope, stmt) + is JcExitMonitorInst -> visitMonitorExitStmt(scope, stmt) + else -> error("Unknown stmt: $stmt") + } + return scope.stepResult() + } + + private fun handleException( + scope: JcStepScope, + exception: JcMethodResult.JcException, + lastStmt: JcInst, + ) { + val catchStatements = applicationGraph.jcApplicationGraph.successors(lastStmt).filterIsInstance().toList() + + val typeConstraintsNegations = mutableListOf() + val catchForks = mutableListOf Unit>>() + + val blockToFork: (JcCatchInst) -> (JcState) -> Unit = { catchInst: JcCatchInst -> + block@{ state: JcState -> + val lValue = exprResolverWithScope(scope).resolveLValue(catchInst.throwable) ?: return@block + val exceptionResult = state.methodResult as JcMethodResult.JcException + + state.memory.write(lValue, exceptionResult.address) + + state.methodResult = JcMethodResult.NoCall + state.newStmt(catchInst.nextStmt) + } + } + + catchStatements.forEach { catchInst -> + val throwableTypes = catchInst.throwableTypes + + val typeConstraint = scope.calcOnState { + val currentTypeConstraints = throwableTypes.map { memory.types.evalIsSubtype(exception.address, it) } + val result = ctx.mkAnd(typeConstraintsNegations + ctx.mkOr(currentTypeConstraints)) + + typeConstraintsNegations += currentTypeConstraints.map { ctx.mkNot(it) } + + result + } + + catchForks += typeConstraint to blockToFork(catchInst) + } + + val typeConditionToMiss = ctx.mkAnd(typeConstraintsNegations) + val functionBlockOnMiss = block@{ _: JcState -> + scope.calcOnState { throwExceptionAndDropStackFrame() } + } + + val catchSectionMiss = typeConditionToMiss to functionBlockOnMiss + + scope.forkMulti(catchForks + catchSectionMiss) + } + + private val typeSelector = JcFixedInheritorsNumberTypeSelector() + + private fun visitMethodCall(scope: JcStepScope, stmt: JcMethodCallBaseInst) { + when (stmt) { + is JcMethodEntrypointInst -> { + scope.doWithState { + if (callStack.isEmpty()) { + val method = stmt.method + callStack.push(method, returnSite = null) + memory.stack.push(method.parametersWithThisCount, method.localsCount) + } + } + + val exprResolver = exprResolverWithScope(scope) + // Run static initializer for all enum arguments of the entrypoint + for ((type, ref) in stmt.entrypointArguments) { + exprResolver.ensureExprCorrectness(ref, type) ?: return + } + + val method = stmt.method + val entryPoint = applicationGraph.jcApplicationGraph.entryPoints(method).single() + scope.doWithState { + newStmt(entryPoint) + } + } + + is JcConcreteMethodCallInst -> { + if (approximateMethod(scope, stmt)) { + return + } + + if (stmt.method.isNative) { + mockNativeMethod(scope, stmt) + return + } + + scope.doWithState { + addNewMethodCall(applicationGraph.jcApplicationGraph, stmt) + } + } + + is JcVirtualMethodCallInst -> { + if (approximateMethod(scope, stmt)) { + return + } + + resolveVirtualInvoke(stmt, scope, typeSelector, forkOnRemainingTypes = false) + } + } + } + + private fun visitAssignInst(scope: JcStepScope, stmt: JcAssignInst) { + val exprResolver = exprResolverWithScope(scope) + val lvalue = exprResolver.resolveLValue(stmt.lhv) ?: return + val expr = exprResolver.resolveJcExpr(stmt.rhv, stmt.lhv.type) ?: return + + val nextStmt = stmt.nextStmt + scope.doWithState { + memory.write(lvalue, expr) + newStmt(nextStmt) + } + } + + private fun visitIfStmt(scope: JcStepScope, stmt: JcIfInst) { + val exprResolver = exprResolverWithScope(scope) + + val boolExpr = exprResolver + .resolveJcExpr(stmt.condition) + ?.asExpr(ctx.boolSort) + ?: return + + val instList = stmt.location.method.instList + val (posStmt, negStmt) = instList[stmt.trueBranch.index] to instList[stmt.falseBranch.index] + + scope.fork( + boolExpr, + blockOnTrueState = { newStmt(posStmt) }, + blockOnFalseState = { newStmt(negStmt) } + ) + } + + private fun visitReturnStmt(scope: JcStepScope, stmt: JcReturnInst) { + val exprResolver = exprResolverWithScope(scope) + val method = requireNotNull(scope.calcOnState { callStack.lastMethod() }) + val returnType = with(applicationGraph) { method.typed }.returnType + + val valueToReturn = stmt.returnValue + ?.let { exprResolver.resolveJcExpr(it, returnType) ?: return } + ?: ctx.mkVoidValue() + + scope.doWithState { + returnValue(valueToReturn) + } + } + + private fun visitGotoStmt(scope: JcStepScope, stmt: JcGotoInst) { + val nextStmt = stmt.location.method.instList[stmt.target.index] + scope.doWithState { newStmt(nextStmt) } + } + + @Suppress("UNUSED_PARAMETER") + private fun visitCatchStmt(scope: JcStepScope, stmt: JcCatchInst) { + error("The catch instruction must be unfolded during processing of the instructions led to it. Encountered inst: $stmt") + } + + private fun visitSwitchStmt(scope: JcStepScope, stmt: JcSwitchInst) { + val exprResolver = exprResolverWithScope(scope) + + val switchKey = stmt.key + // Note that the switch key can be an rvalue, for example, a simple int constant. + val instList = stmt.location.method.instList + + with(ctx) { + val caseStmtsWithConditions = stmt.branches.map { (caseValue, caseTargetStmt) -> + val nextStmt = instList[caseTargetStmt] + val jcEqExpr = JcEqExpr(cp.boolean, switchKey, caseValue) + val caseCondition = exprResolver.resolveJcExpr(jcEqExpr)?.asExpr(boolSort) ?: return + + caseCondition to { state: JcState -> state.newStmt(nextStmt) } + } + + // To make the default case possible, we need to ensure that all case labels are unsatisfiable + val defaultCaseWithCondition = mkAnd( + caseStmtsWithConditions.map { it.first.not() } + ) to { state: JcState -> state.newStmt(instList[stmt.default]) } + + scope.forkMulti(caseStmtsWithConditions + defaultCaseWithCondition) + } + } + + private fun visitThrowStmt(scope: JcStepScope, stmt: JcThrowInst) { + val resolver = exprResolverWithScope(scope) + val address = resolver.resolveJcExpr(stmt.throwable)?.asExpr(ctx.addressSort) ?: return + + scope.calcOnState { + throwExceptionWithoutStackFrameDrop(address, stmt.throwable.type) + } + } + + private fun visitCallStmt(scope: JcStepScope, stmt: JcCallInst) { + val exprResolver = exprResolverWithScope(scope) + exprResolver.resolveJcExpr(stmt.callExpr) ?: return + + scope.doWithState { + val nextStmt = stmt.nextStmt + newStmt(nextStmt) + } + } + + private fun visitMonitorEnterStmt(scope: JcStepScope, stmt: JcEnterMonitorInst) { + val exprResolver = exprResolverWithScope(scope) + exprResolver.resolveJcNotNullRefExpr(stmt.monitor, stmt.monitor.type) ?: return + + // Monitor enter makes sense only in multithreaded environment + + scope.doWithState { + newStmt(stmt.nextStmt) + } + } + + private fun visitMonitorExitStmt(scope: JcStepScope, stmt: JcExitMonitorInst) { + val exprResolver = exprResolverWithScope(scope) + exprResolver.resolveJcNotNullRefExpr(stmt.monitor, stmt.monitor.type) ?: return + + // Monitor exit makes sense only in multithreaded environment + + scope.doWithState { + newStmt(stmt.nextStmt) + } + } + + private fun exprResolverWithScope(scope: JcStepScope) = + JcExprResolver( + ctx, + scope, + ::mapLocalToIdxMapper, + ::typeInstanceAllocator, + ::stringConstantAllocator, + ::classInitializerAlwaysAnalysisRequiredForType + ) + + private val localVarToIdx = mutableMapOf>() // (method, localName) -> idx + + // TODO: now we need to explicitly evaluate indices of registers, because we don't have specific ULValues + private fun mapLocalToIdxMapper(method: JcMethod, local: JcLocal) = + when (local) { + is JcLocalVar -> localVarToIdx + .getOrPut(method) { mutableMapOf() } + .run { + getOrPut(local.name) { method.parametersWithThisCount + size } + } + + is JcThis -> 0 + is JcArgument -> method.localIdx(local.index) + else -> error("Unexpected local: $local") + } + + private val JcInst.nextStmt get() = location.method.instList[location.index + 1] + private operator fun JcInstList.get(instRef: JcInstRef): JcInst = this[instRef.index] + + private val stringConstantAllocatedRefs = mutableMapOf() + + // Equal string constants must have equal references + private fun stringConstantAllocator(value: String, state: JcState): UConcreteHeapRef = + stringConstantAllocatedRefs.getOrPut(value) { + // Allocate globally unique ref with a negative address + state.memory.allocateStaticRef() + } + + private val typeInstanceAllocatedRefs = mutableMapOf() + + private fun typeInstanceAllocator(type: JcType, state: JcState): UConcreteHeapRef { + val typeInfo = resolveTypeInfo(type) + return typeInstanceAllocatedRefs.getOrPut(typeInfo) { + // Allocate globally unique ref with a negative address + state.memory.allocateStaticRef() + } + } + + private fun classInitializerAlwaysAnalysisRequiredForType(type: JcRefType): Boolean { + // Always analyze a static initializer for enums + return type.jcClass.isEnum + } + + private fun resolveTypeInfo(type: JcType): JcTypeInfo = when (type) { + is JcClassType -> JcClassTypeInfo(type.jcClass) + is JcPrimitiveType -> JcPrimitiveTypeInfo(type) + is JcArrayType -> JcArrayTypeInfo(resolveTypeInfo(type.elementType)) + else -> error("Unexpected type: $type") + } + + private sealed interface JcTypeInfo + + private data class JcClassTypeInfo(val className: String) : JcTypeInfo { + // Don't use type.typeName here, because it contains generic parameters + constructor(cls: JcClassOrInterface) : this(cls.name) + } + + private data class JcPrimitiveTypeInfo(val type: JcPrimitiveType) : JcTypeInfo + + private data class JcArrayTypeInfo(val element: JcTypeInfo) : JcTypeInfo + + private fun resolveVirtualInvoke( + methodCall: JcVirtualMethodCallInst, + scope: JcStepScope, + typeSelector: JcTypeSelector, + forkOnRemainingTypes: Boolean, + ): Unit = with(methodCall) { + val instance = arguments.first().asExpr(ctx.addressSort) + val concreteRef = scope.calcOnState { models.first().eval(instance) } as UConcreteHeapRef + + if (isAllocatedConcreteHeapRef(concreteRef) || isStaticHeapRef(concreteRef)) { + // We have only one type for allocated and static heap refs + val type = scope.calcOnState { memory.typeStreamOf(concreteRef) }.first() + + val concreteMethod = type.findMethod(method) + ?: error("Can't find method $method in type ${type.typeName}") + + scope.doWithState { + val concreteCall = methodCall.toConcreteMethodCall(concreteMethod.method) + newStmt(concreteCall) + } + + return@with + } + + val typeStream = scope.calcOnState { models.first().typeStreamOf(concreteRef) } + + val inheritors = typeSelector.choose(method, typeStream) + val typeConstraints = inheritors.map { type -> + scope.calcOnState { + memory.types.evalTypeEquals(instance, type) + } + } + + val typeConstraintsWithBlockOnStates = mutableListOf Unit>>() + + inheritors.mapIndexedTo(typeConstraintsWithBlockOnStates) { idx, type -> + val isExpr = typeConstraints[idx] + + val block = { state: JcState -> + val concreteMethod = type.findMethod(method) + ?: error("Can't find method $method in type ${type.typeName}") + + val concreteCall = methodCall.toConcreteMethodCall(concreteMethod.method) + state.newStmt(concreteCall) + } + + isExpr to block + } + + if (forkOnRemainingTypes) { + val excludeAllTypesConstraint = ctx.mkAnd(typeConstraints.map { ctx.mkNot(it) }) + typeConstraintsWithBlockOnStates += excludeAllTypesConstraint to { } // do nothing, just exclude types + } + + scope.forkMulti(typeConstraintsWithBlockOnStates) + } + + private val approximationResolver = JcMethodApproximationResolver(ctx, applicationGraph.jcApplicationGraph) + + private fun approximateMethod(scope: JcStepScope, methodCall: JcMethodCall): Boolean { + val exprResolver = exprResolverWithScope(scope) + return approximationResolver.approximate(scope, exprResolver, methodCall) + } + + private fun mockNativeMethod( + scope: JcStepScope, + methodCall: JcConcreteMethodCallInst + ) = with(methodCall) { + logger.warn { "Mocked: ${method.enclosingClass.name}::${method.name}" } + + val returnType = with(applicationGraph) { method.typed }.returnType + + if (returnType == ctx.cp.void) { + scope.doWithState { skipMethodInvocationWithValue(methodCall, ctx.voidValue) } + return@with + } + + val mockSort = ctx.typeToSort(returnType) + val mockValue = scope.calcOnState { + memory.mock { call(method, arguments.asSequence(), mockSort) } + } + + if (mockSort == ctx.addressSort) { + val constraint = scope.calcOnState { + memory.types.evalIsSubtype(mockValue.asExpr(ctx.addressSort), returnType) + } + scope.assert(constraint) ?: return + } + + scope.doWithState { + skipMethodInvocationWithValue(methodCall, mockValue) + } + } +} \ No newline at end of file diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/machine/MLJcMachine.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/machine/MLJcMachine.kt new file mode 100644 index 0000000000..21dd757c34 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/jvm/machine/MLJcMachine.kt @@ -0,0 +1,110 @@ +package org.usvm.jvm.machine + +import mu.KLogging +import org.jacodb.api.JcClasspath +import org.jacodb.api.JcMethod +import org.jacodb.api.cfg.JcInst +import org.jacodb.api.ext.methods +import org.usvm.* +import org.usvm.api.targets.JcTarget +import org.usvm.jvm.JcApplicationBlockGraph +import org.usvm.jvm.interpreter.JcBlockInterpreter +import org.usvm.machine.JcComponents +import org.usvm.machine.JcContext +import org.usvm.machine.JcTypeSystem +import org.usvm.machine.state.JcMethodResult +import org.usvm.machine.state.JcState +import org.usvm.machine.state.lastStmt +import org.usvm.statistics.* +import org.usvm.statistics.collectors.CoveredNewStatesCollector +import org.usvm.statistics.collectors.TargetsReachedStatesCollector +import org.usvm.stopstrategies.createStopStrategy + +val logger = object : KLogging() {}.logger + +class MLJcMachine( + cp: JcClasspath, + private val options: MLMachineOptions +) : UMachine() { + private val applicationGraph = JcApplicationBlockGraph(cp) + + private val typeSystem = JcTypeSystem(cp) + private val components = JcComponents(typeSystem, options.basicOptions.solverType) + private val ctx = JcContext(cp, components) + + private val interpreter = JcBlockInterpreter(ctx, applicationGraph) + fun analyze(method: JcMethod, targets: List = emptyList()): List { + logger.debug("{}.analyze({}, {})", this, method, targets) + val initialState = interpreter.getInitialState(method, targets) + applicationGraph.initialStatement = initialState.currentStatement + + val methodsToTrackCoverage = + when (options.basicOptions.coverageZone) { + CoverageZone.METHOD -> setOf(method) + CoverageZone.TRANSITIVE -> setOf(method) + // TODO: more adequate method filtering. !it.isConstructor is used to exclude default constructor which is often not covered + CoverageZone.CLASS -> method.enclosingClass.methods.filter { + it.enclosingClass == method.enclosingClass && !it.isConstructor + }.toSet() + } + + val coverageStatistics: CoverageStatistics = CoverageStatistics( + methodsToTrackCoverage, + applicationGraph.jcApplicationGraph + ) + + val stateVisitsStatistics: StateVisitsStatistics = StateVisitsStatistics() + + val pathSelector = + createPathSelector(initialState, options, applicationGraph, stateVisitsStatistics, coverageStatistics) + + val statesCollector = + when (options.basicOptions.stateCollectionStrategy) { + StateCollectionStrategy.COVERED_NEW -> CoveredNewStatesCollector(coverageStatistics) { + it.methodResult is JcMethodResult.JcException + } + + StateCollectionStrategy.REACHED_TARGET -> TargetsReachedStatesCollector() + } + + val stopStrategy = createStopStrategy( + options.basicOptions, + targets, + coverageStatistics = { coverageStatistics }, + getCollectedStatesCount = { statesCollector.collectedStates.size } + ) + + val observers = mutableListOf>(coverageStatistics) + observers.add(TerminatedStateRemover()) + + if (options.basicOptions.coverageZone == CoverageZone.TRANSITIVE) { + observers.add( + TransitiveCoverageZoneObserver( + initialMethod = method, + methodExtractor = { state -> state.lastStmt.location.method }, + addCoverageZone = { coverageStatistics.addCoverageZone(it) }, + ignoreMethod = { false } // TODO replace with a configurable setting + ) + ) + } + observers.add(statesCollector) + + run( + interpreter, + pathSelector, + observer = CompositeUMachineObserver(observers), + isStateTerminated = ::isStateTerminated, + stopStrategy = stopStrategy, + ) + + return statesCollector.collectedStates + } + + private fun isStateTerminated(state: JcState): Boolean { + return state.callStack.isEmpty() + } + + override fun close() { + components.close() + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/GNNPathSelector.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/GNNPathSelector.kt new file mode 100644 index 0000000000..ae94deeced --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/GNNPathSelector.kt @@ -0,0 +1,343 @@ +package org.usvm.ps + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import org.usvm.ApplicationBlockGraph +import org.usvm.StateId +import org.usvm.UPathSelector +import org.usvm.UState +import org.usvm.statistics.CoverageStatistics +import org.usvm.statistics.StateVisitsStatistics +import java.nio.FloatBuffer +import java.nio.LongBuffer + +data class GameEdgeLabel( + val token: Int +) + +data class GameMapEdge( + val vertexFrom: Int, + val vertexTo: Int, + val label: GameEdgeLabel = GameEdgeLabel(0), +) + +data class BlockFeatures( + val uid: Int = 0, + val id: Int, + val basicBlockSize: Int, + val inCoverageZone: Boolean, + val coveredByTest: Boolean, + val visitedByState: Boolean, + val touchedByState: Boolean, + val states: List, +) + +data class StateHistoryElem( + val graphVertexId: Int, + val numOfVisits: Int, +) + +data class StateFeatures( + val id: StateId, + val position: Int = 0, + val predictedUsefulness: Int = 42, + val pathConditionSize: Int, + val visitedAgainVertices: Int, + val visitedNotCoveredVerticesInZone: Int, + val visitedNotCoveredVerticesOutOfZone: Int, + val history: List, + val children: List, +) + +data class GraphNative( + val gameVertex: List>, + val stateVertex: List>, + val gameVertexToGameVertex: List>, + val gameVertexHistoryStateVertexIndex: List>, + val gameVertexHistoryStateVertexAttrs: List, + val gameVertexInStateVertex: List>, + val stateVertexParentOfStateVertex: List>, + val stateMap: Map +) + +class GNNPathSelector>( + private val coverageStatistics: CoverageStatistics, + private val stateVisitsStatistics: StateVisitsStatistics, + private val applicationBlockGraph: ApplicationBlockGraph, + heteroGNNModelPath: String, +) : UPathSelector { + private val states: MutableList = mutableListOf() + + private var env: OrtEnvironment = OrtEnvironment.getEnvironment() + private var gnnSession: OrtSession = env.createSession(heteroGNNModelPath) + + override fun isEmpty() = states.isEmpty() + override fun peek(): State { + val nativeInput = createNativeInput() + if (nativeInput.gameVertexInStateVertex.isEmpty()) { + return states.first() + } + + val gameVertexTensor = onnxFloatTensor(nativeInput.gameVertex) + val stateVertexTensor = onnxFloatTensor(nativeInput.stateVertex) + val gameVertexToGameVertexTensor = onnxLongTensor(nativeInput.gameVertexToGameVertex.transpose()) + val gameVertexHistoryStateVertexIndexTensor = + onnxLongTensor(nativeInput.gameVertexHistoryStateVertexIndex.transpose()) + val gameVertexHistoryStateVertexAttrsTensor = + onnxLongTensor(nativeInput.gameVertexHistoryStateVertexAttrs.map { listOf(it) }) // this rn can be of shape [0, 0], we need to fix this + + val gameVertexInStateVertexTensor = onnxLongTensor(nativeInput.gameVertexInStateVertex.transpose()) + val stateVertexParentOfStateVertexTensor = + onnxLongTensor(nativeInput.stateVertexParentOfStateVertex.transpose()) + + val res = gnnSession.run( + mapOf( + "game_vertex" to gameVertexTensor, + "state_vertex" to stateVertexTensor, + "game_vertex to game_vertex" to gameVertexToGameVertexTensor, + "game_vertex history state_vertex index" to gameVertexHistoryStateVertexIndexTensor, + "game_vertex history state_vertex attrs" to gameVertexHistoryStateVertexAttrsTensor, + "game_vertex in state_vertex" to gameVertexInStateVertexTensor, + "state_vertex parent_of state_vertex" to stateVertexParentOfStateVertexTensor + ) + ) + + val predictedStatesRanks = + (res["out"].get().value as Array<*>).map { it as FloatArray }.map { it.toList() }.toList() + val chosenStateId = predictState(predictedStatesRanks, nativeInput.stateMap) + + return states.single { state -> state.id.toInt() == chosenStateId } + } + + override fun update(state: State) {} + + override fun add(states: Collection) { + this.states += states + } + + override fun remove(state: State) { + states.remove(state) + } + + private fun createNativeInput(): GraphNative { + val nodesState = mutableListOf>() + val nodesVertex = mutableListOf>() + val edgesIndexVSHistory = mutableListOf>() + val edgesAttrVS = mutableListOf() + val edgesIndexSS = mutableListOf>() + val edgesIndexVSIn = mutableListOf>() + + val stateMap = mutableMapOf() + val vertexMap = mutableMapOf() + + val blockIdGenerator = IDGenerator() + val rawBlocks = applicationBlockGraph.blocks() + + val blocksFeatures = mutableListOf() + + rawBlocks.forEach { block -> + val blockId = blockIdGenerator.issue() + vertexMap[block] = blockId + blocksFeatures.add( + createBlockFeatures( + id = blockId, + block = block + ) + ) + } + + val statesFeatures = + states.map { state -> getStateFeatures(state, rawBlocks) { vertexMap[it]!! } } + for ((stateIndexOrder, stateFeatures) in statesFeatures.withIndex()) { + stateMap[stateFeatures.id.toInt()] = stateIndexOrder + nodesState.add(stateFeatures.toList()) + } + + blocksFeatures.forEach { blockFeatures -> + nodesVertex.add(blockFeatures.toList()) + } + + val edgesIndexVV = applicationBlockGraph + .edges(rawBlocks) { bb -> vertexMap[bb]!! } + .map { edge -> listOf(edge.vertexFrom, edge.vertexTo) } + .toList() + + for ((stateIndexOrder, stateFeatures) in statesFeatures.withIndex()) { + for (historyEdge in stateFeatures.history) { + val vertexTo = historyEdge.graphVertexId + edgesIndexVSHistory.add(listOf(vertexTo, stateIndexOrder)) + edgesAttrVS.add(historyEdge.numOfVisits) + } + } + + for (stateFeatures in statesFeatures) { + for (childId in stateFeatures.children) { + if (childId.toInt() in stateMap.keys) + edgesIndexSS.add(listOf(stateMap[stateFeatures.id.toInt()]!!, stateMap[childId.toInt()]!!)) + } + } + + for (vertex in blocksFeatures) { + for (state in vertex.states) { + edgesIndexVSIn.add(listOf(vertex.id, stateMap[state.toInt()]!!)) + } + } + + return GraphNative( + gameVertex = nodesVertex, + stateVertex = nodesState, + gameVertexToGameVertex = edgesIndexVV, + gameVertexHistoryStateVertexIndex = edgesIndexVSHistory, + gameVertexHistoryStateVertexAttrs = edgesAttrVS, + gameVertexInStateVertex = edgesIndexVSIn, + stateVertexParentOfStateVertex = edgesIndexSS, + stateMap = stateMap + ) + } + + private fun get2DShape(data: List>): Pair { + if (data.isEmpty()) { + return Pair(0, 0) + } + if (data[0].isEmpty()) { + return Pair(data.size, 0) + } + return Pair(data.size, data[0].size) + } + + + private fun List.create2DFloatBuffer(shape: Pair): OnnxTensor { + val longArrayOfShape = longArrayOf(shape.first.toLong(), shape.second.toLong()) + return OnnxTensor.createTensor( + env, + FloatBuffer.wrap(this.toFloatArray()), + longArrayOfShape + ) + } + + private fun List.create2DLongBuffer(shape: Pair): OnnxTensor { + val longArrayOfShape = longArrayOf(shape.first.toLong(), shape.second.toLong()) + return OnnxTensor.createTensor( + env, + LongBuffer.wrap(this.toLongArray()), + longArrayOfShape + ) + } + + private fun onnxFloatTensor(data: List>): OnnxTensor { + return data.flatten().map { it.toFloat() }.create2DFloatBuffer(get2DShape(data)) + } + + private fun onnxLongTensor(data: List>): OnnxTensor { + return data.flatten().map { it.toLong() }.create2DLongBuffer(get2DShape(data)) + } + + private fun BasicBlock.inCoverageZone() = coverageStatistics.inCoverageZone(applicationBlockGraph.methodOf(this)) + + private fun BasicBlock.isVisited() = stateVisitsStatistics.isVisited(this.instructions().last()) + + private fun BasicBlock.isTouched() = + this.isVisited() || stateVisitsStatistics.isVisited(this.instructions().first()) + + private fun BasicBlock.isCovered() = coverageStatistics.isCovered(this.instructions().first()) + + private fun createBlockFeatures( + id: Int, + block: BasicBlock + ): BlockFeatures { + return BlockFeatures( + id = id, + inCoverageZone = block.inCoverageZone(), + basicBlockSize = block.instructions().count(), + coveredByTest = block.isCovered(), + visitedByState = block.isVisited(), + touchedByState = block.isTouched(), + states = states + .filter { it.currentStatement in applicationBlockGraph.instructions(block) } + .map { it.id } + ) + } + + private fun BasicBlock.instructions(): Sequence { + return applicationBlockGraph.instructions(this) + } + + private fun getStateFeatures( + state: State, + blocksSource: Sequence, + mapper: (BasicBlock) -> Int + ): StateFeatures { + val blockHistory = state.reversedPathFrom(blocksSource).toList().reversed() + + val visitedNotCoveredVerticesInZone = blockHistory.count { it.isVisited() && it.inCoverageZone() } + val visitedNotCoveredVerticesOutOfZone = blockHistory.count { it.isVisited() && !it.inCoverageZone() } + + return StateFeatures( + id = state.id, + pathConditionSize = state.pathConstraints.size(), + visitedAgainVertices = state.reversedPath.asSequence().count() - state.reversedPath.asSequence().distinct() + .count(), + visitedNotCoveredVerticesInZone = visitedNotCoveredVerticesInZone, + visitedNotCoveredVerticesOutOfZone = visitedNotCoveredVerticesOutOfZone, + history = blockHistory.map { block -> + StateHistoryElem( + mapper(block), + blockHistory.count { other -> mapper(block) == mapper(other) }) + }, + children = state.pathLocation.accumulatedForks.map { it.id } + ) + } + + private fun UState<*, *, Statement, *, *, State>.reversedPathFrom(sourceBlocks: Sequence): Sequence { + val blocks = mutableSetOf() + for (instruction in this.reversedPath) { + if (instruction in sourceBlocks.map { block -> applicationBlockGraph.instructions(block) }.flatten()) + blocks.add(applicationBlockGraph.blockOf(instruction)) + // else: external call + } + return blocks.asSequence() + } + + fun ApplicationBlockGraph.edges( + rawBlocks: Sequence, + mapper: (BasicBlock) -> Int + ): Sequence { + return rawBlocks.flatMap { basicBlock -> + this.successors(basicBlock).map { otherBasicBlock -> + GameMapEdge(mapper(basicBlock), mapper(otherBasicBlock)) + } + } + } +} + +fun predictState(stateRank: List>, stateMap: Map): Int { + val reverseStateMap = stateMap.entries.associate { (k, v) -> v to k } + + val stateRankMapping = stateRank.mapIndexed { orderIndex, rank -> + reverseStateMap[orderIndex]!! to rank + } + + return stateRankMapping.maxBy { it.second.sum() }.first +} + +fun BlockFeatures.toList(): List { + return listOf( + this.inCoverageZone.toInt(), + this.basicBlockSize, + this.coveredByTest.toInt(), + this.visitedByState.toInt(), + this.touchedByState.toInt() + ) +} + +fun StateFeatures.toList(): List { + return listOf( + this.position, + this.predictedUsefulness, + this.pathConditionSize, + this.visitedAgainVertices, + this.visitedNotCoveredVerticesInZone, + this.visitedNotCoveredVerticesOutOfZone + ) +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/Utils.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/Utils.kt new file mode 100644 index 0000000000..d130748e59 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/Utils.kt @@ -0,0 +1,61 @@ +package org.usvm.ps + +import ai.onnxruntime.OnnxTensor +import kotlinx.coroutines.runBlocking +import org.usvm.constraints.UPathConstraints + + +fun get2DShape(data: List>): Pair { + if (data.isEmpty()) { + return Pair(0, 0) + } + if (data[0].isEmpty()) { + return Pair(data.size, 0) + } + return Pair(data.size, data[0].size) +} + +fun List>.tensorNullIfEmpty(toTensor: (List>) -> OnnxTensor): OnnxTensor { + if (this.isEmpty()) { + return toTensor(listOf(listOf(), listOf())) + } + + return toTensor(this) +} + +fun List>.transpose(): List> { + if (this.isEmpty()) { + return listOf(listOf(), listOf()) + } + + val (rows, cols) = get2DShape(this) + return List(cols) { j -> + List(rows) { i -> + this[i][j] + } + } +} + +fun Boolean.toInt(): Int = if (this) 1 else 0 + +class IDGenerator { + @Volatile + var current: Int = 0 + + fun issue(): Int { + val result: Int + runBlocking { result = current++ } + return result + } +} + +fun UPathConstraints.size(): Int { + return numericConstraints.constraints().count() + + this.equalityConstraints.distinctReferences.count() + + this.equalityConstraints.equalReferences.count() + + this.equalityConstraints.referenceDisequalities.count() + + this.equalityConstraints.nullableDisequalities.count() + + this.logicalConstraints.count() + + this.typeConstraints.symbolicRefToTypeRegion.count() // TODO: maybe throw out? +} + diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/statistics/StateVisitsStatistics.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/statistics/StateVisitsStatistics.kt new file mode 100644 index 0000000000..113b16225b --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/statistics/StateVisitsStatistics.kt @@ -0,0 +1,14 @@ +package org.usvm.statistics + +import org.usvm.UState + +class StateVisitsStatistics> : + UMachineObserver { + private val visitedStatements = HashSet() + + fun isVisited(statement: Statement) = visitedStatements.contains(statement) + + override fun onState(parent: State, forks: Sequence) { + visitedStatements.add(parent.currentStatement) + } +} diff --git a/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JacoDBContainer.kt b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JacoDBContainer.kt new file mode 100644 index 0000000000..9b41dd101c --- /dev/null +++ b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JacoDBContainer.kt @@ -0,0 +1,53 @@ +package org.usvm.jvm + +import kotlinx.coroutines.runBlocking +import org.jacodb.api.JcClasspath +import org.jacodb.api.JcDatabase +import org.jacodb.approximation.Approximations +import org.jacodb.impl.JcSettings +import org.jacodb.impl.features.InMemoryHierarchy +import org.jacodb.impl.jacodb +import org.usvm.org.usvm.jvm.util.allClasspath +import java.io.File + +class JacoDBContainer( + classpath: List, + builder: JcSettings.() -> Unit, +) { + val db: JcDatabase + val cp: JcClasspath + + init { + val (db, cp) = runBlocking { + val db = jacodb(builder) + val cp = db.classpath(classpath) + db to cp + } + this.db = db + this.cp = cp + runBlocking { + db.awaitBackgroundJobs() + } + } + + companion object { + private val keyToJacoDBContainer = HashMap() + + operator fun invoke( + key: Any?, + classpath: List = samplesClasspath, + builder: JcSettings.() -> Unit = defaultBuilder, + ): JacoDBContainer = + keyToJacoDBContainer.getOrPut(key) { JacoDBContainer(classpath, builder) } + + private val samplesClasspath = allClasspath.filter { + it.name.contains("samples") || it.name.contains("tests") + } + + private val defaultBuilder: JcSettings.() -> Unit = { + useProcessJavaRuntime() + installFeatures(InMemoryHierarchy, Approximations) + loadByteCode(samplesClasspath) + } + } +} diff --git a/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JavaMethodTestRunner.kt b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JavaMethodTestRunner.kt new file mode 100644 index 0000000000..4f8e74f31e --- /dev/null +++ b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/JavaMethodTestRunner.kt @@ -0,0 +1,95 @@ +package org.usvm.jvm + +import org.jacodb.api.JcClassOrInterface +import org.jacodb.api.JcMethod +import org.jacodb.api.ext.packageName +import org.usvm.CoverageZone +import org.usvm.MLMachineOptions +import org.usvm.MLPathSelectionStrategy +import org.usvm.UMachineOptions +import org.usvm.jvm.machine.MLJcMachine +import kotlin.io.path.Path +import kotlin.system.measureTimeMillis + +fun jarLoad(jars: Set, classes: MutableMap>) { + jars.forEach { filePath -> + val file = Path(filePath).toFile() + val container = JacoDBContainer(key = filePath, classpath = listOf(file)) + val classNames = container.db.locations.flatMap { it.jcLocation?.classNames ?: listOf() } + classes[filePath] = mutableListOf() + classNames.forEach { className -> + container.cp.findClassOrNull(className)?.let { + classes[filePath]?.add(it) + } + } + } +} + +fun getMethodFullName(method: Any?): String { + return if (method is JcMethod) { + "${method.enclosingClass.name}#${method.name}(${method.parameters.joinToString { it.type.typeName }})" + } else { + method.toString() + } +} + +val baseOptions = UMachineOptions( + coverageZone = CoverageZone.TRANSITIVE, + exceptionsPropagation = true, + timeoutMs = 60_000, + stepsFromLastCovered = 3500L, +) + +open class JavaMethodTestRunner { + private val options: MLMachineOptions = MLMachineOptions( + baseOptions, + MLPathSelectionStrategy.GNN, + heteroGNNModelPath = "Game_env/test_model.onnx" + ) + + fun runner(method: JcMethod, jarKey: String) { + MLJcMachine(JacoDBContainer(jarKey).cp, options).use { machine -> + machine.analyze(method, emptyList()) + } + } +} + + +fun main() { + val inputJars = mapOf( + Pair("Game_env/guava-28.2-jre.jar", listOf("com.google.common")) + ) + + val jarClasses = mutableMapOf>() + jarLoad(inputJars.keys, jarClasses) + println("\nLOADING COMPLETE\n") + + jarClasses.forEach { (key, classesList) -> + println("RUNNING TESTS FOR $key") + val allMethods = classesList.filter { cls -> + !cls.isAnnotation && !cls.isInterface && + inputJars.getValue(key).any { cls.packageName.contains(it) } && + !cls.name.contains("Test") + }.flatMap { cls -> + cls.declaredMethods.filter { method -> + method.enclosingClass == cls && !method.isConstructor + } + }.sortedBy { getMethodFullName(it).hashCode() }.distinctBy { getMethodFullName(it) } + + println(" RUNNING TESTS WITH ${baseOptions.timeoutMs}ms TIME LIMIT") + println(" RUNNING TESTS WITH ${MLPathSelectionStrategy.GNN} PATH SELECTOR") + + val testRunner = JavaMethodTestRunner() + for (method in allMethods.shuffled()) { + try { + println(" Running test ${method.name}") + val time = measureTimeMillis { + testRunner.runner(method, key) + } + println(" Test ${method.name} finished after ${time}ms") + } catch (e: NotImplementedError) { + println(" $e, ${e.message}") + } + } + } +} diff --git a/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/util/Util.kt b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/util/Util.kt new file mode 100644 index 0000000000..efa7adc826 --- /dev/null +++ b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jvm/util/Util.kt @@ -0,0 +1,15 @@ +package org.usvm.org.usvm.jvm.util + +import java.io.File + +val allClasspath: List + get() { + return classpath.map { File(it) } + } + +private val classpath: List + get() { + val classpath = System.getProperty("java.class.path") + return classpath.split(File.pathSeparatorChar) + .toList() + }