Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common ml base #74

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ buildSrc/.gradle

# Ignore Idea directory
.idea

# Ignore MacOS-specific files
/**/.DS_Store
1 change: 1 addition & 0 deletions buildSrc/src/main/kotlin/Versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ object Versions {
const val mockk = "1.13.4"
const val junitParams = "5.9.3"
const val logback = "1.4.8"
const val onnxruntime = "1.15.1"

// versions for jvm samples
const val samplesLombok = "1.18.20"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies {

implementation(kotlin("stdlib-jdk8"))
implementation(kotlin("reflect"))
implementation("com.microsoft.onnxruntime", "onnxruntime", Versions.onnxruntime)
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}")

testImplementation(kotlin("test"))
Expand Down
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ include("usvm-core")
include("usvm-jvm")
include("usvm-util")
include("usvm-sample-language")
include("usvm-ml-path-selection")
11 changes: 11 additions & 0 deletions usvm-core/src/main/kotlin/org/usvm/PathTrieNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ sealed class PathsTrieNode<State : UState<*, *, Statement, *, *, State>, Stateme
*/
abstract val depth: Int

/**
* States that were forked from this node
*/
abstract val accumulatedForks: MutableCollection<State>

/**
* Adds a new label to [labels] collection.
*/
Expand Down Expand Up @@ -80,6 +85,7 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
statement = statement
) {
parentNode.children[statement] = this
parentNode.accumulatedForks.addAll(this.states)
}

internal constructor(parentNode: PathsTrieNodeImpl<State, Statement>, statement: Statement, state: State) : this(
Expand All @@ -89,11 +95,14 @@ class PathsTrieNodeImpl<State : UState<*, *, Statement, *, *, State>, Statement>
statement = statement
) {
parentNode.children[statement] = this
parentNode.accumulatedForks.addAll(this.states)
parentNode.states -= state
}

override val labels: MutableSet<Any> = hashSetOf()

override val accumulatedForks: MutableCollection<State> = mutableSetOf()

override fun addLabel(label: Any) {
labels.add(label)
}
Expand All @@ -115,6 +124,8 @@ class RootNode<State : UState<*, *, Statement, *, *, State>, Statement> : PathsT

override val labels: MutableSet<Any> = hashSetOf()

override val accumulatedForks: MutableCollection<State> = mutableSetOf()

override val depth: Int = 0

override fun addLabel(label: Any) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class CoverageStatistics<Method, Statement, State : UState<*, Method, Statement,
return uncoveredStatements.values.flatten()
}

fun isCovered(statement: Statement) = statement in coveredStatements.values.flatten()

fun inCoverageZone(method: Method) = coveredStatements.containsKey(method) || uncoveredStatements.containsKey(method)

/**
* Adds a listener triggered when a new statement is covered.
*/
Expand Down
11 changes: 11 additions & 0 deletions usvm-ml-path-selection/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
plugins {
id("usvm.kotlin-conventions")
}

