-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ML path selection #60
Open
sergeyrid
wants to merge
95
commits into
main
Choose a base branch
from
ml-path-selection
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
95 commits
Select commit
Hold shift + click to select a range
c28ceb6
Add BfsWithLoggingPathSelector
sergeyrid 26be358
Add features to BfsWithLoggingPathSelector
sergeyrid 7df7145
Add jsonAggregator
sergeyrid c8f0dfa
Add InferencePathSelector
sergeyrid a93a138
Fix jsonAggregator and InferencePathSelector
sergeyrid 11ee239
Change jsonAggregator and fix path selectors
sergeyrid 8987aa7
Add graph visualization and alternative reward
sergeyrid 5a33d7a
Fix bugs and paths, add hashes to final json
sergeyrid 5310c7c
Add BlockGraph to BfsWithLoggingPathSelector
sergeyrid da8e9dd
Add BlockGraph logging
sergeyrid fb8f93b
Add concurrency and fix bugs
sergeyrid 3523c44
Add features, fix bugs and refactor
sergeyrid 575a266
Add features, fix bugs and refactor
sergeyrid f85016c
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid 4a4bcd0
Add PPO support
sergeyrid 1932d37
Fix path selectors
sergeyrid c731f02
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid f21c385
Change jsonAggregator and add options
sergeyrid 2093db0
Add options and fix block graph
sergeyrid 03b93a5
Add graph features and refactor
sergeyrid cfacb55
Merge remote-tracking branch 'origin/main' into ml-path-selection
sergeyrid 211ae65
Support string and class constants
Saloed d3fc3a7
Update tests ignore reasons
Saloed 2f5219c
Add gnn inference and refactor main
sergeyrid 52bdecf
Close machine (solver) after test
Saloed 337c494
Mock native methods + few approximations
Saloed df92f73
Update test ignore reasons
Saloed db0ce16
Add graph features and probabilities logging
sergeyrid 7860242
Merge remote-tracking branch 'origin/main' into ml-path-selection
sergeyrid d6dd569
Merge remote-tracking branch 'origin/saloed/native-mocks' into ml-pat…
sergeyrid 93270eb
Add features, jar processing and other
sergeyrid 71ea504
Merge remote-tracking branch 'origin/main' into ml-path-selection
sergeyrid f8491f5
Update tests ignore reasons
Saloed b69df4f
Mock native methods + few approximations
Saloed 9f22c2e
Update test ignore reasons
Saloed ed2579b
Fix mocks type constraints
Saloed f965c98
Update test ignore reasons
Saloed cb2ce74
Safe test resolver
Saloed f7e5106
Change test disabling method
Saloed e71e824
More approximations
Saloed 925a7c9
Scoring in type selector
Saloed b0da1a4
Timeouts
Saloed 943c7e3
Fix ignore reasons
Saloed d526ed1
Fix sc test
Saloed e64a46f
Disable more tests
Saloed 38905e4
Use yices by default
Saloed 3565713
Ensure thread stopped after timeout
Saloed 13823e6
Support IOB in arraycopy approximation
Saloed b7aa3f9
Enable few tests
Saloed 739223f
Safe logging
Saloed 01b2245
Better timeout handling
Saloed 5da1c9b
tmp
Saloed a7b4153
Change graphs and refactor
sergeyrid cfd157b
Merge remote-tracking branch 'origin/main' into ml-path-selection
sergeyrid d844878
Merge remote-tracking branch 'origin/saloed/native-mocks' into ml-pat…
sergeyrid 2f93c79
Disable unnecessary logging
sergeyrid 723d705
Add RNN
sergeyrid b83d742
Make main take jars as input
sergeyrid ac012a7
Fix, refactor and add options
sergeyrid acc5780
Add statistics calculation
sergeyrid f6d4257
Change queue to lru and fix statistics
sergeyrid cfe7cd1
Fix reward and statistics
sergeyrid 2470cfa
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid 652f180
Revert native-mocks changes
sergeyrid effd951
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid b2615f7
Move ml path selector to separate module
sergeyrid e4f3704
Separate ml from other modules
sergeyrid b50f08f
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid 8a477ee
Add OtherUMachine
sergeyrid acfab1b
Refactor
sergeyrid d7ff815
Add ml environment
sergeyrid fdc3d7a
Refactor
sergeyrid 4084382
Add README.md
sergeyrid 58c1e16
Change README.md
sergeyrid 736485c
Change README.md
sergeyrid eb21924
Change README.md
sergeyrid fe2edde
Uncomment code and fix gradle
sergeyrid fb5791f
Refactor
sergeyrid 5baa20a
Add test dependencies
sergeyrid 01a795b
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid 8278b6e
Fix ModifiedUMachine
sergeyrid 52e2385
Fix ModifiedUMachine
sergeyrid 9625b22
Refactor
sergeyrid 389c5a3
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid 587321f
Fix merge
sergeyrid e7cd950
Merge branch 'main' of https://github.com/UnitTestBot/usvm into ml-pa…
sergeyrid 65a2876
Fix merge
sergeyrid 50f6f1c
Merge remote-tracking branch 'origin/main' into ml-path-selection
sergeyrid fe38cea
Fix merge
sergeyrid 629f26b
Refactor BlockGraph
sergeyrid c689de1
Refactor CoverageCounter
sergeyrid f837ef3
Refactor CoverageCounterStatistics
sergeyrid cac5201
Refactor
sergeyrid 542c11e
Merge remote-tracking branch 'origin/main' into ml-path-selection
sergeyrid 72e1a72
Fix merge
sergeyrid File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
## Machine Learning Path Selector | ||
|
||
### Entry point | ||
|
||
To run tests with this path selector use `jarRunner.kt`. You can pass a path to a configuration json as the first argument. Gathered statistics will be put in a folder according to your configuration. | ||
|
||
### Config | ||
|
||
A config object is declared inside `MLConfig.kt`. A detailed description of all the options is listed below: | ||
|
||
- `gameEnvPath` - a path to a folder that contains trained models (`rnn_cell.onnx`, `gnn_model.onnx`, `actor_model.onnx`) and a blacklist of tests to be skipped (`blacklist.txt`), also some logs are saved to this folder | ||
- `dataPath` - a path to a folder to save all statistics into | ||
- `defaultAlgorithm` - an algorithm to use if a trained model is not found, must be one of: `BFS`, `ForkDepthRandom` | ||
- `postprocessing` - how actor model's outputs should be processed, must be one of: `Argmax` (choose an id of the maximum value), `Softmax` (sample from a distribution derived from the outputs via the softmax), `None` (sample from the outputs — only when they form a distribution) | ||
- `mode` - a mode for `jarRunner.kt`, must be one of: `Calculation` (to calculate statistics used to train models), `Aggregation` (to aggregate statistics for different tests into one file), `Both` (to both calculate statistics and aggregate them), `Test` (to test this path selector with different time limits and compare it to other path selectors) | ||
- `logFeatures` - whether to save statistics used to train models | ||
- `shuffleTests` - whether to shuffle tests before running (affects the tests being run if the `dataConsumption` option is less than 100) | ||
- `discounts` - time discounts used when testing path selectors | ||
- `inputShape` - an input shape of an actor model | ||
- `maxAttentionLength` - a maximum attention length of a PPO actor model | ||
- `useGnn` - whether to use a GNN model | ||
- `dataConsumption` - a percentage of tests to run | ||
- `hardTimeLimit` - a time limit for one test | ||
- `solverTimeLimit` - a time limit for one solver call | ||
- `maxConcurrency` - a maximum number of threads running different tests concurrently | ||
- `graphUpdate` - when to update block graph data, must be one of: `Once` (at the beginning of a test), `TestGeneration` (every time a new test is generated) | ||
- `logGraphFeatuers` - whether to save graph statistics used to train a GNN model to a dataset file | ||
- `gnnFeaturesCount` - a number of features that a GNN model returns | ||
- `useRnn` - whether to use an RNN model | ||
- `rnnStateShape` - a shape of an RNN state | ||
- `rnnFeaturesCount` - a number of features that an RNN model returns | ||
- `inputJars` - jars and their packages to run tests on | ||
|
||
### How to modify the metric | ||
|
||
To modify the metric you may change values of the `reward` property of the `ActionData` objects. They are written inside the property `path` of the `FeaturesLoggingPathSelector`. Currently, the metric is calculated in the `remove` method of the `FeaturesLoggingPathSelector`. | ||
|
||
### Training environment | ||
|
||
The training environment and its description are inside `environment.zip`. | ||
|
||
### "Modified" files | ||
|
||
Source files which names start with "Modified" are modified copies of files from other modules. They were modified to support this path selector. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
object MLVersions { | ||
const val serialization = "1.5.1" | ||
const val onnxruntime = "1.15.1" | ||
const val dotlin = "1.0.2" | ||
} | ||
|
||
plugins { | ||
id("usvm.kotlin-conventions") | ||
kotlin("plugin.serialization") version "1.8.21" | ||
} | ||
|
||
dependencies { | ||
implementation(project(":usvm-jvm")) | ||
implementation(project(":usvm-core")) | ||
|
||
implementation("org.jacodb:jacodb-analysis:${Versions.jcdb}") | ||
implementation("ch.qos.logback:logback-classic:${Versions.logback}") | ||
|
||
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:${MLVersions.serialization}") | ||
implementation("io.github.rchowell:dotlin:${MLVersions.dotlin}") | ||
implementation("com.microsoft.onnxruntime:onnxruntime:${MLVersions.onnxruntime}") | ||
} |
Binary file not shown.
80 changes: 80 additions & 0 deletions
80
usvm-ml-path-selection/src/main/kotlin/org/usvm/CoverageCounter.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package org.usvm | ||
|
||
import kotlinx.serialization.Serializable | ||
import java.util.concurrent.ConcurrentHashMap | ||
|
||
class CoverageCounter( | ||
private val mlConfig: MLConfig | ||
) { | ||
private val testCoverages = ConcurrentHashMap<String, List<Float>>() | ||
private val testStatementsCounts = ConcurrentHashMap<String, Float>() | ||
private val testDiscounts = ConcurrentHashMap<String, List<Float>>() | ||
private val testFinished = ConcurrentHashMap<String, Boolean>() | ||
|
||
fun addTest(testName: String, statementsCount: Float) { | ||
testCoverages[testName] = List(mlConfig.discounts.size) { 0.0f } | ||
testStatementsCounts[testName] = statementsCount | ||
testDiscounts[testName] = List(mlConfig.discounts.size) { 1.0f } | ||
testFinished[testName] = false | ||
} | ||
|
||
fun updateDiscounts(testName: String) { | ||
testDiscounts[testName] = testDiscounts.getValue(testName) | ||
.mapIndexed { id, currentDiscount -> mlConfig.discounts[id] * currentDiscount } | ||
} | ||
|
||
fun updateResults(testName: String, newCoverage: Float) { | ||
val currentDiscounts = testDiscounts.getValue(testName) | ||
testCoverages[testName] = testCoverages.getValue(testName) | ||
.mapIndexed { id, currentCoverage -> currentCoverage + currentDiscounts[id] * newCoverage } | ||
} | ||
|
||
fun finishTest(testName: String) { | ||
testFinished[testName] = true | ||
} | ||
|
||
fun reset() { | ||
testCoverages.clear() | ||
testStatementsCounts.clear() | ||
testDiscounts.clear() | ||
testFinished.clear() | ||
} | ||
|
||
private fun getTotalCoverages(): List<Float> { | ||
return testCoverages.values.reduce { acc, floats -> | ||
acc.zip(floats).map { (total, value) -> total + value } | ||
} | ||
} | ||
|
||
@Serializable | ||
data class TestStatistics( | ||
private val discounts: Map<String, Float>, | ||
private val statementsCount: Float, | ||
private val finished: Boolean, | ||
) | ||
|
||
@Serializable | ||
data class Statistics( | ||
private val tests: Map<String, TestStatistics>, | ||
private val totalDiscounts: Map<String, Float>, | ||
private val totalStatementsCount: Float, | ||
private val finishedTestsCount: Float, | ||
) | ||
|
||
fun getStatistics(): Statistics { | ||
val discountStrings = mlConfig.discounts.map { it.toString() } | ||
val testStatistics = testCoverages.mapValues { (test, coverages) -> | ||
TestStatistics( | ||
discountStrings.zip(coverages).toMap(), | ||
testStatementsCounts.getValue(test), | ||
testFinished.getValue(test), | ||
) | ||
} | ||
return Statistics( | ||
testStatistics, | ||
discountStrings.zip(getTotalCoverages()).toMap(), | ||
testStatementsCounts.values.sum(), | ||
testFinished.values.sumOf { if (it) 1.0 else 0.0 }.toFloat(), | ||
) | ||
} | ||
} |
51 changes: 51 additions & 0 deletions
51
usvm-ml-path-selection/src/main/kotlin/org/usvm/MLConfig.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package org.usvm | ||
|
||
enum class Postprocessing { | ||
Argmax, | ||
Softmax, | ||
None, | ||
} | ||
|
||
enum class Mode { | ||
Calculation, | ||
Aggregation, | ||
Both, | ||
Test, | ||
} | ||
|
||
enum class Algorithm { | ||
BFS, | ||
ForkDepthRandom, | ||
} | ||
|
||
enum class GraphUpdate { | ||
Once, | ||
TestGeneration, | ||
} | ||
|
||
data class MLConfig ( | ||
val gameEnvPath: String = "../Game_env", | ||
val dataPath: String = "../Data", | ||
val defaultAlgorithm: Algorithm = Algorithm.BFS, | ||
val postprocessing: Postprocessing = Postprocessing.Argmax, | ||
val mode: Mode = Mode.Both, | ||
val logFeatures: Boolean = true, | ||
val shuffleTests: Boolean = true, | ||
val discounts: List<Float> = listOf(1.0f, 0.998f, 0.99f), | ||
val inputShape: List<Long> = listOf(1, -1, 77), | ||
val maxAttentionLength: Int = -1, | ||
val useGnn: Boolean = true, | ||
val dataConsumption: Float = 100.0f, | ||
val hardTimeLimit: Int = 30000, // in ms | ||
val solverTimeLimit: Int = 10000, // in ms | ||
val maxConcurrency: Int = 64, | ||
val graphUpdate: GraphUpdate = GraphUpdate.Once, | ||
val logGraphFeatures: Boolean = false, | ||
val gnnFeaturesCount: Int = 8, | ||
val useRnn: Boolean = true, | ||
val rnnStateShape: List<Long> = listOf(4, 1, 512), | ||
val rnnFeaturesCount: Int = 33, | ||
val inputJars: Map<String, List<String>> = mapOf( | ||
Pair("../Game_env/jars/usvm-jvm-new.jar", listOf("org.usvm.samples", "com.thealgorithms")) | ||
) // path to jar file -> list of package names | ||
) |
19 changes: 19 additions & 0 deletions
19
usvm-ml-path-selection/src/main/kotlin/org/usvm/ModifiedUMachineOptions.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package org.usvm | ||
|
||
enum class ModifiedPathSelectionStrategy { | ||
/** | ||
* Collects features according to states selected by any other path selector. | ||
*/ | ||
FEATURES_LOGGING, | ||
/** | ||
* Collects features and feeds them to the ML model to select states. | ||
* Extends FEATURE_LOGGING path selector. | ||
*/ | ||
MACHINE_LEARNING, | ||
} | ||
|
||
data class ModifiedUMachineOptions( | ||
Saloed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val basicOptions: UMachineOptions = UMachineOptions(), | ||
val pathSelectionStrategies: List<ModifiedPathSelectionStrategy> = | ||
listOf(ModifiedPathSelectionStrategy.MACHINE_LEARNING) | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like hardcoded path. Maybe we should pass this as environment variable or via configuration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a default value inside of a configuration object, it can be changed with a configuration file.