Skip to content

Commit

Permalink
robustness and perf fixes (monad overflow)
Browse files Browse the repository at this point in the history
  • Loading branch information
ailrst committed Dec 5, 2024
1 parent 8dd5f0e commit 21efc16
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/ir/eval/ExprEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ trait Loader[S, E] {
}

def evaluateExpr(exp: Expr): Option[Literal] = {
val (e, _) = SimpExpr(fastPartialEvalExpr)(exp)
val (e, _) = simpFixedPoint(SimpExpr(fastPartialEvalExpr).apply)(exp)
e match {
case l: Literal => Some(l)
case _ => None
Expand Down
173 changes: 113 additions & 60 deletions src/main/scala/ir/eval/InterpretBasilIR.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package ir.eval
import ir.transforms.Substitute
import ir._
import ir.eval.BitVectorEval.*
import ir.*
Expand Down Expand Up @@ -63,6 +64,24 @@ case object Eval {
/* Eval functions */
/*--------------------------------------------------------------------------------*/

// def evalExpr[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Expr, InterpreterError] = {
// val ldr = StVarLoader[S, T](f)

// val varbs = e.variables.toList
// val vars = varbs.map(v => ldr.getVariable(v))

// for {
// vs <- State.mapM(ldr.getVariable, varbs)
// vvs = varbs.zip(vs).filter(_._2.isDefined).map(l => (l._1, l._2.get)).toMap
// substed = Substitute(vvs.get)(e).flatMap(evaluateExpr)
// res <- substed match {
// case Some(r) => State.pure(r)
// case None => State.pure(e)
// }
// } yield (res)
// }
//

def evalExpr[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Expr, InterpreterError] = {
val ldr = StVarLoader[S, T](f)
for {
Expand Down Expand Up @@ -176,30 +195,39 @@ case object Eval {
addr: BasilValue,
values: List[BasilValue],
endian: Endian
): State[S, Unit, InterpreterError] = for {
mem <- f.loadVar(vname)
x <- mem match {
case m @ BasilMapValue(_, MapType(kt, vt))
if Some(kt) == addr.irType && values.forall(v => v.irType == Some(vt)) =>
State.pure((m, kt, vt))
case v => State.setError((TypeError(s"Invalid map store operation to $vname : $v")))
}
(mapval, keytype, valtype) = x
keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until values.size))
vals = endian match {
case Endian.LittleEndian => values.reverse
case Endian.BigEndian => values
}
x <- f.storeMem(vname, keys.zip(vals).toMap)
} yield (x)
)(implicit
line: sourcecode.Line,
file: sourcecode.FileName,
name: sourcecode.Name
): State[S, Unit, InterpreterError] =
monlog.debug(s"store ${vname} ${values.size} bytes ${file.value}:${line.value}")
for {
mem <- f.loadVar(vname)
x <- mem match {
case m @ BasilMapValue(_, MapType(kt, vt))
if Some(kt) == addr.irType && values.forall(v => v.irType == Some(vt)) =>
State.pure((m, kt, vt))
case v => State.setError((TypeError(s"Invalid map store operation to $vname : $v")))
}
(mapval, keytype, valtype) = x
keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until values.size))
vals = endian match {
case Endian.LittleEndian => values.reverse
case Endian.BigEndian => values
}
x <- f.storeMem(vname, keys.zip(vals).toMap)
} yield (x)

/** Extract bitvec to bytes and store bytes */
def storeBV[S, T <: Effects[S, InterpreterError]](f: T)(
vname: String,
addr: BasilValue,
value: BitVecLiteral,
endian: Endian
): State[S, Unit, InterpreterError] = for {
)(implicit line: sourcecode.Line, file: sourcecode.FileName, name: sourcecode.Name): State[S, Unit, InterpreterError] =

monlog.debug(s"storeBV ${vname} ${size(value).get / 8} bytes")(line, file, name)
for {
mem <- f.loadVar(vname)
mr <- mem match {
case m @ BasilMapValue(_, MapType(kt, BitVecType(size))) if Some(kt) == addr.irType => State.pure((m, size))
Expand Down Expand Up @@ -395,11 +423,11 @@ class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S
dc.actualParams
)
_ <- {
if (dc.target.entryBlock.isDefined) {
if (LibcIntrinsic.intrinsics.contains(dc.target.procName)) {
f.call(dc.target.name, Intrinsic(dc.target.procName), ReturnTo(dc))
} else if (dc.target.entryBlock.isDefined) {
val block = dc.target.entryBlock.get
f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), ReturnTo(dc))
} else if (LibcIntrinsic.intrinsics.contains(dc.target.name)) {
f.call(dc.target.name, Intrinsic(dc.target.name), ReturnTo(dc))
} else {
State.setError(EscapedControlFlow(dc))
}
Expand Down Expand Up @@ -519,75 +547,100 @@ object InterpFuns {
l <- s.storeVar("R30_in", Scope.Global, Scalar(LR))
l <- s.storeVar("R0", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(0, 64)))
/** callee saved **/
l <- s.storeVar("R19", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R20", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R21", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R22", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R23", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R24", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R25", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R26", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R27", Scope.Global, Scalar(BitVecLiteral(0, 64)))
l <- s.storeVar("R28", Scope.Global, Scalar(BitVecLiteral(0, 64)))
/** end callee saved **/
_ <- s.storeVar("ghost-funtable", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(64))))
_ <- IntrinsicImpl.initFileGhostRegions(s)
} yield (l)
}

