From 4cb64c03b57408fe90a52d871bf6235cd00f79b5 Mon Sep 17 00:00:00 2001 From: yousifpatti Date: Mon, 9 Dec 2024 11:11:35 +1000 Subject: [PATCH] Added SASI and VS --- src/main/scala/analysis/Lattice.scala | 546 +++++++++++++++++++++++++- 1 file changed, 534 insertions(+), 12 deletions(-) diff --git a/src/main/scala/analysis/Lattice.scala b/src/main/scala/analysis/Lattice.scala index 0ef98020f..415bc6ddd 100644 --- a/src/main/scala/analysis/Lattice.scala +++ b/src/main/scala/analysis/Lattice.scala @@ -2,31 +2,554 @@ package analysis import ir._ import analysis.BitVectorEval +import math.pow import util.Logger /** Basic lattice - */ + */ trait Lattice[T]: type Element = T /** The bottom element of this lattice. - */ + */ val bottom: T /** The top element of this lattice. Default: not implemented. - */ + */ def top: T = ??? /** The least upper bound of `x` and `y`. - */ + */ def lub(x: T, y: T): T /** Returns true whenever `x` <= `y`. - */ + */ def leq(x: T, y: T): Boolean = lub(x, y) == y // rarely used, but easy to implement :-) +trait StridedWrappedInterval + +case class SI(s: BigInt, l: BigInt, u: BigInt, w: BigInt) extends StridedWrappedInterval { + if (l == u) { + require(s == 0) + } + + override def toString = s"SASI $s [$l, $u] $w" +} + +case object SIBottom extends StridedWrappedInterval { + override def toString = "SASIBot" +} + +// TOP is 1[0^w, 1^w]w +case object SITop extends StridedWrappedInterval { + override def toString = "SASITop" +} + +class SASILattice extends Lattice[StridedWrappedInterval] { + val lowestPossibleValue: BigInt = 0 + val highestPossibleValue: BigInt = Long.MaxValue - 1 + + override val bottom: StridedWrappedInterval = SIBottom + + override def top: StridedWrappedInterval = SITop + + // def gamma(x: StridedWrappedInterval): Set[BitVecLiteral] = x match { + // case SIBottom => Set.empty + // case SI(s, l, u, w) => + // if (s == BitVecLiteral(0, 64)) { // singleton set + // Set(l) + // } else { + // bitVec_interval(l, u, s) + // } + // } + + def isSingleValue(x: StridedWrappedInterval): Boolean = x match { + case SI(s, l, u, w) => s == 0 && l == u + case _ => false + } + + def modularPlus(a: BigInt, b: BigInt, w: BigInt): BigInt = { + (a + b) mod BigInt(2).pow(w.toInt) + } + + def modularMinus(a: BigInt, b: BigInt, w: BigInt): BigInt = { + (a - b) mod BigInt(2).pow(w.toInt) + } + + def modularLEQ(a: BigInt, b: BigInt, x: BigInt, w: BigInt): Boolean = { + modularMinus(a, x, w) <= modularMinus(b, x, w) + } + + def membershipFunction(v: BigInt, r: StridedWrappedInterval): Boolean = { + r match { + case SIBottom => false + case SITop => true + case SI(sr, lb, ub, w) => + modularLEQ(v, ub, lb, w) && (modularMinus(v, lb, w) mod sr) == 0 + } + } + + def cardinalityFunction(r: StridedWrappedInterval, w: BigInt): BigInt = { + r match { + case SIBottom => 0 + case SITop => BigInt(2).pow(w.toInt) + case SI(sr, lb, ub, w) => ((ub - lb + 1) / sr) // TODO: this may need to be a math.floor operation + } + } + + def orderingOperator(r: StridedWrappedInterval, t: StridedWrappedInterval): Boolean = { + if (r == SITop && t != SITop) { + false + } else if (r == SIBottom || t == SITop) { + true + } else { + (r, t) match { + case (SI(sr, a, b, w1), SI(st, c, d, w2)) => + + if ((a == c) && (b == d) && ((st == 0 && sr == 0) || (st != 0 && (sr mod st) == 0))) { // added check for zero division that is not in paper + return true + } + membershipFunction(a, t) && membershipFunction(b, t) && (!membershipFunction(c, r) || !membershipFunction(d, r)) && ((a - c) mod st) == 0 && (sr mod st) == 0 + case _ => false + } + } + } + + /** S1[L1, U1] join S2[L2, U2] -> gcd(S1, S2)[min(L1, L2), max(U1, U2)] */ + override def lub(r: StridedWrappedInterval, t: StridedWrappedInterval): StridedWrappedInterval = { + (r, t) match { + case (SIBottom, t) => t + case (t, SIBottom) => t + case (SITop, _) => SITop + case (_, SITop) => SITop + case (SI(sr, a, b, w1), SI(st, c, d, w2)) => + assert(w1 == w2) + val w = w1 // TODO: should this be the largest? + if (orderingOperator(r, t)) { + return t + } + if (orderingOperator(t, r)) { + return r + } + if (membershipFunction(a, t) && membershipFunction(b, t) && membershipFunction(c, r) && membershipFunction(d, r)) { + return SITop + } + if (membershipFunction(c, r) && membershipFunction(b, t) && !membershipFunction(a, t) && !membershipFunction(d, r)) { + return SI(sr.gcd(st).gcd(modularMinus(d, a, w)), a, d, w) + } + if (membershipFunction(a, t) && membershipFunction(d, r) && !membershipFunction(c, r) && !membershipFunction(b, t)) { + return SI(sr.gcd(st).gcd(modularMinus(b, c, w)), c, b, w) + } + val sad = SI(sr.gcd(st).gcd(modularMinus(d, a, w)), a, d, w) + val scb = SI(sr.gcd(st).gcd(modularMinus(b, c, w)), c, b, w) + if (!membershipFunction(a, t) && !membershipFunction(d, r) && !membershipFunction(c, r) && !membershipFunction(b, t) && cardinalityFunction(sad, w) <= cardinalityFunction(scb, w)) { + return sad + } + return scb + case _ => ??? + } + } + + def singletonSI(v: BigInt, w: BigInt): StridedWrappedInterval = { + SI(0, v, v, w) + } + + // def valuesToSI(x: List[BigInt], w: BigInt): StridedWrappedInterval = { + // if (x.isEmpty) { + // SIBottom + // } else { + // val l = x.min + // val u = x.max + // val initialStride = u - l + // val stride = x.foldLeft(initialStride) { + // case (acc, v) => acc.gcd(v - l) + // } + // SI(stride, l, u, w) + // } + // } + + /** + * Convert a set of values to a strided interval. Assumes the widths are the same. + * @param x the set of values + * @param w the width of each value + * @return the strided interval representing the values in the set + */ + def valuesToSI(x: Set[BigInt], w: BigInt): StridedWrappedInterval = { + if (x.isEmpty) { + SIBottom + } else { + // create singleton intervals for each value and then join them + x.foldLeft(bottom) { + case (acc, v) => lub(acc, singletonSI(v, w)) + } + } + } + + /** + * s + t = + * BOT if s = BOT or t = BOT + * gcd(s, t)(|a +w c, b +w d|) if s = (|a, b|), t = (|c, d|) and #s + #t <= 2^w + * @param s + * @param t + * @return + */ + def add(s: StridedWrappedInterval, t: StridedWrappedInterval): StridedWrappedInterval = { + (s, t) match { + case (SIBottom, _) => SIBottom // TODO: is this correct? + case (_, SIBottom) => SIBottom // TODO: is this correct? + case (SI(ss, a, b, w1), SI(st, c, d, w2)) if (cardinalityFunction(s, w1) + cardinalityFunction(t, w2)) <= BigInt(2).pow(w1.toInt) => + assert(w1 == w2) + return SI(ss.gcd(st), modularPlus(a, c, w1), modularPlus(b, d, w1), w1) + case _ => SITop + } + } + + def add(s: StridedWrappedInterval, t: BigInt, w: BigInt): StridedWrappedInterval = { + (s, t) match { + case (SIBottom, _) => SIBottom // TODO: is this correct? + case (SI(ss, a, b, w1), t) => + return add(s, singletonSI(t, w)) + case _ => SITop + } + } + + def sub(s: StridedWrappedInterval, t: StridedWrappedInterval): StridedWrappedInterval = { + (s, t) match { + case (SIBottom, _) => SIBottom // TODO: is this correct? + case (_, SIBottom) => SIBottom // TODO: is this correct? + case (SI(ss, a, b, w1), SI(st, c, d, w2)) if (cardinalityFunction(s, w1) + cardinalityFunction(t, w2)) <= BigInt(2).pow(w1.toInt) => + assert(w1 == w2) + return SI(ss.gcd(st), modularMinus(a, d, w1), modularMinus(b, c, w1), w1) + case _ => SITop + } + } + + def sub(s: StridedWrappedInterval, t: BigInt, w: BigInt): StridedWrappedInterval = { + (s, t) match { + case (SIBottom, _) => SIBottom // TODO: is this correct? + case (SI(ss, a, b, w1), t) => + return sub(s, singletonSI(t, w)) + case _ => SITop + } + } +} + +sealed trait ValueSet[T] + +case class VS[T](m: Map[T, StridedWrappedInterval]) extends ValueSet[T] { // TODO: default value in map must be assumed to be SIBottom + override def toString: String = m.toString +} + +/** The lattice of integers with the standard ordering. + */ +class ValueSetLattice[T] extends Lattice[ValueSet[T]] { + + case object VSBottom extends ValueSet[T] { + override def toString = "VSBot" + } + + case object VSTop extends ValueSet[T] { + override def toString = "VSTop" + } + + override val bottom: ValueSet[T] = VSBottom + + override def top: ValueSet[T] = VSTop + + val lattice: SASILattice = SASILattice() + + override def lub(x: ValueSet[T], y: ValueSet[T]): ValueSet[T] = { + (x, y) match { + case (VSBottom, t) => t + case (t, VSBottom) => t + case (VSTop, _) => VSTop + case (_, VSTop) => VSTop + case (VS(m1), VS(m2)) => + VS(m1.keys.foldLeft(m2) { + case (acc, k) => + val v1 = m1(k) + val v2 = m2(k) + acc + (k -> lattice.lub(v1, v2)) + }) + } + } + + // def meet(x: ValueSet[String], y: ValueSet[String]): ValueSet[String] = { + // (x, y) match { + // case (VSBottom, t) => VSBottom + // case (t, VSBottom) => VSBottom + // case (VSTop, _) => y + // case (_, VSTop) => x + // case (VS(m1), VS(m2)) => + // VS(m1.keys.foldLeft(m2) { + // case (acc, k) => + // val v1 = m1(k) + // val v2 = m2(k) + // acc + (k -> lattice.meet(v1, v2)) + // }) + // } + // } + + def applyOp(op: BinOp, lhs: ValueSet[T], rhs: Either[ValueSet[T], BitVecLiteral]): ValueSet[T] = { + op match + case bvOp: BVBinOp => + bvOp match + case BVAND => ??? + case BVOR => ??? + case BVADD => rhs match + case Left(vs) => add(lhs, vs) + case Right(bitVecLiteral) => add(lhs, bitVecLiteral) + case BVMUL => ??? + case BVUDIV => ??? + case BVUREM => ??? + case BVSHL => ??? + case BVLSHR => ??? + case BVULT => ??? + case BVNAND => ??? + case BVNOR => ??? + case BVXOR => ??? + case BVXNOR => ??? + case BVCOMP => ??? + case BVSUB => rhs match + case Left(vs) => sub(lhs, vs) + case Right(bitVecLiteral) => sub(lhs, bitVecLiteral) + case BVSDIV => ??? + case BVSREM => ??? + case BVSMOD => ??? + case BVASHR => ??? + case BVULE => ??? + case BVUGT => ??? + case BVUGE => ??? + case BVSLT => ??? + case BVSLE => ??? + case BVSGT => ??? + case BVSGE => ??? + case BVEQ => ??? + case BVNEQ => ??? + case BVCONCAT => ??? + case boolOp: BoolBinOp => + boolOp match + case BoolEQ => applyOp(BVEQ, lhs, rhs) + case BoolNEQ => applyOp(BVNEQ, lhs, rhs) + case BoolAND => applyOp(BVAND, lhs, rhs) + case BoolOR => applyOp(BVOR, lhs, rhs) + case BoolIMPLIES => ??? + case BoolEQUIV => ??? + case intOp: IntBinOp => + applyOp(intOp.toBV, lhs, rhs) + case _ => ??? + } + + def applyOp(op: UnOp, rhs: ValueSet[T]): ValueSet[T] = { + op match + case bvOp: BVUnOp => + bvOp match + case BVNOT => ??? + case BVNEG => ??? + case boolOp: BoolUnOp => + boolOp match + case BoolNOT => ??? + case intOp: IntUnOp => + applyOp(intOp.toBV, rhs) + case _ => ??? + } + + def add(x: ValueSet[T], y: ValueSet[T]): ValueSet[T] = { + (x, y) match { + case (VSBottom, t) => t + case (t, VSBottom) => t + case (VSTop, _) => VSTop + case (_, VSTop) => VSTop + case (VS(m1), VS(m2)) => + VS(m1.keys.foldLeft(m2) { + case (acc, k) => + val v1 = m1(k) + val v2 = m2(k) + acc + (k -> lattice.add(v1, v2)) + }) + } + } + + def add(x: ValueSet[T], y: BitVecLiteral): ValueSet[T] = { + x match { + case VSBottom => VSBottom + case VSTop => VSTop + case VS(m) => + VS(m.map { + case (k, s) => k -> lattice.add(s, y.value, y.size) // TODO: is the size correct here? + }) + } + } + + def sub(x: ValueSet[T], y: ValueSet[T]): ValueSet[T] = { + (x, y) match { + case (VSTop, _) => VSTop // TODO: is this correct? + case (_, VSTop) => VSTop // TODO: is this correct? + case (VSBottom, t) => VSBottom + case (t, VSBottom) => t + case (VS(m1), VS(m2)) => + VS(m1.keys.foldLeft(m2) { + case (acc, k) => + val v1 = m1(k) + val v2 = m2(k) + acc + (k -> lattice.sub(v1, v2)) + }) + } + } + + def sub(x: ValueSet[T], y: BitVecLiteral): ValueSet[T] = { + x match { + case VSTop => VSTop + case VSBottom => VSBottom + case VS(m) => + VS(m.map { + case (k, s) => k -> lattice.sub(s, y.value, y.size) // TODO: is the size correct here? + }) + } + } + + // def widen(vs1: ValueSet[T], vs2: ValueSet[T]): ValueSet[T] = { + // (vs1, vs2) match { + // case (VSBottom, t) => ??? + // case (t, VSBottom) => ??? + // case (VSTop, _) => VSTop + // case (_, VSTop) => VSTop + // case (VS(m1), VS(m2)) => + // VS(m1.keys.foldLeft(m2) { + // case (acc, k) => + // val v1 = m1(k) + // val v2 = m2(k) + // acc + (k -> lattice.widen(v1, v2)) + // }) + // } + // } + + def removeLowerBounds(vs: ValueSet[T]): ValueSet[T] = { + vs match { + case VSBottom => VSBottom + case VSTop => VSTop + case VS(m) => + VS(m.map { + case (k, SI(s, l, u, w)) => k -> SI(s, lattice.lowestPossibleValue, u, w) + }) + } + } + + def removeUpperBound(vs: ValueSet[T]): ValueSet[T] = { + vs match { + case VSBottom => VSBottom + case VSTop => VSTop + case VS(m) => + VS(m.map { + case (k, SI(s, l, u, w)) => k -> SI(s, l, lattice.highestPossibleValue, w) + }) + } + } +} + +trait Bool3 + +case object BOTTOM_BOOL3 extends Bool3 { + override def toString = "BOTTOM" +} + +case object FALSE_BOOL3 extends Bool3 { + override def toString = "FALSE" +} + +case object TURE_BOOL3 extends Bool3 { + override def toString = "TRUE" +} + +case object MAYBE_BOOL3 extends Bool3 { + override def toString = "MAYBE" +} + +/** The lattice of booleans with the standard ordering. + */ +class Bool3Lattice extends Lattice[Bool3] { + + override val bottom: Bool3 = BOTTOM_BOOL3 + + override def top: Bool3 = MAYBE_BOOL3 + + override def lub(x: Bool3, y: Bool3): Bool3 = { + (x, y) match { + case (BOTTOM_BOOL3, t) => t + case (t, BOTTOM_BOOL3) => t + case (TURE_BOOL3, FALSE_BOOL3) => MAYBE_BOOL3 + case (FALSE_BOOL3, TURE_BOOL3) => MAYBE_BOOL3 + case _ => x + } + } +} + +enum Flags { + case CF // Carry Flag + case ZF // Zero Flag + case SF // Sign Flag + case PF // Parity Flag + case AF // Auxiliary Flag + case OF // Overflow Flag +} + +/** + * case CF // Carry Flag + * case ZF // Zero Flag + * case SF // Sign Flag + * case PF // Parity Flag + * case AF // Auxiliary Flag + * case OF // Overflow Flag + */ +trait Flag + +case object BOTTOM_Flag extends Flag { + override def toString = "BOTTOM_FLAG" +} + +case class FlagMap(m: Map[Flags, Bool3]) extends Flag { + override def toString: String = m.toString +} + + +/** The lattice of booleans with the standard ordering. + */ +class FlagLattice extends Lattice[Flag] { + + override val bottom: Flag = BOTTOM_Flag + + override def top: Flag = FlagMap(Map( + Flags.CF -> MAYBE_BOOL3, + Flags.ZF -> MAYBE_BOOL3, + Flags.SF -> MAYBE_BOOL3, + Flags.PF -> MAYBE_BOOL3, + Flags.AF -> MAYBE_BOOL3, + Flags.OF -> MAYBE_BOOL3 + )) + + val lattice: Bool3Lattice = Bool3Lattice() + + override def lub(x: Flag, y: Flag): Flag = { + (x, y) match { + case (BOTTOM_Flag, t) => t + case (t, BOTTOM_Flag) => t + case (FlagMap(m1), FlagMap(m2)) => + FlagMap(m1.keys.foldLeft(m2) { + case (acc, k) => + val v1 = m1(k) + val v2 = m2(k) + acc + (k -> lattice.lub(v1, v2)) + }) + } + } + + def setFlag(flag: Flags, value: Bool3): Flag = { + FlagMap(Map(flag -> value)) + } +} + /** The powerset lattice of a set of elements of type `A` with subset ordering. - */ + */ class PowersetLattice[A] extends Lattice[Set[A]] { val bottom: Set[A] = Set.empty def lub(x: Set[A], y: Set[A]): Set[A] = x.union(y) @@ -104,8 +627,8 @@ case object Top extends FlatElement[Nothing] case object Bottom extends FlatElement[Nothing] /** The flat lattice made of element of `X`. Top is greater than every other element, and Bottom is less than every - * other element. No additional ordering is defined. - */ + * other element. No additional ordering is defined. + */ class FlatLattice[X] extends Lattice[FlatElement[X]] { val bottom: FlatElement[X] = Bottom @@ -140,9 +663,8 @@ class TupleLattice[L1 <: Lattice[T1], L2 <: Lattice[T2], T1, T2](val lattice1: L override def top: (T1, T2) = (lattice1.top, lattice2.top) } - /** A lattice of maps from a set of elements of type `A` to a lattice with element `L'. Bottom is the default value. - */ + */ class MapLattice[A, T, +L <: Lattice[T]](val sublattice: L) extends Lattice[Map[A, T]] { val bottom: Map[A, T] = Map().withDefaultValue(sublattice.bottom) def lub(x: Map[A, T], y: Map[A, T]): Map[A, T] = @@ -150,8 +672,8 @@ class MapLattice[A, T, +L <: Lattice[T]](val sublattice: L) extends Lattice[Map[ } /** Constant propagation lattice. - * - */ + * + */ class ConstantPropagationLattice extends FlatLattice[BitVecLiteral] { private def apply(op: (BitVecLiteral, BitVecLiteral) => BitVecLiteral, a: FlatElement[BitVecLiteral], b: FlatElement[BitVecLiteral]): FlatElement[BitVecLiteral] = try { (a, b) match