Skip to content

Commit

Permalink
Initial fp support
Browse files Browse the repository at this point in the history
  • Loading branch information
CaelmBleidd committed Sep 9, 2022
1 parent 45ce2b5 commit ec872b4
Show file tree
Hide file tree
Showing 17 changed files with 1,140 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ import org.ksmt.sort.KBv32Sort
import org.ksmt.sort.KBv64Sort
import org.ksmt.sort.KBv8Sort
import org.ksmt.sort.KBvSort
import org.ksmt.sort.KFpSort
import org.ksmt.sort.KIntSort
import org.ksmt.sort.KRealSort
import org.ksmt.sort.KSort
Expand Down Expand Up @@ -587,6 +588,9 @@ open class KBitwuzlaExprInternalizer(

override fun visit(sort: KUninterpretedSort): BitwuzlaSort =
throw KSolverUnsupportedFeatureException("Unsupported sort $sort")

override fun <S : KFpSort> visit(sort: S): BitwuzlaSort =
TODO("We do not support KFP sort yet")
}

open class FunctionSortInternalizer(
Expand Down
281 changes: 273 additions & 8 deletions ksmt-core/src/main/kotlin/org/ksmt/KContext.kt

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions ksmt-core/src/main/kotlin/org/ksmt/cache/Cache.kt
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,18 @@ class Cache4<T, A0, A1, A2, A3>(val builder: (A0, A1, A2, A3) -> T) : AutoClosea
}
}

class Cache5<T, A0, A1, A2, A3, A4>(val builder: (A0, A1, A2, A3, A4) -> T) {
private val cache = HashMap<List<*>, T>()

@Suppress("unused")
fun create(a0: A0, a1: A1, a2: A2, a3: A3, a4: A4): T = cache.getOrPut(listOf(a0, a1, a2, a3, a4)) { builder(a0, a1, a2, a3, a4) }
}

fun <T> mkCache(builder: () -> T) = Cache0(builder)
fun <T, A0> mkCache(builder: (A0) -> T) = Cache1(builder)
fun <T, A0, A1> mkCache(builder: (A0, A1) -> T) = Cache2(builder)
fun <T, A0, A1, A2> mkCache(builder: (A0, A1, A2) -> T) = Cache3(builder)

@Suppress("unused")
fun <T, A0, A1, A2, A3> mkCache(builder: (A0, A1, A2, A3) -> T) = Cache4(builder)
fun <T, A0, A1, A2, A3, A4> mkCache(builder: (A0, A1, A2, A3, A4) -> T) = Cache5(builder)
116 changes: 116 additions & 0 deletions ksmt-core/src/main/kotlin/org/ksmt/decl/KFloatingPointDecl.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package org.ksmt.decl

import org.ksmt.KContext
import org.ksmt.expr.KApp
import org.ksmt.expr.KExpr
import org.ksmt.sort.KFp128Sort
import org.ksmt.sort.KFp16Sort
import org.ksmt.sort.KFp32Sort
import org.ksmt.sort.KFp64Sort
import org.ksmt.sort.KFpSort
import org.ksmt.utils.getHalfPrecisionExponent
import org.ksmt.utils.booleanSignBit
import org.ksmt.utils.getExponent
import org.ksmt.utils.halfPrecisionSignificand
import org.ksmt.utils.significand
import org.ksmt.utils.toBinary

abstract class KFpDecl<T : KFpSort, N : Number> internal constructor(
ctx: KContext,
sort: T,
val sign: Boolean,
val significand: N,
val exponent: N
) : KConstDecl<T>(
ctx,
constructNameForDeclaration(sign, sort, exponent, significand),
sort
)

private fun <N : Number, T : KFpSort> constructNameForDeclaration(
sign: Boolean,
sort: T,
exponent: N,
significand: N
): String {
val exponentBits = sort.exponentBits
val binaryExponent = exponent.toBinary().takeLast(exponentBits.toInt())
val significandBits = sort.significandBits
val binarySignificand = significand
.toBinary()
.takeLast(significandBits.toInt() - 1)
.let { it.padStart(significandBits.toInt() - 1, it[0]) }

return "FP (sign $sign) ($exponentBits $binaryExponent) ($significandBits $binarySignificand)"
}

class KFp16Decl internal constructor(ctx: KContext, val value: Float) :
KFpDecl<KFp16Sort, Int>(
ctx,
ctx.mkFp16Sort(),
value.booleanSignBit,
value.halfPrecisionSignificand,
value.getHalfPrecisionExponent(isBiased = false)
) {
override fun apply(args: List<KExpr<*>>): KApp<KFp16Sort, *> = ctx.mkFp16(value)

override fun <R> accept(visitor: KDeclVisitor<R>): R = visitor.visit(this)
}

