Skip to content

Commit

Permalink
Initial fp support
Browse files Browse the repository at this point in the history
  • Loading branch information
CaelmBleidd committed Sep 1, 2022
1 parent d4d326f commit 56d85d3
Show file tree
Hide file tree
Showing 16 changed files with 1,064 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ import org.ksmt.sort.KBv32Sort
import org.ksmt.sort.KBv64Sort
import org.ksmt.sort.KBv8Sort
import org.ksmt.sort.KBvSort
import org.ksmt.sort.KFpSort
import org.ksmt.sort.KIntSort
import org.ksmt.sort.KRealSort
import org.ksmt.sort.KSort
Expand Down Expand Up @@ -467,6 +468,9 @@ open class KBitwuzlaExprInternalizer(

override fun visit(sort: KUninterpretedSort): BitwuzlaSort =
throw KSolverUnsupportedFeatureException("Unsupported sort $sort")

override fun <S : KFpSort> visit(sort: S): BitwuzlaSort =
throw KSolverUnsupportedFeatureException("Unsupported sort $sort")
}

open class FunctionSortInternalizer(
Expand Down
263 changes: 255 additions & 8 deletions ksmt-core/src/main/kotlin/org/ksmt/KContext.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.ksmt

import java.lang.Double.longBitsToDouble
import java.lang.Float.intBitsToFloat
import org.ksmt.cache.mkCache
import org.ksmt.decl.KAndDecl
import org.ksmt.decl.KArithAddDecl
Expand Down Expand Up @@ -179,6 +181,12 @@ import org.ksmt.sort.KBv8Sort
import org.ksmt.sort.KBvCustomSizeSort
import org.ksmt.sort.KBvSort
import org.ksmt.sort.KBoolSort
import org.ksmt.sort.KFp128Sort
import org.ksmt.sort.KFp16Sort
import org.ksmt.sort.KFp32Sort
import org.ksmt.sort.KFp64Sort
import org.ksmt.sort.KFpCustomSizeSort
import org.ksmt.sort.KFpSort
import org.ksmt.sort.KIntSort
import org.ksmt.sort.KRealSort
import org.ksmt.sort.KSort
Expand All @@ -188,6 +196,12 @@ import kotlin.reflect.KProperty
import org.ksmt.decl.KBvRotateLeftIndexedDecl
import org.ksmt.decl.KBvRotateRightIndexedDecl
import org.ksmt.decl.KBvSubNoUnderflowDecl
import org.ksmt.decl.KFp128Decl
import org.ksmt.decl.KFp16Decl
import org.ksmt.decl.KFp32Decl
import org.ksmt.decl.KFp64Decl
import org.ksmt.decl.KFpCustomSizeDecl
import org.ksmt.decl.KFpDecl
import org.ksmt.expr.KBitVec16Value
import org.ksmt.expr.KBitVec1Value
import org.ksmt.expr.KBitVec32Value
Expand All @@ -197,8 +211,18 @@ import org.ksmt.expr.KBitVecCustomValue
import org.ksmt.expr.KBvRotateLeftIndexedExpr
import org.ksmt.expr.KBvRotateRightIndexedExpr
import org.ksmt.expr.KBvSubNoUnderflowExpr
import org.ksmt.expr.KFp128Value
import org.ksmt.expr.KFp16Value
import org.ksmt.expr.KFp32Value
import org.ksmt.expr.KFp64Value
import org.ksmt.expr.KFpCustomSizeValue
import org.ksmt.expr.KFpValue
import org.ksmt.expr.KFunctionAsArray
import org.ksmt.utils.booleanSignBit
import org.ksmt.utils.cast
import org.ksmt.utils.extendWithZeros
import org.ksmt.utils.extractExponent
import org.ksmt.utils.extractSignificand
import org.ksmt.utils.toBinary
import org.ksmt.utils.uncheckedCast

Expand Down Expand Up @@ -245,6 +269,25 @@ open class KContext {
private val uninterpretedSortCache = mkCache { name: String -> KUninterpretedSort(name, this) }
fun mkUninterpretedSort(name: String): KUninterpretedSort = uninterpretedSortCache.create(name)

// floating point
private val fpSortCache = mkCache { exponentBits: UInt, significandBits: UInt ->
when {
exponentBits == KFp16Sort.exponentBits && significandBits == KFp16Sort.significandBits -> KFp16Sort(this)
exponentBits == KFp32Sort.exponentBits && significandBits == KFp32Sort.significandBits -> KFp32Sort(this)
exponentBits == KFp64Sort.exponentBits && significandBits == KFp64Sort.significandBits -> KFp64Sort(this)
exponentBits == KFp128Sort.exponentBits && significandBits == KFp128Sort.significandBits -> KFp128Sort(this)
else -> KFpCustomSizeSort(this, exponentBits, significandBits)
}
}

fun mkFp16Sort(): KFp16Sort = fpSortCache.create(KFp16Sort.exponentBits, KFp16Sort.significandBits).cast()
fun mkFp32Sort(): KFp32Sort = fpSortCache.create(KFp32Sort.exponentBits, KFp32Sort.significandBits).cast()
fun mkFp64Sort(): KFp64Sort = fpSortCache.create(KFp64Sort.exponentBits, KFp64Sort.significandBits).cast()
fun mkFp128Sort(): KFp128Sort = fpSortCache.create(KFp128Sort.exponentBits, KFp128Sort.significandBits).cast()
fun mkFpSort(exponentBits: UInt, significandBits: UInt): KFpSort =
fpSortCache.create(exponentBits, significandBits)


// utils
val boolSort: KBoolSort
get() = mkBoolSort()
Expand Down Expand Up @@ -604,33 +647,53 @@ open class KContext {
private val bvCache = mkCache { value: String, sizeBits: UInt -> KBitVecCustomValue(this, value, sizeBits) }

fun mkBv(value: Boolean): KBitVec1Value = bv1Cache.create(value)
fun mkBv(value: Boolean, sizeBits: UInt): KBitVecValue<KBvSort> = mkBv((if (value) 1 else 0) as Number, sizeBits)
fun Boolean.toBv(): KBitVec1Value = mkBv(this)
fun Boolean.toBv(sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(this, sizeBits)


fun mkBv(value: Byte): KBitVec8Value = bv8Cache.create(value)
fun mkBv(value: Byte, sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(value as Number, sizeBits)
fun Byte.toBv(): KBitVec8Value = mkBv(this)
fun Byte.toBv(sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(this, sizeBits)
fun UByte.toBv(): KBitVec8Value = mkBv(toByte())

fun mkBv(value: Short): KBitVec16Value = bv16Cache.create(value)
fun mkBv(value: Short, sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(value as Number, sizeBits)
fun Short.toBv(): KBitVec16Value = mkBv(this)
fun Short.toBv(sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(this, sizeBits)
fun UShort.toBv(): KBitVec16Value = mkBv(toShort())

fun mkBv(value: Int): KBitVec32Value = bv32Cache.create(value)
fun mkBv(value: Int, sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(value as Number, sizeBits)
fun Int.toBv(): KBitVec32Value = mkBv(this)
fun Int.toBv(sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(this, sizeBits)
fun UInt.toBv(): KBitVec32Value = mkBv(toInt())

fun mkBv(value: Long): KBitVec64Value = bv64Cache.create(value)
fun mkBv(value: Long, sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(value as Number, sizeBits)
fun Long.toBv(): KBitVec64Value = mkBv(this)
fun Long.toBv(sizeBits: UInt): KBitVecValue<KBvSort> = mkBv(this, sizeBits)
fun ULong.toBv(): KBitVec64Value = mkBv(toLong())
fun mkBv(value: Number, sizeBits: UInt): KBitVecValue<KBvSort> {
val binaryString = value.toBinary()

require(binaryString.length <= sizeBits.toInt()) {
"Cannot create a bitvector of size $sizeBits from the given number $value" +
" since its binary representation requires at least ${binaryString.length} bits"
}

/**
* Constructs a bit vector from the given [value] containing of [sizeBits] bits.
*
* Note: if [sizeBits] is less than is required to represent the [value],
* the last [sizeBits] bits of the [value] will be taken.
*
* At the same time, if [sizeBits] is greater than it is required,
* binary representation of the [value] will be padded from the start with its sign bit.
*/
private fun mkBv(value: Number, sizeBits: UInt): KBitVecValue<KBvSort> {
val binaryString = value.toBinary().takeLast(sizeBits.toInt())
val paddedString = binaryString.padStart(sizeBits.toInt(), binaryString.first())

return mkBv(paddedString, sizeBits)
}

fun Number.toBv(sizeBits: UInt) = mkBv(this, sizeBits)
private fun Number.toBv(sizeBits: UInt) = mkBv(this, sizeBits)

fun mkBv(value: String, sizeBits: UInt): KBitVecValue<KBvSort> = when (sizeBits.toInt()) {
1 -> mkBv(value.toUInt(radix = 2).toInt() != 0).cast()
Byte.SIZE_BITS -> mkBv(value.toUByte(radix = 2).toByte()).cast()
Expand Down Expand Up @@ -981,6 +1044,129 @@ open class KContext {
fun <T : KBvSort> mkBvMulNoUnderflowExpr(arg0: KExpr<T>, arg1: KExpr<T>): KBvMulNoUnderflowExpr<T> =
bvMulNoUnderflowExprCache.create(arg0.cast(), arg1.cast()).cast()

// fp values
private val fp16Cache = mkCache { value: Float -> KFp16Value(this, value) }
private val fp32Cache = mkCache { value: Float -> KFp32Value(this, value) }
private val fp64Cache = mkCache { value: Double -> KFp64Value(this, value) }
private val fp128Cache = mkCache { significand: Long, exponent: Long, signBit: Boolean ->
KFp128Value(this, significand, exponent, signBit)
}
private val fpCustomSizeCache =
mkCache { significandSize: UInt, exponentSize: UInt, significand: Long, exponent: Long, signBit: Boolean ->
KFpCustomSizeValue(this, significandSize, exponentSize, significand, exponent, signBit)
}

/**
* Creates FP16 from the [value].
*
* Important: we suppose that [value] has biased exponent, but FP16 will be created from the unbiased one.
* So, at first, we'll subtract [KFp16Sort.exponentShiftSize] from the [value]'s exponent,
* take required for FP16 bits, and this will be **unbiased** FP16 exponent.
* The same is true for other methods but [mkFpCustomSize].
* */
fun mkFp16(value: Float): KFp16Value = fp16Cache.create(value)
fun mkFp32(value: Float): KFp32Value = fp32Cache.create(value)
fun mkFp64(value: Double): KFp64Value = fp64Cache.create(value)
fun mkFp128(significand: Long, exponent: Long, signBit: Boolean): KFp128Value =
fp128Cache.create(significand, exponent, signBit)

/**
* Creates FP with a custom size.
* Important: [exponent] here is an **unbiased** value.
*/
fun <T : KFpSort> mkFpCustomSize(
exponentSize: UInt,
significandSize: UInt,
exponent: Long,
significand: Long,
signBit: Boolean
): KFpValue<T> {
val intSignBit = if (signBit) 1 else 0

return when (mkFpSort(exponentSize, significandSize)) {
is KFp16Sort -> {
val number = constructFp16Number(exponent, significand, intSignBit)

mkFp16(number).cast()
}
is KFp32Sort -> {
val number = constructFp32Number(exponent, significand, intSignBit)

mkFp32(number).cast()
}
is KFp64Sort -> {
val number = constructFp64Number(exponent, significand, intSignBit)

mkFp64(number).cast()
}
is KFp128Sort -> mkFp128(significand, exponent, signBit).cast()
else -> fpCustomSizeCache.create(significandSize, exponentSize, significand, exponent, signBit).cast()
}
}

private fun constructFp16Number(exponent: Long, significand: Long, intSignBit: Int): Float {
// get sign and `body` of the unbiased exponent
val exponentSign = (exponent.toInt() shr 4) and 1
val otherExponent = exponent.toInt() and 0b1111

// get fp16 significand part -- last teb bits (eleventh stored implicitly)
val significandBits = significand.toInt() and 0b1111_1111_11

// Transform fp16 exponent into fp32 exponent adding three zeroes between the sign and the body
// Then add the bias for fp32 and apply the mask to avoid overflow of the eight bits
val biasedFloatExponent = (((exponentSign shl 8) or otherExponent) + KFp32Sort.exponentShiftSize) and 0xff

val bits = (intSignBit shl 31) or (biasedFloatExponent shl 23) or (significandBits shl 13)

return intBitsToFloat(bits)
}

private fun constructFp32Number(exponent: Long, significand: Long, intSignBit: Int): Float {
// `and 0xff` here is to avoid overloading when we have a number greater than 255,
// and the result of the addition will affect the sign bit
val biasedExponent = (exponent.toInt() + KFp32Sort.exponentShiftSize) and 0xff
val intValue = (intSignBit shl 31) or (biasedExponent shl 23) or significand.toInt()

return intBitsToFloat(intValue)
}

private fun constructFp64Number(exponent: Long, significand: Long, intSignBit: Int): Double {
// `and 0b111_1111_1111` here is to avoid overloading when we have a number greater than 255,
// and the result of the addition will affect the sign bit
val biasedExponent = (exponent + KFp64Sort.exponentShiftSize) and 0b111_1111_1111
val longValue = (intSignBit.toLong() shl 63) or (biasedExponent shl 52) or significand

return longBitsToDouble(longValue)
}

fun <T : KFpSort> mkFp(value: Float, sort: T): KExpr<T> {
val significand = value.extractSignificand(sort)
val exponent = value.extractExponent(sort, isBiased = false).extendWithZeros()
val sign = value.booleanSignBit

return mkFpCustomSize(sort.exponentBits, sort.significandBits, exponent, significand.extendWithZeros(), sign)
}

fun <T : KFpSort> mkFp(value: Double, sort: T): KExpr<T> {
val significand = value.extractSignificand(sort)
val exponent = value.extractExponent(sort, isBiased = false)
val sign = value.booleanSignBit

return mkFpCustomSize(sort.exponentBits, sort.significandBits, exponent, significand, sign)
}

fun <T : KFpSort> mkFp(significand: Int, exponent: Int, signBit: Boolean, sort: T): KExpr<T> =
mkFpCustomSize(
sort.exponentBits,
sort.significandBits,
exponent.extendWithZeros(),
significand.extendWithZeros(),
signBit
)

fun <T : KFpSort> mkFp(significand: Long, exponent: Long, signBit: Boolean, sort: T): KExpr<T> =
mkFpCustomSize(sort.exponentBits, sort.significandBits, exponent, significand, signBit)

// quantifiers
private val existentialQuantifierCache = mkCache { body: KExpr<KBoolSort>, bounds: List<KDecl<*>> ->
ensureContextMatch(body)
Expand Down Expand Up @@ -1492,6 +1678,67 @@ open class KContext {
fun <T : KBvSort> mkBvMulNoUnderflowDecl(arg0: T, arg1: T): KBvMulNoUnderflowDecl<T> =
bvMulNoUnderflowDeclCache.create(arg0, arg1).cast()

// FP
private val fp16DeclCache = mkCache { value: Float -> KFp16Decl(this, value) }
fun mkFp16Decl(value: Float): KFp16Decl = fp16DeclCache.create(value)

private val fp32DeclCache = mkCache { value: Float -> KFp32Decl(this, value) }
fun mkFp32Decl(value: Float): KFp32Decl = fp32DeclCache.create(value)

private val fp64DeclCache = mkCache { value: Double -> KFp64Decl(this, value) }
fun mkFp64Decl(value: Double): KFp64Decl = fp64DeclCache.create(value)

private val fp128DeclCache = mkCache { significand: Long, exponent: Long, signBit: Boolean ->
KFp128Decl(this, significand, exponent, signBit)
}

fun mkFp128Decl(significandBits: Long, exponent: Long, signBit: Boolean): KFp128Decl =
fp128DeclCache.create(significandBits, exponent, signBit)

private val fpCustomSizeDeclCache =
mkCache { significandSize: UInt, exponentSize: UInt, significand: Long, exponent: Long, signBit: Boolean ->
KFpCustomSizeDecl(this, significandSize, exponentSize, significand, exponent, signBit)
}

fun <T : KFpSort> mkFpCustomSizeDecl(
significandSize: UInt,
exponentSize: UInt,
significand: Long,
exponent: Long,
signBit: Boolean
): KFpDecl<T, *> {
val sort = mkFpSort(exponentSize, significandSize)

if (sort is KFpCustomSizeSort) {
return fpCustomSizeDeclCache.create(significandSize, exponentSize, significand, exponent, signBit).cast()
}

if (sort is KFp128Sort) {
return fp128DeclCache.create(significand, exponent, signBit).cast()
}

val intSignBit = if (signBit) 1 else 0

return when (sort) {
is KFp16Sort -> {
val fp16Number = constructFp16Number(exponent, significand, intSignBit)

mkFp16Decl(fp16Number).cast()
}
is KFp32Sort -> {
val fp32Number = constructFp32Number(exponent, significand, intSignBit)

mkFp32Decl(fp32Number).cast()
}
is KFp64Sort -> {
val fp64Number = constructFp64Number(exponent, significand, intSignBit)

mkFp64Decl(fp64Number).cast()
}
else -> error("Sort declaration for an unknown $sort")
}
}

/*
* KAst
* */
Expand Down
8 changes: 8 additions & 0 deletions ksmt-core/src/main/kotlin/org/ksmt/cache/Cache.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,18 @@ class Cache4<T, A0, A1, A2, A3>(val builder: (A0, A1, A2, A3) -> T) {
fun create(a0: A0, a1: A1, a2: A2, a3: A3): T = cache.getOrPut(listOf(a0, a1, a2, a3)) { builder(a0, a1, a2, a3) }
}

class Cache5<T, A0, A1, A2, A3, A4>(val builder: (A0, A1, A2, A3, A4) -> T) {
private val cache = HashMap<List<*>, T>()

@Suppress("unused")
fun create(a0: A0, a1: A1, a2: A2, a3: A3, a4: A4): T = cache.getOrPut(listOf(a0, a1, a2, a3, a4)) { builder(a0, a1, a2, a3, a4) }
}

fun <T> mkCache(builder: () -> T) = Cache0(builder)
fun <T, A0> mkCache(builder: (A0) -> T) = Cache1(builder)
fun <T, A0, A1> mkCache(builder: (A0, A1) -> T) = Cache2(builder)
fun <T, A0, A1, A2> mkCache(builder: (A0, A1, A2) -> T) = Cache3(builder)

@Suppress("unused")
fun <T, A0, A1, A2, A3> mkCache(builder: (A0, A1, A2, A3) -> T) = Cache4(builder)
fun <T, A0, A1, A2, A3, A4> mkCache(builder: (A0, A1, A2, A3, A4) -> T) = Cache5(builder)
Loading

0 comments on commit 56d85d3

Please sign in to comment.