dependencies {
implementation(project(":usvm-core"))
implementation(project(":usvm-jvm"))

implementation("org.jacodb:jacodb-analysis:${Versions.jcdb}")
implementation("org.jacodb:jacodb-approximations:${Versions.jcdb}")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.usvm

import org.usvm.statistics.ApplicationGraph

interface ApplicationBlockGraph<Method, BasicBlock, Statement> : ApplicationGraph<Method, BasicBlock> {
fun blockOf(stmt: Statement): BasicBlock
fun instructions(block: BasicBlock): Sequence<Statement>
fun blocks(): Sequence<BasicBlock>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.usvm

enum class MLPathSelectionStrategy {
/**
* Collects features according to states selected by any other path selector.
*/
FEATURES_LOGGING,

/**
* Collects features and feeds them to the ML model to select states.
* Extends FEATURE_LOGGING path selector.
*/
MACHINE_LEARNING,

/**
* Selects states with best Graph Neural Network state score
*/
GNN,
}

data class MLMachineOptions(
val basicOptions: UMachineOptions,
val pathSelectionStrategy: MLPathSelectionStrategy,
val heteroGNNModelPath: String = ""
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.usvm

import org.usvm.ps.ExceptionPropagationPathSelector
import org.usvm.ps.GNNPathSelector
import org.usvm.statistics.CoverageStatistics
import org.usvm.statistics.StateVisitsStatistics

fun <Method, Statement, BasicBlock, State : UState<*, Method, Statement, *, *, State>> createPathSelector(
initialState: State,
options: MLMachineOptions,
applicationGraph: ApplicationBlockGraph<Method, BasicBlock, Statement>,
stateVisitsStatistics: StateVisitsStatistics<Method, Statement, State>,
coverageStatistics: CoverageStatistics<Method, Statement, State>,
): UPathSelector<State> {
val selector = when (options.pathSelectionStrategy) {
MLPathSelectionStrategy.GNN -> createGNNPathSelector(
stateVisitsStatistics,
coverageStatistics, applicationGraph, options.heteroGNNModelPath
)

else -> {
throw NotImplementedError()
}
}

val propagateExceptions = options.basicOptions.exceptionsPropagation

val resultSelector = selector.wrapIfRequired(propagateExceptions)
resultSelector.add(listOf(initialState))

return selector
}

private fun <State : UState<*, *, *, *, *, State>> UPathSelector<State>.wrapIfRequired(propagateExceptions: Boolean) =
if (propagateExceptions && this !is ExceptionPropagationPathSelector<State>) {
ExceptionPropagationPathSelector(this)
} else {
this
}

private fun <Method, Statement, BasicBlock, State : UState<*, Method, Statement, *, *, State>> createGNNPathSelector(
stateVisitsStatistics: StateVisitsStatistics<Method, Statement, State>,
coverageStatistics: CoverageStatistics<Method, Statement, State>,
applicationGraph: ApplicationBlockGraph<Method, BasicBlock, Statement>,
heteroGNNModelPath: String,
): UPathSelector<State> {
return GNNPathSelector(
coverageStatistics,
stateVisitsStatistics,
applicationGraph,
heteroGNNModelPath
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package org.usvm.jvm

import org.jacodb.api.JcClasspath
import org.jacodb.api.JcMethod
import org.jacodb.api.JcTypedMethod
import org.jacodb.api.cfg.JcBasicBlock
import org.jacodb.api.cfg.JcBlockGraph
import org.jacodb.api.cfg.JcInst
import org.jacodb.api.ext.toType
import org.usvm.ApplicationBlockGraph
import org.usvm.machine.JcApplicationGraph
import java.util.concurrent.ConcurrentHashMap

class JcApplicationBlockGraph(cp: JcClasspath) :
ApplicationBlockGraph<JcMethod, JcBasicBlock, JcInst> {
val jcApplicationGraph: JcApplicationGraph = JcApplicationGraph(cp)
var initialStatement: JcInst? = null

private fun initialStatement(): JcInst {
if (initialStatement == null) {
throw RuntimeException("initial statement not set")
}
return initialStatement!!
}

private fun getBlockGraph() = initialStatement().location.method.flowGraph().blockGraph()

override fun predecessors(node: JcBasicBlock): Sequence<JcBasicBlock> {
val jcBlockGraphImpl: JcBlockGraph = getBlockGraph()
return jcBlockGraphImpl.predecessors(node).asSequence()
}

override fun successors(node: JcBasicBlock): Sequence<JcBasicBlock> {
val jcBlockGraphImpl: JcBlockGraph = getBlockGraph()
return jcBlockGraphImpl.successors(node).asSequence() + jcBlockGraphImpl.throwers(node).asSequence()
}

override fun callees(node: JcBasicBlock): Sequence<JcMethod> {
val jcBlockGraphImpl: JcBlockGraph = getBlockGraph()

return jcBlockGraphImpl.instructions(node)
.map { jcApplicationGraph.callees(it) }
.reduce { acc, sequence -> acc + sequence }
.toSet()
.asSequence()
}

override fun callers(method: JcMethod): Sequence<JcBasicBlock> {
return jcApplicationGraph
.callers(method)
.map { stmt -> blockOf(stmt) }
.toSet()
.asSequence()
}

override fun entryPoints(method: JcMethod): Sequence<JcBasicBlock> =
method.flowGraph().blockGraph().entries.asSequence()

override fun exitPoints(method: JcMethod): Sequence<JcBasicBlock> =
method.flowGraph().blockGraph().exits.asSequence()

override fun methodOf(node: JcBasicBlock): JcMethod {
val firstInstruction = getBlockGraph().instructions(node).first()
return jcApplicationGraph.methodOf(firstInstruction)
}

override fun instructions(block: JcBasicBlock): Sequence<JcInst> {
return getBlockGraph().instructions(block).asSequence()
}

override fun statementsOf(method: JcMethod): Sequence<JcBasicBlock> {
return jcApplicationGraph
.statementsOf(method)
.map { stmt -> blockOf(stmt) }
.toSet()
.asSequence()
}

override fun blockOf(stmt: JcInst): JcBasicBlock {
val jcBlockGraphImpl: JcBlockGraph = stmt.location.method.flowGraph().blockGraph()
val blocks = blocks()
for (block in blocks) {
if (stmt in jcBlockGraphImpl.instructions(block)) {
return block
}
}
throw IllegalStateException("block not found for $stmt in ${jcBlockGraphImpl.jcGraph.method}")
}

override fun blocks(): Sequence<JcBasicBlock> {
return initialStatement().location.method.flowGraph().blockGraph().asSequence()
}

private val typedMethodsCache = ConcurrentHashMap<JcMethod, JcTypedMethod>()

val JcMethod.typed
get() = typedMethodsCache.getOrPut(this) {
enclosingClass.toType().declaredMethods.first { it.method == this }
}
}
Loading
Loading