class KFp32Decl internal constructor(ctx: KContext, val value: Float) :
KFpDecl<KFp32Sort, Int>(
ctx,
ctx.mkFp32Sort(),
value.booleanSignBit,
value.significand,
value.getExponent(isBiased = false)
) {
override fun apply(args: List<KExpr<*>>): KApp<KFp32Sort, *> = ctx.mkFp32(value)

override fun <R> accept(visitor: KDeclVisitor<R>): R = visitor.visit(this)
}

class KFp64Decl internal constructor(ctx: KContext, val value: Double) :
KFpDecl<KFp64Sort, Long>(
ctx,
ctx.mkFp64Sort(),
value.booleanSignBit,
value.significand,
value.getExponent(isBiased = false)
) {
override fun apply(args: List<KExpr<*>>): KApp<KFp64Sort, *> = ctx.mkFp64(value)

override fun <R> accept(visitor: KDeclVisitor<R>): R = visitor.visit(this)
}

// TODO replace significand with bit vector and change KFpDecl accordingly
class KFp128Decl internal constructor(
ctx: KContext,
significand: Long,
exponent: Long,
signBit: Boolean
) : KFpDecl<KFp128Sort, Long>(ctx, ctx.mkFp128Sort(), signBit, significand, exponent) {
override fun apply(args: List<KExpr<*>>): KApp<KFp128Sort, *> = ctx.mkFp128(significand, exponent, sign)

override fun <R> accept(visitor: KDeclVisitor<R>): R = visitor.visit(this)
}

class KFpCustomSizeDecl internal constructor(
ctx: KContext,
significandSize: UInt,
exponentSize: UInt,
significand: Long,
exponent: Long,
signBit: Boolean
) : KFpDecl<KFpSort, Long>(ctx, ctx.mkFpSort(exponentSize, significandSize), signBit, significand, exponent) {
override fun apply(args: List<KExpr<*>>): KApp<KFpSort, *> =
ctx.mkFpCustomSize(
sort.exponentBits,
sort.significandBits,
exponent,
significand,
sign
)

override fun <R> accept(visitor: KDeclVisitor<R>): R = visitor.visit(this)
}
11 changes: 10 additions & 1 deletion ksmt-core/src/main/kotlin/org/ksmt/expr/KBitVecExprs.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,21 @@ import org.ksmt.sort.KBv64Sort
import org.ksmt.sort.KBv8Sort
import org.ksmt.sort.KBvSort
import org.ksmt.sort.KIntSort
import org.ksmt.utils.toBinary

abstract class KBitVecValue<S : KBvSort>(
ctx: KContext
) : KApp<S, KExpr<*>>(ctx) {
override val args: List<KExpr<*>> = emptyList()

abstract val stringValue: String
}

class KBitVec1Value internal constructor(ctx: KContext, val value: Boolean) : KBitVecValue<KBv1Sort>(ctx) {
override fun accept(transformer: KTransformer): KExpr<KBv1Sort> = transformer.transform(this)

override val stringValue: String = if (value) "1" else "0"

override fun decl(): KDecl<KBv1Sort> = ctx.mkBvDecl(value)

override fun sort(): KBv1Sort = ctx.mkBv1Sort()
Expand All @@ -29,7 +34,9 @@ class KBitVec1Value internal constructor(ctx: KContext, val value: Boolean) : KB
abstract class KBitVecNumberValue<S : KBvSort, N : Number>(
ctx: KContext,
val numberValue: N
) : KBitVecValue<S>(ctx)
) : KBitVecValue<S>(ctx) {
override val stringValue: String = numberValue.toBinary()
}