def initialiseProgram[S, T <: Effects[S, InterpreterError]](f: T)(p: Program): State[S, Unit, InterpreterError] = {
def initMemory(mem: String, mems: Iterable[MemorySection]) = {
for {
m <- State.sequence(
State.pure(()),
mems
.filter(m => m.address != 0 && m.bytes.size != 0)
.map(memory =>
Eval.store(f)(
mem,
Scalar(BitVecLiteral(memory.address, 64)),
memory.bytes.toList.map(Scalar(_)),
Endian.BigEndian
)
def initialiseProgram[S, T <: Effects[S, InterpreterError]](
f: T
)(is: S, p: Program): S = {

def initMemory(is: S, mem: String, mems: Iterable[MemorySection]) : S = {
var s = is
for (memory <- mems.filter(m => m.address != 0 && m.bytes.size != 0)) {
val bytes = memory.bytes.toList.map(Scalar(_))
val addrs = memory.address until memory.address + bytes.size
for (store <- addrs.zip(bytes)) {
val (addr,value) = store
s = State.execute(
s,
Eval.store(f)(
mem,
Scalar(BitVecLiteral(addr, 64)),
List(value),
Endian.BigEndian
)
)
} yield ()
)
}
}
s
}

for {
d <- initialState(f)
funs <- State.sequence(
State.pure(Logger.debug("INITIALISE FUNCTION ADDRESSES")),
p.procedures
.filter(p => p.blocks.nonEmpty && p.address.isDefined)
.map((proc: Procedure) =>
Eval.storeSingle(f)(
var s = State.execute(is, initialState(f))

for (proc <- p.procedures.filter(p => p.blocks.nonEmpty && p.address.isDefined)) {
s = State.execute(s, Eval.storeSingle(f)(
"ghost-funtable",
Scalar(BitVecLiteral(proc.address.get, 64)),
FunPointer(BitVecLiteral(proc.address.get, 64), proc.name, Run(IRWalk.firstInBlock(proc.entryBlock.get)))
)
)
)
_ <- State.pure(Logger.debug("INITIALISE MEMORY SECTIONS"))
mem <- initMemory("mem", p.initialMemory.values)
mem <- initMemory("stack", p.initialMemory.values)
mainfun = p.mainProcedure
r <- f.call("init_activation", Stopped(), Stopped()) // frame for main to return to
r <- f.call(mainfun.name, Run(IRWalk.firstInBlock(mainfun.entryBlock.get)), Stopped())
l <- State.sequence(State.pure(()), mainfun.formalInParam.toList.map(i => f.storeVar(i.name, i.toBoogie.scope, Scalar(BitVecLiteral(0, size(i).get)))))
} yield (r)
}

val mainfun = p.mainProcedure
s = State.execute(s, f.call("init_activation", Stopped(), Stopped()))
s = initMemory(s, "mem", p.initialMemory.values)
s = initMemory(s, "stack", p.initialMemory.values)
s = State.execute(s, f.call(mainfun.name, Run(IRWalk.firstInBlock(mainfun.entryBlock.get)), Stopped()))
// l <- State.sequence(State.pure(()), mainfun.formalInParam.toList.map(i => f.storeVar(i.name, i.toBoogie.scope, Scalar(BitVecLiteral(0, size(i).get)))))
s
}

def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext): State[S, Unit, InterpreterError] = {
def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(is: S, p: IRContext): S = {
val bss = for {
first <- p.symbols.find(s => s.name == "__bss_start__").map(_.value)
last <- p.symbols.find(s => s.name == "__bss_end__").map(_.value)
r <- (if (first == last) then None else Some((first, (last - first) * 8)))
(addr, sz) = r
st = {
(rgn => Eval.storeBV(f)(rgn, Scalar(BitVecLiteral(addr, 64)), BitVecLiteral(0, sz.toInt), Endian.LittleEndian))
var s = is
for (rgn <- Seq("mem", "stack")) {
for (addr <- (first until last)) {
s = State.execute(s,
Eval.storeBV(f)(rgn, Scalar(BitVecLiteral(addr, 64)), BitVecLiteral(0, 8), Endian.LittleEndian)
)
}
}
s
}

} yield (st)


bss match {
case None => Logger.error("No BSS initialised"); State.pure(())
case Some(init) => init("mem") >> init("stack")
case None => Logger.error("No BSS initialised"); is
case Some(init) => init
}
}

def initProgState[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = {
val st = (initialiseProgram(f)(p.program) >> initBSS(f)(p))
>> InterpFuns.initRelocTable(f)(p)
var ist = initialiseProgram(f)(is, p.program)
ist = initBSS(f)(ist, p)
ist = State.execute(ist, InterpFuns.initRelocTable(f)(p))
val st = State.putS(ist)
val (fs, v) = st.f(is)
v match {
case Right(r) => fs
Expand All @@ -604,7 +657,7 @@ object InterpFuns {

/* Interpret IR program */
def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = {
val begin = State.execute(is, initialiseProgram(f)(p))
val begin = initialiseProgram(f)(is, p)
val interp = BASILInterpreter(f)
interp.run(begin)
}
Expand Down
33 changes: 33 additions & 0 deletions src/main/scala/ir/eval/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,14 @@ object LibcIntrinsic {
_ <- s.doReturn()
} yield (())

def twoArg[S, E, T <: Effects[S, E]](name: String)(s: T): State[S, Unit, E] = for {
c1 <- s.loadVar("R0")
c2 <- s.loadVar("R1")
res <- s.callIntrinsic(name, List(c1, c2))
_ <- if res.isDefined then s.storeVar("R0", Scope.Global, res.get) else State.pure(())
_ <- s.doReturn()
} yield (())

def calloc[S, T <: Effects[S, InterpreterError]](s: T): State[S, Unit, InterpreterError] = for {
size <- s.loadVar("R0")
res <- s.callIntrinsic("malloc", List(size))
Expand All @@ -340,9 +348,12 @@ object LibcIntrinsic {
def intrinsics[S, T <: Effects[S, InterpreterError]] =
Map[String, T => State[S, Unit, InterpreterError]](
"putc" -> singleArg("putc"),
"putchar" -> singleArg("putc"),
"puts" -> singleArg("puts"),
"printf" -> singleArg("print"),
"write" -> twoArg("write"),
"malloc" -> singleArg("malloc"),
"__libc_malloc_impl" -> singleArg("malloc"),
"free" -> singleArg("free"),
"#free" -> singleArg("free"),
"calloc" -> calloc
Expand Down Expand Up @@ -392,6 +403,27 @@ object IntrinsicImpl {
} yield (Some(filecount.head))
}


def write[S, T <: Effects[S, InterpreterError]](f: T)(fd: BasilValue, strptr: BasilValue): State[S, Option[BasilValue], InterpreterError] = {
for {
str <- Eval.getNullTerminatedString(f)("mem", strptr)
// TODO: fd mapping in state
file = fd match {
case Scalar(BitVecLiteral(1, 64)) => "stdout"
case Scalar(BitVecLiteral(2, 64)) => "stderr"
case _ => "unknown"
}
baseptr: List[BasilValue] <- f.loadMem("ghost-file-bookkeeping", List(Symbol(s"${file}-ptr")))
offs: List[BasilValue] <- State.mapM(
((i: Int) => State.pureE(BasilValue.unsafeAdd(baseptr.head, i))),
(0 until (str.size + 1))
)
_ <- f.storeMem(file, offs.zip(str.map(Scalar(_))).toMap)
naddr <- State.pureE(BasilValue.unsafeAdd(baseptr.head, str.size))
_ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol(s"${file}-ptr") -> naddr))
} yield (None)
}

def print[S, T <: Effects[S, InterpreterError]](f: T)(strptr: BasilValue): State[S, Option[BasilValue], InterpreterError] = {
for {
str <- Eval.getNullTerminatedString(f)("mem", strptr)
Expand Down Expand Up @@ -450,6 +482,7 @@ object NormalInterpreter extends Effects[InterpreterState, InterpreterError] {
} yield (Some(r))
case "print" => IntrinsicImpl.print(this)(args.head)
case "puts" => IntrinsicImpl.print(this)(args.head) >> IntrinsicImpl.putc(this)(Scalar(BitVecLiteral('\n'.toInt, 64)))
case "write" => IntrinsicImpl.write(this)(args(1), args(2))
case _ => State.setError(Errored(s"Call undefined intrinsic $name"))
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/main/scala/util/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GenericLogger(
if (level.id < LogLevel.OFF.id) {
val l = deriveLogger(file.getName(), file)
l.print(content)
close()
l.close()
children.remove(l)
}
}
Expand Down Expand Up @@ -159,7 +159,6 @@ val StaticAnalysisLogger = Logger.deriveLogger("analysis", System.out)
val SimplifyLogger = Logger.deriveLogger("simplify", System.out)
val DebugDumpIRLogger = {
val l = Logger.deriveLogger("debugdumpir")
l.setLevel(LogLevel.OFF)
l
}
val VSALogger = StaticAnalysisLogger.deriveLogger("vsa")
Expand Down
Loading

0 comments on commit 21efc16

Please sign in to comment.