Skip to content

Commit

Permalink
Internalizer optimization (#81)
Browse files Browse the repository at this point in the history
* Specialized long internalizer

* Specialized long converter

* Fix base internalizer

* Fix base converter

* Specialized Int internalizer

* Better Z3 internalizer

* Better runner serializer

* Move fastutil dependency to base

* Better z3 const decl internalizer

* Better Z3 native API

* Better Z3 context

* Fix quantified expression mover

* Use sets for leveled caches

* Bitwuzla: better sort and decl converters

* Better scoped vars

* Use specialized collections in Bitwuzla

* Fix Z3 cache

* Fix Z3 ref counting

* Fix array args builder

* Comment on dependencies size

* Extract common internalizer part

* Extract common converter part

* Fp bits normalization in tests
  • Loading branch information
Saloed authored Mar 6, 2023
1 parent 9f984e8 commit 6921c86
Show file tree
Hide file tree
Showing 32 changed files with 1,634 additions and 711 deletions.
3 changes: 3 additions & 0 deletions buildSrc/src/main/kotlin/org.ksmt.ksmt-base.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ repositories {
}

dependencies {
// Primitive collections
implementation("it.unimi.dsi:fastutil-core:8.5.11") // 6.1MB

testImplementation(kotlin("test"))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package org.ksmt.solver.bitwuzla

import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap
import org.ksmt.KContext
import org.ksmt.decl.KDecl
import org.ksmt.decl.KFuncDecl
Expand All @@ -19,6 +22,7 @@ import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.Native
import org.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED
import org.ksmt.sort.KArray2Sort
import org.ksmt.sort.KArray3Sort
import org.ksmt.sort.KArrayNSort
Expand All @@ -39,24 +43,27 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {

val bitwuzla = Native.bitwuzlaNew()

val trueTerm: BitwuzlaTerm by lazy { Native.bitwuzlaMkTrue(bitwuzla) }
val falseTerm: BitwuzlaTerm by lazy { Native.bitwuzlaMkFalse(bitwuzla) }
val boolSort: BitwuzlaSort by lazy { Native.bitwuzlaMkBoolSort(bitwuzla) }
val trueTerm: BitwuzlaTerm = Native.bitwuzlaMkTrue(bitwuzla)
val falseTerm: BitwuzlaTerm = Native.bitwuzlaMkFalse(bitwuzla)
val boolSort: BitwuzlaSort = Native.bitwuzlaMkBoolSort(bitwuzla)

private val exprGlobalCache = hashMapOf<KExpr<*>, BitwuzlaTerm>()
private val bitwuzlaExpressions = hashMapOf<BitwuzlaTerm, KExpr<*>>()
private val exprGlobalCache = mkTermCache<KExpr<*>>()
private val bitwuzlaExpressions = mkTermReverseCache<KExpr<*>>()

private val constantsGlobalCache = hashMapOf<KDecl<*>, BitwuzlaTerm>()
private val bitwuzlaConstants = hashMapOf<BitwuzlaTerm, KDecl<*>>()
private val constantsGlobalCache = mkTermCache<KDecl<*>>()
private val bitwuzlaConstants = mkTermReverseCache<KDecl<*>>()

private val sorts = hashMapOf<KSort, BitwuzlaSort>()
private val bitwuzlaSorts = hashMapOf<BitwuzlaSort, KSort>()
private val declSorts = hashMapOf<KDecl<*>, BitwuzlaSort>()
private val sorts = mkTermCache<KSort>()
private val declSorts = mkTermCache<KDecl<*>>()
private val bitwuzlaSorts = mkTermReverseCache<KSort>()

private val bitwuzlaValues = hashMapOf<BitwuzlaTerm, KExpr<*>>()
private val bitwuzlaValues = mkTermReverseCache<KExpr<*>>()

private var exprCurrentLevelCache = hashMapOf<KExpr<*>, BitwuzlaTerm>()
private val exprCacheLevel = hashMapOf<KExpr<*>, Int>()
private val exprCacheLevel = Object2IntOpenHashMap<KExpr<*>>().apply {
defaultReturnValue(Int.MAX_VALUE) // Level which is greater than any possible level
}

private var exprCurrentLevelCache = hashSetOf<KExpr<*>>()
private val exprLeveledCache = arrayListOf(exprCurrentLevelCache)
private var currentLevelExprMover = ExprMover()

Expand Down Expand Up @@ -85,37 +92,58 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
* Move expression and recollect declarations.
* See [ExprMover].
* */
fun findExprTerm(expr: KExpr<*>): BitwuzlaTerm? {
val globalTerm = exprGlobalCache[expr] ?: return null
fun findExprTerm(expr: KExpr<*>): BitwuzlaTerm {
val term = exprGlobalCache.getLong(expr)
if (term == NOT_INTERNALIZED) return NOT_INTERNALIZED

val currentLevelTerm = exprCurrentLevelCache[expr]
if (currentLevelTerm != null) return currentLevelTerm
if (expr in exprCurrentLevelCache) return term

currentLevelExprMover.apply(expr)

return globalTerm
return term
}

fun saveExprTerm(expr: KExpr<*>, term: BitwuzlaTerm) {
if (exprCurrentLevelCache.putIfAbsent(expr, term) == null) {
exprGlobalCache[expr] = term
exprCacheLevel[expr] = currentLevel
if (exprCurrentLevelCache.add(expr)) {
exprGlobalCache.put(expr, term)
exprCacheLevel.put(expr, currentLevel)
}
}

operator fun get(sort: KSort): BitwuzlaSort? = sorts[sort]
fun findInternalizedSort(sort: KSort): BitwuzlaSort =
sorts.getLong(sort)

fun internalizeSort(sort: KSort, internalizer: (KSort) -> BitwuzlaSort): BitwuzlaSort =
sorts.getOrPut(sort) {
internalizer(sort).also {
bitwuzlaSorts[it] = sort
}
}
fun saveInternalizedSort(sort: KSort, native: BitwuzlaSort) {
sorts.put(sort, native)
bitwuzlaSorts.put(native, sort)
}

fun internalizeDeclSort(decl: KDecl<*>, internalizer: (KDecl<*>) -> BitwuzlaSort): BitwuzlaSort =
declSorts.getOrPut(decl) {
internalizer(decl)
}.also { registerDeclaration(decl) }
inline fun internalizeSort(sort: KSort, internalizer: (KSort) -> BitwuzlaSort): BitwuzlaSort {
val cached = findInternalizedSort(sort)
if (cached != NOT_INTERNALIZED) return cached

val internalizedSort = internalizer(sort)
saveInternalizedSort(sort, internalizedSort)
return internalizedSort
}

fun findInternalizedDeclSort(decl: KDecl<*>): BitwuzlaSort =
declSorts.getLong(decl)

fun saveInternalizedDeclSort(decl: KDecl<*>, native: BitwuzlaSort) {
declSorts.put(decl, native)
}

inline fun internalizeDeclSort(decl: KDecl<*>, internalizer: (KDecl<*>) -> BitwuzlaSort): BitwuzlaSort {
registerDeclaration(decl)

val cached = findInternalizedDeclSort(decl)
if (cached != NOT_INTERNALIZED) return cached

val internalizedDeclSort = internalizer(decl)
saveInternalizedDeclSort(decl, internalizedDeclSort)
return internalizedDeclSort
}

/**
* Internalize and reverse cache Bv value to support Bv values conversion.
Expand All @@ -125,24 +153,42 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
* expressions.
* */
fun saveInternalizedValue(expr: KExpr<*>, term: BitwuzlaTerm) {
bitwuzlaValues[term] = expr
bitwuzlaValues.put(term, expr)
}

fun findConvertedExpr(expr: BitwuzlaTerm): KExpr<*>? = bitwuzlaExpressions[expr]
fun findConvertedExpr(expr: BitwuzlaTerm): KExpr<*>? = bitwuzlaExpressions.get(expr)

fun saveConvertedExpr(expr: BitwuzlaTerm, converted: KExpr<*>) {
if (bitwuzlaExpressions.putIfAbsent(expr, converted) == null) {
exprGlobalCache.putIfAbsent(converted, expr)
}
}

fun findConvertedSort(sort: BitwuzlaSort): KSort? = bitwuzlaSorts.get(sort)

fun saveConvertedSort(sort: BitwuzlaSort, converted: KSort) {
if (bitwuzlaSorts.putIfAbsent(sort, converted) == null) {
sorts.putIfAbsent(converted, sort)
}
}

fun convertExpr(expr: BitwuzlaTerm, converter: (BitwuzlaTerm) -> KExpr<*>): KExpr<*> =
convert(exprGlobalCache, bitwuzlaExpressions, expr, converter)
inline fun convertSort(sort: BitwuzlaSort, converter: (BitwuzlaSort) -> KSort): KSort {
val cached = findConvertedSort(sort)
if (cached != null) return cached

fun convertSort(sort: BitwuzlaSort, converter: (BitwuzlaSort) -> KSort): KSort =
convert(sorts, bitwuzlaSorts, sort, converter)
val convertedSort = converter(sort)
saveConvertedSort(sort, convertedSort)
return convertedSort
}

fun convertValue(value: BitwuzlaTerm): KExpr<*>? = bitwuzlaValues[value]
fun convertValue(value: BitwuzlaTerm): KExpr<*>? = bitwuzlaValues.get(value)

// Constant is known only if it was previously internalized
fun convertConstantIfKnown(term: BitwuzlaTerm): KDecl<*>? = bitwuzlaConstants[term]
fun convertConstantIfKnown(term: BitwuzlaTerm): KDecl<*>? = bitwuzlaConstants.get(term)

// Find normal constant if it was previously internalized
fun findConstant(decl: KDecl<*>): BitwuzlaTerm? = constantsGlobalCache[decl]
fun findConstant(decl: KDecl<*>): BitwuzlaTerm? =
constantsGlobalCache.getLong(decl).takeIf { it != NOT_INTERNALIZED }

fun declarations(): Set<KDecl<*>> =
leveledDeclarations.flatMapTo(hashSetOf()) { it }
Expand All @@ -165,7 +211,7 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
* Also, if declaration sort is uninterpreted,
* register this declaration as relevant to the sort.
* */
private fun registerDeclaration(decl: KDecl<*>) {
fun registerDeclaration(decl: KDecl<*>) {
if (currentLevelDeclarations.add(decl)) {
currentLevelUninterpretedSortRegisterer.decl = decl
decl.sort.accept(currentLevelUninterpretedSortRegisterer)
Expand All @@ -180,19 +226,27 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
* Since [Native.bitwuzlaMkConst] creates fresh constant on each invocation caches are used
* to guarantee that if two constants are equal in ksmt they are also equal in Bitwuzla.
* */
fun mkConstant(decl: KDecl<*>, sort: BitwuzlaSort): BitwuzlaTerm = constantsGlobalCache.getOrPut(decl) {
Native.bitwuzlaMkConst(bitwuzla, sort, decl.name).also {
bitwuzlaConstants[it] = decl
fun mkConstant(decl: KDecl<*>, sort: BitwuzlaSort): BitwuzlaTerm {
registerDeclaration(decl)

val value = constantsGlobalCache.getLong(decl)
if (value != NOT_INTERNALIZED) return value

val term = Native.bitwuzlaMkConst(bitwuzla, sort, decl.name).also {
bitwuzlaConstants.put(it, decl)
}
}.also { registerDeclaration(decl) }
constantsGlobalCache.put(decl, term)

return term
}

/**
* Create nested declaration scope to allow [popDeclarationScope].
* Declarations scopes are used to manage set of currently asserted declarations
* and must match to the corresponding assertion level ([KBitwuzlaSolver.push]).
* */
fun createNestedDeclarationScope() {
exprCurrentLevelCache = hashMapOf()
exprCurrentLevelCache = hashSetOf()
exprLeveledCache.add(exprCurrentLevelCache)
currentLevelExprMover = ExprMover()

Expand Down Expand Up @@ -252,23 +306,6 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
check(!isClosed) { "The context is already closed." }
}

private inline fun <K, V> convert(
cache: MutableMap<K, V>,
reverseCache: MutableMap<V, K>,
key: V,
converter: (V) -> K
): K {
val current = reverseCache[key]

if (current != null) return current

val converted = converter(key)
cache.putIfAbsent(converted, key)
reverseCache[key] = converted

return converted
}

private class UninterpretedSortRegisterer(
private val register: MutableMap<KUninterpretedSort, HashSet<KDecl<*>>>
) : KSortVisitor<Unit> {
Expand Down Expand Up @@ -332,17 +369,25 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
* */
private inner class ExprMover : KNonRecursiveTransformer(ctx) {
override fun <T : KSort> transformExpr(expr: KExpr<T>): KExpr<T> {
// Move expr to current level
val term = exprGlobalCache.getValue(expr)
exprCacheLevel[expr] = currentLevel
exprCurrentLevelCache[expr] = term
if (!insideQuantifiedScope) {
/**
* Move expr to current level.
*
* Don't move quantified expression since:
* 1. Body may contain vars which can't be moved correctly
* 2. Expression caches will remain correct regardless of body moved
* */
if (exprCurrentLevelCache.add(expr)) {
exprCacheLevel.put(expr, currentLevel)
}
}

return super.transformExpr(expr)
}

override fun <T : KSort> exprTransformationRequired(expr: KExpr<T>): Boolean {
val cachedLevel = exprCacheLevel[expr]
if (cachedLevel != null && cachedLevel < currentLevel) {
val cachedLevel = exprCacheLevel.getInt(expr)
if (cachedLevel < currentLevel) {
val levelCache = exprLeveledCache[cachedLevel]
// If expr is valid on its level we don't need to move it
return expr !in levelCache
Expand Down Expand Up @@ -374,18 +419,25 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
return super.transform(expr)
}

private val quantifiedVarsScope = arrayListOf<Pair<KExpr<*>, Set<KDecl<*>>?>>()
private val quantifiedVarsScopeOwner = arrayListOf<KExpr<*>>()
private val quantifiedVarsScope = arrayListOf<Set<KDecl<*>>?>()

private val insideQuantifiedScope: Boolean
get() = quantifiedVarsScopeOwner.isNotEmpty()

private fun <T : KSort> KExpr<T>.transformQuantifier(bounds: List<KDecl<*>>, body: KExpr<*>): KExpr<T> {
if (quantifiedVarsScope.lastOrNull()?.first != this) {
quantifiedVarsScope.add(this to currentlyIgnoredDeclarations)
if (quantifiedVarsScopeOwner.lastOrNull() != this) {
quantifiedVarsScopeOwner.add(this)
quantifiedVarsScope.add(currentlyIgnoredDeclarations)

val ignoredDecls = currentlyIgnoredDeclarations?.toHashSet() ?: hashSetOf()
ignoredDecls.addAll(bounds)
currentlyIgnoredDeclarations = ignoredDecls
}
return transformExprAfterTransformed(this, body) {
currentlyIgnoredDeclarations = quantifiedVarsScope.removeLast().second
this
quantifiedVarsScopeOwner.removeLast()
currentlyIgnoredDeclarations = quantifiedVarsScope.removeLast()
transformExpr(this)
}
}

Expand All @@ -411,4 +463,14 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
override fun transform(expr: KUniversalQuantifier): KExpr<KBoolSort> =
expr.transformQuantifier(expr.bounds, expr.body)
}

companion object {
@JvmStatic
private fun <T> mkTermCache() = Object2LongOpenHashMap<T>().apply {
defaultReturnValue(NOT_INTERNALIZED)
}

@JvmStatic
private fun <T> mkTermReverseCache() = Long2ObjectOpenHashMap<T>()
}
}
Loading

0 comments on commit 6921c86

Please sign in to comment.