class KBitVec8Value internal constructor(ctx: KContext, byteValue: Byte) :
KBitVecNumberValue<KBv8Sort, Byte>(ctx, byteValue) {
Expand Down Expand Up @@ -81,6 +88,8 @@ class KBitVecCustomValue internal constructor(

override fun accept(transformer: KTransformer): KExpr<KBvSort> = transformer.transform(this)

override val stringValue: String = binaryStringValue

override fun decl(): KDecl<KBvSort> = ctx.mkBvDecl(binaryStringValue, sizeBits)

override fun sort(): KBvSort = ctx.mkBvSort(sizeBits)
Expand Down
138 changes: 138 additions & 0 deletions ksmt-core/src/main/kotlin/org/ksmt/expr/KFloatingPointExpr.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package org.ksmt.expr

import org.ksmt.KContext
import org.ksmt.decl.KDecl
import org.ksmt.sort.KBvSort
import org.ksmt.sort.KFp128Sort
import org.ksmt.sort.KFp16Sort
import org.ksmt.sort.KFp32Sort
import org.ksmt.sort.KFp64Sort
import org.ksmt.sort.KFpSort
import org.ksmt.utils.booleanSignBit
import org.ksmt.utils.getExponent
import org.ksmt.utils.getHalfPrecisionExponent
import org.ksmt.utils.halfPrecisionSignificand
import org.ksmt.utils.significand

abstract class KFpValue<T : KFpSort>(
ctx: KContext,
val significand: KBitVecValue<out KBvSort>,
val exponent: KBitVecValue<out KBvSort>,
val signBit: Boolean
) : KApp<T, KExpr<*>>(ctx) {
override val args: List<KExpr<*>> = emptyList()
}

/**
* Fp16 value. Note that [value] should has biased Fp32 exponent,
* but a constructed Fp16 will have an unbiased one.
*
* Fp32 to Fp16 transformation:
* sign exponent significand
* 0 00000000 00000000000000000000000 (1 8 23)
* x x___xxxx xxxxxxxxxx_____________ (1 5 10)
*/
class KFp16Value internal constructor(ctx: KContext, val value: Float) :
KFpValue<KFp16Sort>(
ctx,
significand = with(ctx) { value.halfPrecisionSignificand.toBv(KFp16Sort.significandBits - 1u) },
exponent = with(ctx) { value.getHalfPrecisionExponent(isBiased = false).toBv(KFp16Sort.exponentBits) },
signBit = value.booleanSignBit
) {

init {
// TODO add checks for the bounds
}

override fun decl(): KDecl<KFp16Sort> = ctx.mkFp16Decl(value)

override fun sort(): KFp16Sort = ctx.mkFp16Sort()

override fun accept(transformer: KTransformer): KExpr<KFp16Sort> = transformer.transform(this)
}

class KFp32Value internal constructor(ctx: KContext, val value: Float) :
KFpValue<KFp32Sort>(
ctx,
significand = with(ctx) { value.significand.toBv(KFp32Sort.significandBits - 1u) },
exponent = with(ctx) { value.getExponent(isBiased = false).toBv(KFp32Sort.exponentBits) },
signBit = value.booleanSignBit
) {
override fun decl(): KDecl<KFp32Sort> = ctx.mkFp32Decl(value)

override fun sort(): KFp32Sort = ctx.mkFp32Sort()

override fun accept(transformer: KTransformer): KExpr<KFp32Sort> = transformer.transform(this)
}

class KFp64Value internal constructor(ctx: KContext, val value: Double) :
KFpValue<KFp64Sort>(
ctx,
significand = with(ctx) { value.significand.toBv(KFp64Sort.significandBits - 1u) },
exponent = with(ctx) { value.getExponent(isBiased = false).toBv(KFp64Sort.exponentBits) },
signBit = value.booleanSignBit
) {
override fun decl(): KDecl<KFp64Sort> = ctx.mkFp64Decl(value)

override fun sort(): KFp64Sort = ctx.mkFp64Sort()

override fun accept(transformer: KTransformer): KExpr<KFp64Sort> = transformer.transform(this)
}

/**
* KFp128 value.
*
* Note: if [exponentValue] contains more than [KFp128Sort.exponentBits] meaningful bits,
* only the last [KFp128Sort.exponentBits] of then will be taken.
*/
class KFp128Value internal constructor(
ctx: KContext,
val significandValue: Long,
val exponentValue: Long,
signBit: Boolean
) : KFpValue<KFp128Sort>(
ctx,
significand = with(ctx) { significandValue.toBv(KFp128Sort.significandBits - 1u) },
exponent = with(ctx) { exponentValue.toBv(KFp128Sort.exponentBits) },
signBit
) {
override fun decl(): KDecl<KFp128Sort> = ctx.mkFp128Decl(significandValue, exponentValue, signBit)

override fun sort(): KFp128Sort = ctx.mkFp128Sort()

override fun accept(transformer: KTransformer): KExpr<KFp128Sort> = transformer.transform(this)
}

/**
* KFp value of custom size.
*
* Note: if [exponentValue] contains more than [KFp128Sort.exponentBits] meaningful bits,
* only the last [KFp128Sort.exponentBits] of then will be taken.
* The same is true for the significand.
*/
class KFpCustomSizeValue internal constructor(
ctx: KContext,
val significandSize: UInt,
val exponentSize: UInt,
val significandValue: Long,
val exponentValue: Long,
signBit: Boolean
) : KFpValue<KFpSort>(
ctx,
significand = with(ctx) { significandValue.toBv(significandSize - 1u) },
exponent = with(ctx) { exponentValue.toBv(exponentSize) },
signBit
) {
init {
require(exponentSize.toInt() <= 63) { "Maximum number of exponent bits is 63" }
}

override fun decl(): KDecl<KFpSort> =
ctx.mkFpCustomSizeDecl(significandSize, exponentSize, significandValue, exponentValue, signBit)

override fun sort(): KFpSort = ctx.mkFpSort(exponentSize, significandSize)

override fun accept(transformer: KTransformer): KExpr<KFpSort> = transformer.transform(this)
}


Loading

0 comments on commit ec872b4

Please sign in to comment.