-
Notifications
You must be signed in to change notification settings - Fork 25
ML path selection #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 92 commits
c28ceb6
26be358
7df7145
c8f0dfa
a93a138
11ee239
8987aa7
5a33d7a
5310c7c
da8e9dd
fb8f93b
3523c44
575a266
f85016c
4a4bcd0
1932d37
c731f02
f21c385
2093db0
03b93a5
cfacb55
211ae65
d3fc3a7
2f5219c
52bdecf
337c494
df92f73
db0ce16
7860242
d6dd569
93270eb
71ea504
f8491f5
b69df4f
9f22c2e
ed2579b
f965c98
cb2ce74
f7e5106
e71e824
925a7c9
b0da1a4
943c7e3
d526ed1
e64a46f
38905e4
3565713
13823e6
b7aa3f9
739223f
01b2245
5da1c9b
a7b4153
cfd157b
d844878
2f93c79
723d705
b83d742
ac012a7
acc5780
f6d4257
cfe7cd1
2470cfa
652f180
effd951
b2615f7
e4f3704
b50f08f
8a477ee
acfab1b
d7ff815
fdc3d7a
4084382
58c1e16
736485c
eb21924
fe2edde
fb5791f
5baa20a
01a795b
8278b6e
52e2385
9625b22
389c5a3
587321f
e7cd950
65a2876
50f6f1c
fe38cea
629f26b
c689de1
f837ef3
cac5201
542c11e
72e1a72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
## Machine Learning Path Selector | ||
|
||
### Entry point | ||
|
||
To run tests with this path selector use `jarRunner.kt`. You can pass a path to a configuration json as the first argument. Gathered statistics will be put in a folder according to your configuration. | ||
|
||
### Config | ||
|
||
A config object is declared inside `MLConfig.kt`. A detailed description of all the options is listed below: | ||
|
||
- `gameEnvPath` - a path to a folder that contains trained models (`rnn_cell.onnx`, `gnn_model.onnx`, `actor_model.onnx`) and a blacklist of tests to be skipped (`blacklist.txt`), also some logs are saved to this folder | ||
- `dataPath` - a path to a folder to save all statistics into | ||
- `defaultAlgorithm` - an algorithm to use if a trained model is not found, must be one of: `BFS`, `ForkDepthRandom` | ||
- `postprocessing` - how actor model's outputs should be processed, must be one of: `Argmax` (choose an id of the maximum value), `Softmax` (sample from a distribution derived from the outputs via the softmax), `None` (sample from the outputs — only when they form a distribution) | ||
- `mode` - a mode for `jarRunner.kt`, must be one of: `Calculation` (to calculate statistics used to train models), `Aggregation` (to aggregate statistics for different tests into one file), `Both` (to both calculate statistics and aggregate them), `Test` (to test this path selector with different time limits and compare it to other path selectors) | ||
- `logFeatures` - whether to save statistics used to train models | ||
- `shuffleTests` - whether to shuffle tests before running (affects the tests being run if the `dataConsumption` option is less than 100) | ||
- `discounts` - time discounts used when testing path selectors | ||
- `inputShape` - an input shape of an actor model | ||
- `maxAttentionLength` - a maximum attention length of a PPO actor model | ||
- `useGnn` - whether to use a GNN model | ||
- `dataConsumption` - a percentage of tests to run | ||
- `hardTimeLimit` - a time limit for one test | ||
- `solverTimeLimit` - a time limit for one solver call | ||
- `maxConcurrency` - a maximum number of threads running different tests concurrently | ||
- `graphUpdate` - when to update block graph data, must be one of: `Once` (at the beginning of a test), `TestGeneration` (every time a new test is generated) | ||
- `logGraphFeatuers` - whether to save graph statistics used to train a GNN model to a dataset file | ||
- `gnnFeaturesCount` - a number of features that a GNN model returns | ||
- `useRnn` - whether to use an RNN model | ||
- `rnnStateShape` - a shape of an RNN state | ||
- `rnnFeaturesCount` - a number of features that an RNN model returns | ||
- `inputJars` - jars and their packages to run tests on | ||
|
||
### How to modify the metric | ||
|
||
To modify the metric you may change values of the `reward` property of the `ActionData` objects. They are written inside the property `path` of the `FeaturesLoggingPathSelector`. Currently, the metric is calculated in the `remove` method of the `FeaturesLoggingPathSelector`. | ||
|
||
### Training environment | ||
|
||
The training environment and its description are inside `environment.zip`. | ||
|
||
### "Modified" files | ||
|
||
Source files which names start with "Modified" are modified copies of files from other modules. They were modified to support this path selector. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
object MLVersions { | ||
const val serialization = "1.5.1" | ||
const val onnxruntime = "1.15.1" | ||
const val dotlin = "1.0.2" | ||
} | ||
|
||
plugins { | ||
id("usvm.kotlin-conventions") | ||
kotlin("plugin.serialization") version "1.8.21" | ||
} | ||
|
||
dependencies { | ||
implementation(project(":usvm-jvm")) | ||
implementation(project(":usvm-core")) | ||
|
||
implementation("org.jacodb:jacodb-analysis:${Versions.jcdb}") | ||
implementation("ch.qos.logback:logback-classic:${Versions.logback}") | ||
|
||
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:${MLVersions.serialization}") | ||
implementation("io.github.rchowell:dotlin:${MLVersions.dotlin}") | ||
implementation("com.microsoft.onnxruntime:onnxruntime:${MLVersions.onnxruntime}") | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package org.usvm | ||
|
||
import kotlinx.serialization.Serializable | ||
import kotlinx.serialization.json.Json | ||
import kotlinx.serialization.json.JsonObject | ||
import kotlinx.serialization.json.encodeToJsonElement | ||
import kotlinx.serialization.json.jsonObject | ||
import java.util.concurrent.ConcurrentHashMap | ||
|
||
class CoverageCounter( | ||
private val mlConfig: MLConfig | ||
) { | ||
private val testCoverages = ConcurrentHashMap<String, List<Float>>() | ||
private val testStatementsCounts = ConcurrentHashMap<String, Float>() | ||
private val testDiscounts = ConcurrentHashMap<String, List<Float>>() | ||
private val testFinished = ConcurrentHashMap<String, Boolean>() | ||
|
||
fun addTest(testName: String, statementsCount: Float) { | ||
testCoverages[testName] = List(mlConfig.discounts.size) { 0.0f } | ||
testStatementsCounts[testName] = statementsCount | ||
testDiscounts[testName] = List(mlConfig.discounts.size) { 1.0f } | ||
testFinished[testName] = false | ||
} | ||
|
||
fun updateDiscounts(testName: String) { | ||
testDiscounts[testName] = testDiscounts.getValue(testName) | ||
.mapIndexed { id, currentDiscount -> mlConfig.discounts[id] * currentDiscount } | ||
} | ||
|
||
fun updateResults(testName: String, newCoverage: Float) { | ||
val currentDiscounts = testDiscounts.getValue(testName) | ||
testCoverages[testName] = testCoverages.getValue(testName) | ||
.mapIndexed { id, currentCoverage -> currentCoverage + currentDiscounts[id] * newCoverage } | ||
} | ||
|
||
fun finishTest(testName: String) { | ||
testFinished[testName] = true | ||
} | ||
|
||
fun reset() { | ||
testCoverages.clear() | ||
testStatementsCounts.clear() | ||
testDiscounts.clear() | ||
testFinished.clear() | ||
} | ||
|
||
private fun getTotalCoverages(): List<Float> { | ||
return testCoverages.values.reduce { acc, floats -> | ||
acc.zip(floats).map { (total, value) -> total + value } | ||
} | ||
} | ||
|
||
@Serializable | ||
private data class TestStatistics( | ||
private val discounts: Map<String, Float>, | ||
private val statementsCount: Float, | ||
private val finished: Boolean, | ||
) | ||
|
||
@Serializable | ||
private data class Statistics( | ||
private val tests: Map<String, TestStatistics>, | ||
private val totalDiscounts: Map<String, Float>, | ||
private val totalStatementsCount: Float, | ||
private val finishedTestsCount: Float, | ||
) | ||
|
||
fun getStatistics(): JsonObject { | ||
val discountStrings = mlConfig.discounts.map { it.toString() } | ||
val testStatistics = testCoverages.mapValues { (test, coverages) -> | ||
TestStatistics( | ||
discountStrings.zip(coverages).toMap(), | ||
testStatementsCounts.getValue(test), | ||
testFinished.getValue(test), | ||
) | ||
} | ||
val statistics = Statistics( | ||
testStatistics, | ||
discountStrings.zip(getTotalCoverages()).toMap(), | ||
testStatementsCounts.values.sum(), | ||
testFinished.values.sumOf { if (it) 1.0 else 0.0 }.toFloat(), | ||
) | ||
return Json.encodeToJsonElement(statistics).jsonObject | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package org.usvm | ||
|
||
enum class Postprocessing { | ||
Argmax, | ||
Softmax, | ||
None, | ||
} | ||
|
||
enum class Mode { | ||
Calculation, | ||
Aggregation, | ||
Both, | ||
Test, | ||
} | ||
|
||
enum class Algorithm { | ||
BFS, | ||
ForkDepthRandom, | ||
} | ||
|
||
enum class GraphUpdate { | ||
Once, | ||
TestGeneration, | ||
} | ||
|
||
data class MLConfig ( | ||
val gameEnvPath: String = "../Game_env", | ||
val dataPath: String = "../Data", | ||
val defaultAlgorithm: Algorithm = Algorithm.BFS, | ||
val postprocessing: Postprocessing = Postprocessing.Argmax, | ||
val mode: Mode = Mode.Both, | ||
val logFeatures: Boolean = true, | ||
val shuffleTests: Boolean = true, | ||
val discounts: List<Float> = listOf(1.0f, 0.998f, 0.99f), | ||
val inputShape: List<Long> = listOf(1, -1, 77), | ||
val maxAttentionLength: Int = -1, | ||
val useGnn: Boolean = true, | ||
val dataConsumption: Float = 100.0f, | ||
val hardTimeLimit: Int = 30000, // in ms | ||
val solverTimeLimit: Int = 10000, // in ms | ||
val maxConcurrency: Int = 64, | ||
val graphUpdate: GraphUpdate = GraphUpdate.Once, | ||
val logGraphFeatures: Boolean = false, | ||
val gnnFeaturesCount: Int = 8, | ||
val useRnn: Boolean = true, | ||
val rnnStateShape: List<Long> = listOf(4, 1, 512), | ||
val rnnFeaturesCount: Int = 33, | ||
val inputJars: Map<String, List<String>> = mapOf( | ||
Pair("../Game_env/jars/usvm-jvm-new.jar", listOf("org.usvm.samples", "com.thealgorithms")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like hardcoded path. Maybe we should pass this as environment variable or via configuration? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a default value inside of a configuration object, it can be changed with a configuration file. |
||
) // path to jar file -> list of package names | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package org.usvm | ||
|
||
enum class ModifiedPathSelectionStrategy { | ||
/** | ||
* Collects features according to states selected by any other path selector. | ||
*/ | ||
FEATURES_LOGGING, | ||
/** | ||
* Collects features and feeds them to the ML model to select states. | ||
* Extends FEATURE_LOGGING path selector. | ||
*/ | ||
MACHINE_LEARNING, | ||
} | ||
|
||
data class ModifiedUMachineOptions( | ||
Saloed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val basicOptions: UMachineOptions = UMachineOptions(), | ||
val pathSelectionStrategies: List<ModifiedPathSelectionStrategy> = | ||
listOf(ModifiedPathSelectionStrategy.MACHINE_LEARNING) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to return a
statistics
object, avoiding serialization to json.