diff --git a/build.sbt b/build.sbt index f1e6339ff..bef063b83 100644 --- a/build.sbt +++ b/build.sbt @@ -29,7 +29,7 @@ lazy val root = project libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test ) -scalacOptions ++= Seq("-deprecation", "-feature") +scalacOptions ++= Seq("-deprecation", "-unchecked", "-feature") Compile / PB.targets := Seq( scalapb.gen() -> (Compile / sourceManaged).value / "scalapb" diff --git a/build.sc b/build.sc index 7086685b9..1bd07adc2 100644 --- a/build.sc +++ b/build.sc @@ -22,6 +22,7 @@ object basil extends RootModule with ScalaModule with antlr.AntlrModule with Sca def scalaPBVersion = "0.11.15" + def scalacOptions = Seq("-deprecation", "-unchecked", "-feature") def mainClass = Some("Main") diff --git a/docs/development/interpreter.md b/docs/development/interpreter.md new file mode 100644 index 000000000..a2854bb73 --- /dev/null +++ b/docs/development/interpreter.md @@ -0,0 +1,203 @@ +# BASIL IR Interpreter + +The interpreter is designed for testing, debugging, and validation of static analyses and code transforms. +This page describes first how it can be used for this purpose, and secondly its design. + +## Basic Usage + +The interpreter can be invoked from the command line, via the interpret flag, by default this prints a trace and checks that the interpreter +exited on a non-error stop state. + +```shell +./mill run -i src/test/correct/indirect_call/gcc/indirect_call.adt -r src/test/correct/indirect_call/gcc/indirect_call.relf --interpret +[INFO] Interpreter Trace: + StoreVar(#5,Local,0xfd0:bv64) + StoreMem(mem,HashMap(0xfd0:bv64 -> 0xf0:bv8, 0xfd6:bv64 -> 0x0:bv8, 0xfd2:bv64 -> 0x0:bv8, 0xfd3:bv64 -> 0x0:bv8, 0xfd4:bv64 -> 0x0:bv8, 0xfd7:bv64 -> 0x0:bv8, 0xfd5:bv64 -> 0x0:bv8, 0xfd1 +... +[INFO] Interpreter stopped normally. +``` + +The `--verbose` flag can also be used, which may print interpreter trace events as they are executed, but not this may not correspond to the actual +execution trace, and contain additional events not corresponding to the program. +E.g. this shows the memory intialisation events that precede the program execution. This is mainly useful for debugging the interpreter. + +### Testing with Interpreter + +The interpreter is invoked with `interpret(p: IRContext)` to interpret normally and return an `InterpreterState` object +containing the final state. + +#### Traces + +There is also, `interpretTrace(p: IRContext)` which returns a tuple of `(InterpreterState, Trace(t: List[ExecEffect]))`, +where the second argument contains a list of all the events generated by the interpreter in order. +This is useful for asserting a stronger equivalence between program executions, but in most cases events describing "unobservable" +behaviour, such as register accesses should be filtered out from this list before comparison. + +To see an example of this used to validate the constant prop analysis see [/src/test/scala/DifferentialAnalysis.scala](../../src/test/scala/DifferentialAnalysis.scala). + +#### BreakPoints + +Finally `interpretBreakPoints(p: IRContext, breakpoints: List[BreakPoint])` is used to +run an interpreter and perform additional actions at specified code points. For example, this may be invoked such as: + +```scala +val watch = IRWalk.firstInProc((program.procedures.find(_.name == "main")).get).get +val bp = BreakPoint("entrypoint", BreakPointLoc.CMD(watch), BreakPointAction(saveState=true, stop=true, evalExprs=List(("R0", Register("R0", 64))), log=true)) +val res = interpretBreakPoints(program, List(bp)) +``` + +The structure of a breakpoint is as follows: + +```scala +case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) + +// the place to perform the breakpoint action +enum BreakPointLoc: + case CMD(c: Command) // at a command c + case CMDCond(c: Command, condition: Expr) // at a command c, when condition evaluates to TrueLiteral + +// describes what to do when the breakpoint is triggered +case class BreakPointAction( + saveState: Boolean = true, // stash the state of the interpreter + stop: Boolean = false, // stop the interpreter with an error state + evalExprs: List[(String, Expr)] = List(), // Evaluate the rhs of the list of expressions, and stash them (lhs is an arbitrary human-readable name) + log: Boolean = false // Print a log message about passing the breakpoint describing the results of this action +) +``` + +To see an example of this used to validate the constant prop analysis see [/src/test/scala/InterpretTestConstProp.scala](../../src/test/scala/InterpretTestConstProp.scala). + +### Resource Limit + +This kills the interpreter in an error state once a specified instruction count is reached, to avoid the interpreter running forever on infinite loops. + +It can be used simply with the function `interptretRLimit`, this automatically ignores the initialisation instructions. + +```scala +def interpretRLimit(p: IRContext, instructionLimit: Int) : InterpreterState +``` + +It can also be combined with other interpreters as shown: + +```scala +def interp(p: IRContext, instructionLimit: Int) : (InterpreterState, Trace) = { + val interpreter = LayerInterpreter(tracingInterpreter(NormalInterpreter), EffectsRLimit(instructionLimit)) + val initialState = InterpFuns.initProgState(NormalInterpreter)(p, InterpreterState()) + BASILInterpreter(interpreter).run((initialState, Trace(List())), 0)._1 +} +``` + +## Implementation / Code Structure + +### Summary + +- [Bitvector.scala](../../src/main/scala/ir/eval/Bitvector.scala) + - Evaluation of bitvector operations, throws `IllegalArgumentException` on violation of contract + (e.g negative divisor, type mismatch) +- [ExprEval.scala](../../src/main/scala/ir/eval/ExprEval.scala) + - Evaluation of expressions, defined in terms of partial evaluation down to a Literal + - This can also be used to evaluate expressions in static analyses, by passing a function to query variable assignments and memory state from the value domain. +- [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala) + - Definition of core `Effects[S, E]` and `Interpreter[S, E]` types describing state transitions in + the interpreter + - Instantiation/definition of `Effects` for concrete state `InterpreterState` +- [InterpreterProduct.scala](../../src/main/scala/ir/eval/InterpreterProduct.scala) + - Definition of product and layering composition of generic `Effects[S, E]`s interpreters +- [InterpretBasilIR.scala](../../src/main/scala/ir/eval/InterpretBasilIR.scala) + - Definition of `Eval` object defining expression evaluation in terms of `Effects[S, InterpreterError]` + - Definition of `Interpreter` instance for BASIL IR, using a generic `Effects` instance and concrete state. + - Definition of ELF initialisation in terms of generic `Effects[S, InterpreterError]` +- [InterpretBreakpoints.scala](../../src/main/scala/ir/eval/InterpretBreakpoints.scala) + - Definition of a generic interpreter with a breakpoint checker layered on top +- [interpretRLimit.scala](../../src/main/scala/ir/eval/InterpretRLimit.scala) + - Definition of layered interpreter which terminates after a specified cycle count +- [InterpretTrace.scala](../../src/main/scala/ir/eval/InterpretTrace.scala) + - Definition of a generic interpreter which records a trace of calls to the `Effects[]` instance. + +### Explanation + +The interpreter is structured for compositionality, at its core is the `Effects[S, E]` type, defined in [Interpreter.scala](../../src/main/scala/ir/eval/Interpreter.scala). +This type defines a small set of functions which describe all the possible state transformations, over a concrete state `S`, and error type `E` (always `InterpreterError` in practice). + +This is implemented using the state Monad, `State[S,V,E]` where `S` is the state, `V` the value, and `E` the error type. +This is a flattened `State[S, Either[E]]`, defined in [util/functional.scala](../../src/main/scala/util/functional.scala). +`Effects` methods return delayed computations, functions from an input state (`S`) to a resulting state and a value (`(S, Either[E, V])`). +These are sequenced using `flatMap` (monad bind), or the `for{} yield()` syntax sugar for flatMap. + +This `Effects[S, E]` is instantiated for a given concrete state, the main example of which is `NormalInterpreter <: Effects[InterpreterState, InterpreterError]`, +also defined in `Interpreter.scala`. The memory component of the state is abstracted further into the `MemoryState` object. + +The actual execution of code is defined on top of this, in the `Interpreter[S, E]` type, which takes an instance of the `Effects` by parameter, +and defines both the small step (`interpretOne`) over on instruction, and the fixed point to termination from some in initial state in `run()`. +The fact that the stepping is defined outside the effects is important, as it allows concrete states, and state transitions over them to be +composed somewhat arbitrarily, and the interpretatation of the language compiled down to calls to resulting instance of `Effects`. + +This is defined in [InterpretBasilIR.scala](../../src/main/scala/ir/eval/InterpretBasilIR.scala). `BASILInterpreter` defines an +`Interpreter` over an arbitrary instance of `Effects[S, InterpreterError]`, encoding BASIL IR commands as effects. +This file also contains definitions of the initial memory state setup of the interpreter, based on the ELF sections and symbol table. + +### Composition of interpreters + +There are two ways to compose `Effects`, product and layer. Both produce an instance of `Effects[(L, R), E]`, +where `L` and `R` are the concrete state types of the two Effects being composed. + +Product runs the two effects, over two different concrete state types, simultaneously without interaction. + +Layer runs the `before` effect first, and passes its state to the `inner` effect whose value is returned. + +```scala +case class ProductInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[T, E]) extends Effects[(L, T), E] { +case class LayerInterpreter[L, T, E](val inner: Effects[L, E], val before: Effects[(L, T), E]) +``` + +Examples of using these are in the `interpretTrace` and `interpretWithBreakPoints` interpreters respectively. + +Note, this only works by the aforementioned requirement that all effect calls come from outside the `Effects[]` +instance itself. In the simple case, the `Interpreter` instance is the only object calling `Effects`. +This means, `Effects` triggered by an inner `Effects[]` instance do not flow back to the `ProductInterpreter`, +but only appear from when `Interpreter` above the `ProductInterpreter` interprets the program via effect calls. +For this reason if, for example, `NormalInterpreter` makes effect calls they will not appear in a trace emitted by `interptretTrace`. + +### Note on memory space initialisation + +Most of the interpret functions are overloaded such that there is a version taking a program `interpret(p: Program)`, +and a version taking `IRContext`. The variant taking IRContext uses the ELF symbol information to initialise the +memory before interpretation. If you are interpreting a real program (i.e. not a synthetic example created through +the DSL), this is most likely required. + +We initialise: + +- The general interpreter state, stack and memory regions, stack pointer, a symbolic mapping from addresses functions +- The initial and readonly memory sections stored in Program +- The `.bss` section to zero +- The relocation table. Each listed offset is stored an address to either a real procedure in the program, or a + location storing a symbolic function pointer to an intrinsic function. + +`.bss` is generally the top of the initialised data, the ELF symbol `__bss_end__` being equal to the symbol `__end__`. +Above this we can somewhat choose arbitrarily where to put things, usually the heap is above, followed by +dynamically linked symbols, then the stack. There is currently no stack overflow checking, or heap implemented in the +interpreter. + +Unfortunately these details are defined by the load-time linker and the system's linker script, and it is hard to find a good description +of their behaviour. Some details are described here https://refspecs.linuxfoundation.org/elf/elf.pdf, and here +https://dl.acm.org/doi/abs/10.1145/2983990.2983996. + +### Missing features + +- There is functionality to implement external function calls via intrinsics written in Scala code, but currently only + basic printf style functions are implemented as no-ops. These can be extended to use a file IO abstraction, where + a memory region is created for each file (e.g. stdout), with a variable to keep track of the current write-point + such that a file write operation stores to the write-point address, and increments it by the size of the store. + Importantly, an implementation of malloc() and free() is needed, which can implement a simple greedy allocation + algorithm. +- Despite the presence of procedure parameters in the current IR, they are not used for by the boogie translation and + are hence similarly ignored in the interpreter. +- The interpreter's immutable state representation is motivated by the ability to easily implement a sound approach + to non-determinism, e.g. to implement GoTos with guessing and rollback rather than look-ahead. This is more + useful for checking specification constructs than executing real programs, so is not yet implemented. +- The trace does not clearly distinguish internal vs external calls, or observable + and non-observable behaviour. +- While the interpreter semantics supports memory regions, we do not initialise the memory regions (or the initial memory state) + based on those present in the program, we simply assume a flat `mem` and `stack` memory partitioning. + + diff --git a/docs/development/readme.md b/docs/development/readme.md index 4e220c162..b92fbeaab 100644 --- a/docs/development/readme.md +++ b/docs/development/readme.md @@ -5,6 +5,7 @@ - [tool-installation](tool-installation.md) Guide to lifter, etc. tool installation - [scala](scala.md) Advice on Scala programming. - [cfg](cfg.md) Explanation of the old CFG datastructure +- [interpreter](interpreter.md) Explanation of IR interpreter ## Scala diff --git a/docs/readme.md b/docs/readme.md index f7b331b23..3bb6905aa 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -12,6 +12,7 @@ To get started on development, see [development](development). - [editor-setup](development/editor-setup.md) Guide to basil development in IDEs - [tool-installation](development/tool-installation.md) Guide to lifter, etc. tool installation - [cfg](development/cfg.md) Explanation of the old CFG datastructure + - [interpreter](development/interpreter.md) Explanation of IR interpreter - [basil-ir](basil-ir.md) explanation of BASIL's intermediate representation - [compiler-explorer](compiler-explorer.md) guide to the compiler explorer basil interface - [il-cfg](il-cfg.md) explanation of the IL cfg iterator design diff --git a/src/main/scala/analysis/Lattice.scala b/src/main/scala/analysis/Lattice.scala index 0ef98020f..cb499f8c3 100644 --- a/src/main/scala/analysis/Lattice.scala +++ b/src/main/scala/analysis/Lattice.scala @@ -1,7 +1,7 @@ package analysis import ir._ -import analysis.BitVectorEval +import ir.eval.BitVectorEval import util.Logger /** Basic lattice @@ -244,4 +244,4 @@ class ConstantPropagationLatticeWithSSA extends PowersetLattice[BitVecLiteral] { apply(BitVectorEval.boogie_extract(high, low, _: BitVecLiteral), a) def concat(a: Set[BitVecLiteral], b: Set[BitVecLiteral]): Set[BitVecLiteral] = apply(BitVectorEval.smt_concat, a, b) -} \ No newline at end of file +} diff --git a/src/main/scala/analysis/MemoryRegionAnalysis.scala b/src/main/scala/analysis/MemoryRegionAnalysis.scala index 5f8e07560..396df7523 100644 --- a/src/main/scala/analysis/MemoryRegionAnalysis.scala +++ b/src/main/scala/analysis/MemoryRegionAnalysis.scala @@ -1,6 +1,6 @@ package analysis -import analysis.BitVectorEval.isNegative +import ir.eval.BitVectorEval.isNegative import analysis.solvers.SimpleWorklistFixpointSolver import ir.* import util.Logger diff --git a/src/main/scala/analysis/RegionInjector.scala b/src/main/scala/analysis/RegionInjector.scala index 3cc414f06..09bf8682e 100644 --- a/src/main/scala/analysis/RegionInjector.scala +++ b/src/main/scala/analysis/RegionInjector.scala @@ -1,6 +1,6 @@ package analysis -import analysis.BitVectorEval.isNegative +import ir.eval.BitVectorEval.isNegative import ir.* import util.Logger diff --git a/src/main/scala/analysis/UtilMethods.scala b/src/main/scala/analysis/UtilMethods.scala index 2f74b65e8..ae9cbdbbe 100644 --- a/src/main/scala/analysis/UtilMethods.scala +++ b/src/main/scala/analysis/UtilMethods.scala @@ -1,6 +1,7 @@ package analysis import ir.* import util.Logger +import ir.eval.BitVectorEval /** Evaluate an expression in a hope of finding a global variable. * @@ -12,62 +13,14 @@ import util.Logger * The evaluated expression (e.g. 0x69632) */ def evaluateExpression(exp: Expr, constantPropResult: Map[Variable, FlatElement[BitVecLiteral]]): Option[BitVecLiteral] = { - exp match { - case binOp: BinaryExpr => - val lhs = evaluateExpression(binOp.arg1, constantPropResult) - val rhs = evaluateExpression(binOp.arg2, constantPropResult) + def value(v: Variable) = constantPropResult(v) match { + case FlatEl(value) => Some(value) + case _ => None + } - (lhs, rhs) match { - case (Some(l: BitVecLiteral), Some(r: BitVecLiteral)) => - val result = binOp.op match { - case BVADD => BitVectorEval.smt_bvadd(l, r) - case BVSUB => BitVectorEval.smt_bvsub(l, r) - case BVMUL => BitVectorEval.smt_bvmul(l, r) - case BVUDIV => BitVectorEval.smt_bvudiv(l, r) - case BVSDIV => BitVectorEval.smt_bvsdiv(l, r) - case BVSREM => BitVectorEval.smt_bvsrem(l, r) - case BVUREM => BitVectorEval.smt_bvurem(l, r) - case BVSMOD => BitVectorEval.smt_bvsmod(l, r) - case BVAND => BitVectorEval.smt_bvand(l, r) - case BVOR => BitVectorEval.smt_bvxor(l, r) - case BVXOR => BitVectorEval.smt_bvxor(l, r) - case BVNAND => BitVectorEval.smt_bvnand(l, r) - case BVNOR => BitVectorEval.smt_bvnor(l, r) - case BVXNOR => BitVectorEval.smt_bvxnor(l, r) - case BVSHL => BitVectorEval.smt_bvshl(l, r) - case BVLSHR => BitVectorEval.smt_bvlshr(l, r) - case BVASHR => BitVectorEval.smt_bvashr(l, r) - case BVCOMP => BitVectorEval.smt_bvcomp(l, r) - case BVCONCAT => BitVectorEval.smt_concat(l, r) - case x => throw RuntimeException("Binary operation support not implemented: " + binOp.op) - } - Some(result) - case _ => None - } - case extend: ZeroExtend => - evaluateExpression(extend.body, constantPropResult) match { - case Some(b: BitVecLiteral) => Some(BitVectorEval.smt_zero_extend(extend.extension, b)) - case None => None - } - case extend: SignExtend => - evaluateExpression(extend.body, constantPropResult) match { - case Some(b: BitVecLiteral) => Some(BitVectorEval.smt_sign_extend(extend.extension, b)) - case None => None - } - case e: Extract => - evaluateExpression(e.body, constantPropResult) match { - case Some(b: BitVecLiteral) => Some(BitVectorEval.boogie_extract(e.end, e.start, b)) - case None => None - } - case variable: Variable => - constantPropResult(variable) match { - case FlatEl(value) => Some(value) - case Top => None - case Bottom => None - } - case b: BitVecLiteral => Some(b) - case _ => //throw new RuntimeException("ERROR: CASE NOT HANDLED: " + exp + "\n") - None + ir.eval.evalBVExpr(exp, value) match { + case Right(v) => Some(v) + case Left(_) => None } } diff --git a/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala b/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala index 9d9dac1f7..5c0450c78 100644 --- a/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala +++ b/src/main/scala/analysis/data_structure_analysis/LocalPhase.scala @@ -1,6 +1,6 @@ package analysis.data_structure_analysis -import analysis.BitVectorEval.{bv2SignedInt, isNegative} +import ir.eval.BitVectorEval.{bv2SignedInt, isNegative} import analysis.* import ir.* import specification.{ExternalFunction, SpecGlobal, SymbolTableEntry} diff --git a/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala b/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala index 52cb23639..773319b32 100644 --- a/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala +++ b/src/main/scala/analysis/data_structure_analysis/SymbolicAddressAnalysis.scala @@ -1,6 +1,6 @@ package analysis.data_structure_analysis -import analysis.BitVectorEval.{bv2SignedInt, isNegative} +import ir.eval.BitVectorEval.{bv2SignedInt, isNegative} import analysis.solvers.ForwardIDESolver import analysis.* import ir.* diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index 03579c2cf..945ffcfb7 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -16,18 +16,22 @@ sealed trait Literal extends Expr { override def acceptVisit(visitor: Visitor): Literal = visitor.visitLiteral(this) } -sealed trait BoolLit extends Literal +sealed trait BoolLit extends Literal { + def value: Boolean +} case object TrueLiteral extends BoolLit { override def toBoogie: BoolBLiteral = TrueBLiteral override def getType: IRType = BoolType override def toString: String = "true" + override def value = true } case object FalseLiteral extends BoolLit { override def toBoogie: BoolBLiteral = FalseBLiteral override def getType: IRType = BoolType override def toString: String = "false" + override def value = false } case class BitVecLiteral(value: BigInt, size: Int) extends Literal { diff --git a/src/main/scala/ir/Interpreter.scala b/src/main/scala/ir/Interpreter.scala deleted file mode 100644 index 53ef40c2d..000000000 --- a/src/main/scala/ir/Interpreter.scala +++ /dev/null @@ -1,352 +0,0 @@ -package ir -import analysis.BitVectorEval.* -import util.Logger - -import scala.collection.mutable -import scala.util.control.Breaks.{break, breakable} - -class Interpreter() { - val regs: mutable.Map[Variable, BitVecLiteral] = mutable.Map() - val mems: mutable.Map[BigInt, BitVecLiteral] = mutable.Map() - private val SP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - private val FP: BitVecLiteral = BitVecLiteral(4096 - 16, 64) - private val LR: BitVecLiteral = BitVecLiteral(BigInt("FF", 16), 64) - private var nextCmd: Option[Command] = None - private val returnCmd: mutable.Stack[Command] = mutable.Stack() - - def eval(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BitVecLiteral = { - exp match { - case id: Variable => - env.get(id) match { - case Some(value) => - Logger.debug(s"\t${id.name} == 0x${value.value.toString(16)}[u${value.size}]") - value - case _ => throw new Exception(s"$id not found in env") - } - - case n: Literal => - n match { - case bv: BitVecLiteral => - Logger.debug(s"\tBitVecLiteral(0x${bv.value.toString(16)}[u${bv.size}])") - bv - case _ => ??? - } - - case ze: ZeroExtend => - Logger.debug(s"\t$ze") - smt_zero_extend(ze.extension, eval(ze.body, env)) - - case se: SignExtend => - Logger.debug(s"\t$se") - smt_sign_extend(se.extension, eval(se.body, env)) - - case e: Extract => - Logger.debug(s"\tExtract($e, ${e.start}, ${e.end})") - boogie_extract(e.end, e.start, eval(e.body, env)) - - case r: Repeat => - Logger.debug(s"\t$r") - ??? // TODO - - case bin: BinaryExpr => - val left = eval(bin.arg1, env) - val right = eval(bin.arg2, env) - Logger.debug( - s"\tBinaryExpr(0x${left.value.toString(16)}[u${left.size}] ${bin.op} 0x${right.value.toString(16)}[u${right.size}])" - ) - bin.op match { - case BVAND => smt_bvand(left, right) - case BVOR => smt_bvor(left, right) - case BVADD => smt_bvadd(left, right) - case BVMUL => smt_bvmul(left, right) - case BVUDIV => smt_bvudiv(left, right) - case BVUREM => smt_bvurem(left, right) - case BVSHL => smt_bvshl(left, right) - case BVLSHR => smt_bvlshr(left, right) - case BVNAND => smt_bvnand(left, right) - case BVNOR => smt_bvnor(left, right) - case BVXOR => smt_bvxor(left, right) - case BVXNOR => smt_bvxnor(left, right) - case BVCOMP => smt_bvcomp(left, right) - case BVSUB => smt_bvsub(left, right) - case BVSDIV => smt_bvsdiv(left, right) - case BVSREM => smt_bvsrem(left, right) - case BVSMOD => smt_bvsmod(left, right) - case BVASHR => smt_bvashr(left, right) - case BVCONCAT => smt_concat(left, right) - case _ => ??? - } - - case un: UnaryExpr => - val arg = eval(un.arg, env) - Logger.debug(s"\tUnaryExpr($un)") - un.op match { - case BVNEG => smt_bvneg(arg) - case BVNOT => smt_bvnot(arg) - } - - case ml: MemoryLoad => - Logger.debug(s"\t$ml") - val index: Int = eval(ml.index, env).value.toInt - getMemory(index, ml.size, ml.endian, mems) - - case u: UninterpretedFunction => - Logger.debug(s"\t$u") - ??? - } - } - - def evalBool(exp: Expr, env: mutable.Map[Variable, BitVecLiteral]): BoolLit = { - exp match { - case n: BoolLit => n - case bin: BinaryExpr => - bin.op match { - case b: BoolBinOp => - val arg1 = evalBool(bin.arg1, env) - val arg2 = evalBool(bin.arg2, env) - b match { - case BoolEQ => - if (arg1 == arg2) { - TrueLiteral - } else { - FalseLiteral - } - case BoolNEQ => - if (arg1 != arg2) { - TrueLiteral - } else { - FalseLiteral - } - case BoolAND => - (arg1, arg2) match { - case (TrueLiteral, TrueLiteral) => TrueLiteral - case _ => FalseLiteral - } - case BoolOR => - (arg1, arg2) match { - case (FalseLiteral, FalseLiteral) => FalseLiteral - case _ => TrueLiteral - } - case BoolIMPLIES => - (arg1, arg2) match { - case (TrueLiteral, FalseLiteral) => FalseLiteral - case _ => TrueLiteral - } - case BoolEQUIV => - if (arg1 == arg2) { - TrueLiteral - } else { - FalseLiteral - } - } - case b: BVBinOp => - val left = eval(bin.arg1, env) - val right = eval(bin.arg2, env) - b match { - case BVULT => smt_bvult(left, right) - case BVULE => smt_bvule(left, right) - case BVUGT => smt_bvugt(left, right) - case BVUGE => smt_bvuge(left, right) - case BVSLT => smt_bvslt(left, right) - case BVSLE => smt_bvsle(left, right) - case BVSGT => smt_bvsgt(left, right) - case BVSGE => smt_bvsge(left, right) - case BVEQ => smt_bveq(left, right) - case BVNEQ => smt_bvneq(left, right) - case _ => ??? - } - case _ => ??? - } - - case un: UnaryExpr => - un.op match { - case BoolNOT => if evalBool(un.arg, env) == TrueLiteral then FalseLiteral else TrueLiteral - case _ => ??? - } - case _ => ??? - } - } - - def getMemory(index: BigInt, size: Int, endian: Endian, env: mutable.Map[BigInt, BitVecLiteral]): BitVecLiteral = { - val end = index + size / 8 - 1 - val memoryChunks = (index to end).map(i => env.getOrElse(i, BitVecLiteral(0, 8))) - - val (newValue, newSize) = memoryChunks.foldLeft(("", 0)) { (acc, current) => - val currentString: String = current.value.toString(2).reverse.padTo(8, '0').reverse - endian match { - case Endian.LittleEndian => (currentString + acc._1, acc._2 + current.size) - case Endian.BigEndian => (acc._1 + currentString, acc._2 + current.size) - } - } - - BitVecLiteral(BigInt(newValue, 2), newSize) - } - - def setMemory( - index: BigInt, - size: Int, - endian: Endian, - value: BitVecLiteral, - env: mutable.Map[BigInt, BitVecLiteral] - ): BitVecLiteral = { - val binaryString: String = value.value.toString(2).reverse.padTo(size, '0').reverse - - val data: List[BitVecLiteral] = endian match { - case Endian.LittleEndian => - binaryString.grouped(8).toList.map(chunk => BitVecLiteral(BigInt(chunk, 2), 8)).reverse - case Endian.BigEndian => - binaryString.grouped(8).toList.map(chunk => BitVecLiteral(BigInt(chunk, 2), 8)) - } - - data.zipWithIndex.foreach { case (bv, i) => - env(index + i) = bv - } - - value - } - - private def interpretProcedure(p: Procedure): Unit = { - Logger.debug(s"Procedure(${p.name}, ${p.address.getOrElse("None")})") - - // Procedure.in - for ((in, index) <- p.in.zipWithIndex) { - Logger.debug(s"\tin[$index]:${in.name} ${in.size} ${in.value}") - } - - // Procedure.out - for ((out, index) <- p.out.zipWithIndex) { - Logger.debug(s"\tout[$index]:${out.name} ${out.size} ${out.value}") - } - - // Procedure.Block - p.entryBlock match { - case Some(block) => nextCmd = Some(block.statements.headOption.getOrElse(block.jump)) - case None => nextCmd = Some(returnCmd.pop()) - } - } - - private def interpretJump(j: Jump) : Unit = { - Logger.debug(s"jump:") - breakable { - j match { - case gt: GoTo => - Logger.debug(s"$gt") - for (g <- gt.targets) { - val condition: Option[Expr] = g.statements.headOption.collect { case a: Assume => a.body } - condition match { - case Some(e) => evalBool(e, regs) match { - case TrueLiteral => - nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) - break - case _ => - } - case None => - nextCmd = Some(g.statements.headOption.getOrElse(g.jump)) - break - } - } - case r: Return => { - nextCmd = Some(returnCmd.pop()) - } - case h: Unreachable => { - Logger.debug("Unreachable") - nextCmd = None - } - } - } - } - - private def interpretStatement(s: Statement): Unit = { - Logger.debug(s"statement[$s]:") - s match { - case assign: Assign => - Logger.debug(s"LocalAssign ${assign.lhs} = ${assign.rhs}") - val evalRight = eval(assign.rhs, regs) - Logger.debug(s"LocalAssign ${assign.lhs} := 0x${evalRight.value.toString(16)}[u${evalRight.size}]\n") - regs += (assign.lhs -> evalRight) - - case assign: MemoryAssign => - Logger.debug(s"MemoryAssign ${assign.mem}[${assign.index}] = ${assign.value}") - - val index: Int = eval(assign.index, regs).value.toInt - val value: BitVecLiteral = eval(assign.value, regs) - Logger.debug(s"\tMemoryStore(mem:${assign.mem}, index:0x${index.toHexString}, value:0x${ - value.value - .toString(16) - }[u${value.size}], size:${assign.size})") - - val evalStore = setMemory(index, assign.size, assign.endian, value, mems) - evalStore match { - case BitVecLiteral(value, size) => - Logger.debug(s"MemoryAssign ${assign.mem} := 0x${value.toString(16)}[u$size]\n") - } - case _ : NOP => () - case assert: Assert => - // TODO - Logger.debug(assert) - evalBool(assert.body, regs) match { - case TrueLiteral => () - case FalseLiteral => throw Exception(s"Assertion failed ${assert}") - } - case assume: Assume => - // TODO, but already taken into effect if it is a branch condition - Logger.debug(assume) - evalBool(assume.body, regs) match { - case TrueLiteral => () - case FalseLiteral => { - nextCmd = None - Logger.debug(s"Assumption not satisfied: $assume") - } - } - case dc: DirectCall => - Logger.debug(s"$dc") - returnCmd.push(dc.successor) - interpretProcedure(dc.target) - break - case ic: IndirectCall => - Logger.debug(s"$ic") - if (ic.target == Register("R30", 64)) { - if (returnCmd.nonEmpty) { - nextCmd = Some(returnCmd.pop()) - } else { - //Exit Interpreter - nextCmd = None - } - break - } else { - ??? - } - } - } - - def interpret(IRProgram: Program): mutable.Map[Variable, BitVecLiteral] = { - // initialize memory array from IRProgram - var currentAddress = BigInt(0) - IRProgram.initialMemory.values.foreach { im => - if (im.address + im.size > currentAddress) { - val start = im.address.max(currentAddress) - val data = if (im.address < currentAddress) im.bytes.slice((currentAddress - im.address).toInt, im.size) else im.bytes - data.zipWithIndex.foreach { (byte, index) => - mems(start + index) = byte - } - currentAddress = im.address + im.size - } - } - - // Initial SP, FP and LR to regs - regs += (Register("R31", 64) -> SP) - regs += (Register("R29", 64) -> FP) - regs += (Register("R30", 64) -> LR) - - // Program.Procedure - interpretProcedure(IRProgram.mainProcedure) - while (nextCmd.isDefined) { - nextCmd.get match { - case c: Statement => interpretStatement(c) - case c: Jump => interpretJump(c) - } - } - - regs - } -} diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index 5ff441799..0d368c15e 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -3,9 +3,10 @@ package ir import scala.collection.mutable.ArrayBuffer import scala.collection.{IterableOnceExtensionMethods, View, immutable, mutable} import boogie.* -import analysis.{BitVectorEval, MergedRegion} +import analysis.MergedRegion import util.intrusive_list.* import translating.serialiseIL +import eval.BitVectorEval class Program(var procedures: ArrayBuffer[Procedure], var mainProcedure: Procedure, diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index a0d87f5f9..03688084e 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -12,12 +12,21 @@ val R4: Register = Register("R4", 64) val R5: Register = Register("R5", 64) val R6: Register = Register("R6", 64) val R7: Register = Register("R7", 64) +val R8: Register = Register("R8", 64) val R29: Register = Register("R29", 64) val R30: Register = Register("R30", 64) val R31: Register = Register("R31", 64) +def exprEq(l: Expr, r: Expr) : Expr = (l, r) match { + case (l, r) if l.getType != r.getType => FalseLiteral + case (l, r) if l.getType == BoolType => BinaryExpr(BoolEQ, l, r) + case (l, r) if l.getType.isInstanceOf[BitVecType] => BinaryExpr(BVEQ, l, r) + case (l, r) if l.getType == IntType => BinaryExpr(IntEQ, l, r) + case _ => FalseLiteral +} + def bv32(i: Int): BitVecLiteral = BitVecLiteral(i, 32) def bv64(i: Int): BitVecLiteral = BitVecLiteral(i, 64) @@ -26,6 +35,8 @@ def bv8(i: Int): BitVecLiteral = BitVecLiteral(i, 8) def bv16(i: Int): BitVecLiteral = BitVecLiteral(i, 16) +def R(i: Int): Register = Register(s"R$i", 64) + case class DelayNameResolve(ident: String) { def resolveProc(prog: Program): Option[Procedure] = prog.collectFirst { case b: Procedure if b.name == ident => b diff --git a/src/main/scala/analysis/BitVectorEval.scala b/src/main/scala/ir/eval/BitVectorEval.scala similarity index 90% rename from src/main/scala/analysis/BitVectorEval.scala rename to src/main/scala/ir/eval/BitVectorEval.scala index ee2d09512..d92ac1817 100644 --- a/src/main/scala/analysis/BitVectorEval.scala +++ b/src/main/scala/ir/eval/BitVectorEval.scala @@ -1,6 +1,6 @@ -package analysis +package ir.eval + import ir.* -import analysis.BitVectorEval.* import scala.annotation.tailrec import scala.math.pow @@ -28,11 +28,12 @@ object BitVectorEval { /** * Converts a bitvector value to its corresponding signed integer */ - def bv2SignedInt(b: BitVecLiteral): BigInt = + def bv2SignedInt(b: BitVecLiteral): BigInt = { if isNegative(b) then b.value - BigInt(2).pow(b.size) else b.value + } /** (bvadd (_ BitVec m) (_ BitVec m) (_ BitVec m)) @@ -168,15 +169,15 @@ object BitVectorEval { /** (bvneq (_ BitVec m) (_ BitVec m)) * - not equal too */ - def smt_bveq(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(s == t) + def smt_bveq(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + s == t } /** (bvneq (_ BitVec m) (_ BitVec m)) * - not equal too */ - def smt_bvneq(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(s != t) + def smt_bvneq(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + s != t } /** (bvshl (_ BitVec m) (_ BitVec m) (_ BitVec m)) @@ -270,55 +271,55 @@ object BitVectorEval { /** (bvult (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned less-than */ - def smt_bvult(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(bv2nat(s) < bv2nat(t)) + def smt_bvult(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + bv2nat(s) < bv2nat(t) } /** (bvule (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned less than or equal */ - def smt_bvule(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { - bool2BoolLit(bv2nat(s) <= bv2nat(t)) + def smt_bvule(s: BitVecLiteral, t: BitVecLiteral): Boolean = { + bv2nat(s) <= bv2nat(t) } /** (bvugt (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned greater than */ - def smt_bvugt(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { + def smt_bvugt(s: BitVecLiteral, t: BitVecLiteral): Boolean = { smt_bvult(t, s) } /** (bvuge (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for unsigned greater than or equal */ - def smt_bvuge(s: BitVecLiteral, t: BitVecLiteral): BoolLit = smt_bvule(t, s) + def smt_bvuge(s: BitVecLiteral, t: BitVecLiteral): Boolean = smt_bvule(t, s) /** (bvslt (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed less than */ - def smt_bvslt(s: BitVecLiteral, t: BitVecLiteral): BoolLit = { + def smt_bvslt(s: BitVecLiteral, t: BitVecLiteral): Boolean = { val sNeg = isNegative(s) val tNeg = isNegative(t) - bool2BoolLit((sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvult(s, t) == TrueLiteral))) + (sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvult(s, t))) } /** (bvsle (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed less than or equal */ - def smt_bvsle(s: BitVecLiteral, t: BitVecLiteral): BoolLit = + def smt_bvsle(s: BitVecLiteral, t: BitVecLiteral): Boolean = val sNeg = isNegative(s) val tNeg = isNegative(t) - bool2BoolLit((sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvule(s, t) == TrueLiteral))) + (sNeg && !tNeg) || ((sNeg == tNeg) && (smt_bvule(s, t))) /** (bvsgt (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed greater than */ - def smt_bvsgt(s: BitVecLiteral, t: BitVecLiteral): BoolLit = smt_bvslt(t, s) + def smt_bvsgt(s: BitVecLiteral, t: BitVecLiteral): Boolean = smt_bvslt(t, s) /** (bvsge (_ BitVec m) (_ BitVec m) Bool) * - binary predicate for signed greater than or equal */ - def smt_bvsge(s: BitVecLiteral, t: BitVecLiteral): BoolLit = smt_bvsle(t, s) + def smt_bvsge(s: BitVecLiteral, t: BitVecLiteral): Boolean = smt_bvsle(t, s) def smt_bvashr(s: BitVecLiteral, t: BitVecLiteral): BitVecLiteral = if (!isNegative(s)) { diff --git a/src/main/scala/ir/eval/ExprEval.scala b/src/main/scala/ir/eval/ExprEval.scala new file mode 100644 index 000000000..6a9fb8829 --- /dev/null +++ b/src/main/scala/ir/eval/ExprEval.scala @@ -0,0 +1,264 @@ +package ir.eval +import ir.eval.BitVectorEval +import util.functional.State +import ir.* + +/** We generalise the expression evaluator to a partial evaluator to simplify evaluating casts. + * + * - Program state is taken via a function from var -> value and for loads a function from (mem,addr,endian,size) -> + * value. + * - For conrete evaluators we prefer low-level representations (bool vs BoolLit) and wrap them at the expression + * eval level + * - Avoid using any default cases so we have some idea of complete coverage + */ + +def evalBVBinExpr(b: BVBinOp, l: BitVecLiteral, r: BitVecLiteral): BitVecLiteral = { + b match { + case BVADD => BitVectorEval.smt_bvadd(l, r) + case BVSUB => BitVectorEval.smt_bvsub(l, r) + case BVMUL => BitVectorEval.smt_bvmul(l, r) + case BVUDIV => BitVectorEval.smt_bvudiv(l, r) + case BVSDIV => BitVectorEval.smt_bvsdiv(l, r) + case BVSREM => BitVectorEval.smt_bvsrem(l, r) + case BVUREM => BitVectorEval.smt_bvurem(l, r) + case BVSMOD => BitVectorEval.smt_bvsmod(l, r) + case BVAND => BitVectorEval.smt_bvand(l, r) + case BVOR => BitVectorEval.smt_bvxor(l, r) + case BVXOR => BitVectorEval.smt_bvxor(l, r) + case BVNAND => BitVectorEval.smt_bvnand(l, r) + case BVNOR => BitVectorEval.smt_bvnor(l, r) + case BVXNOR => BitVectorEval.smt_bvxnor(l, r) + case BVSHL => BitVectorEval.smt_bvshl(l, r) + case BVLSHR => BitVectorEval.smt_bvlshr(l, r) + case BVASHR => BitVectorEval.smt_bvashr(l, r) + case BVCOMP => BitVectorEval.smt_bvcomp(l, r) + case BVCONCAT => BitVectorEval.smt_concat(l, r) + case BVULE | BVULT | BVUGT | BVUGE | BVSLT | BVSLE | BVSGT | BVSGE | BVEQ | BVNEQ => + throw IllegalArgumentException("Did not expect logical op") + } +} + +def evalBVLogBinExpr(b: BVBinOp, l: BitVecLiteral, r: BitVecLiteral): Boolean = b match { + case BVULE => BitVectorEval.smt_bvule(l, r) + case BVUGT => BitVectorEval.smt_bvult(l, r) + case BVUGE => BitVectorEval.smt_bvuge(l, r) + case BVULT => BitVectorEval.smt_bvult(l, r) + case BVSLT => BitVectorEval.smt_bvslt(l, r) + case BVSLE => BitVectorEval.smt_bvsle(l, r) + case BVSGT => BitVectorEval.smt_bvsgt(l, r) + case BVSGE => BitVectorEval.smt_bvsge(l, r) + case BVEQ => BitVectorEval.smt_bveq(l, r) + case BVNEQ => BitVectorEval.smt_bvneq(l, r) + case BVADD | BVSUB | BVMUL | BVUDIV | BVSDIV | BVSREM | BVUREM | BVSMOD | BVAND | BVOR | BVXOR | BVNAND | BVNOR | + BVXNOR | BVSHL | BVLSHR | BVASHR | BVCOMP | BVCONCAT => + throw IllegalArgumentException("Did not expect non-logical op") +} + +def evalIntLogBinExpr(b: IntBinOp, l: BigInt, r: BigInt): Boolean = b match { + case IntEQ => l == r + case IntNEQ => l != r + case IntLT => l < r + case IntLE => l <= r + case IntGT => l > r + case IntGE => l >= r + case IntADD | IntSUB | IntMUL | IntDIV | IntMOD => throw IllegalArgumentException("Did not expect non-logical op") +} + +def evalIntBinExpr(b: IntBinOp, l: BigInt, r: BigInt): BigInt = b match { + case IntADD => l + r + case IntSUB => l - r + case IntMUL => l * r + case IntDIV => l / r + case IntMOD => l % r + case IntEQ | IntNEQ | IntLT | IntLE | IntGT | IntGE => throw IllegalArgumentException("Did not expect logical op") +} + +def evalBoolLogBinExpr(b: BoolBinOp, l: Boolean, r: Boolean): Boolean = b match { + case BoolEQ => l == r + case BoolEQUIV => l == r + case BoolNEQ => l != r + case BoolAND => l && r + case BoolOR => l || r + case BoolIMPLIES => l || (!r) +} + +def evalUnOp(op: UnOp, body: Literal): Expr = { + (body, op) match { + case (b: BitVecLiteral, BVNOT) => BitVectorEval.smt_bvnot(b) + case (b: BitVecLiteral, BVNEG) => BitVectorEval.smt_bvneg(b) + case (i: IntLiteral, IntNEG) => IntLiteral(-i.value) + case (FalseLiteral, BoolNOT) => TrueLiteral + case (TrueLiteral, BoolNOT) => FalseLiteral + case (_, _) => throw Exception(s"Unreachable ${(body, op)}") + } +} + +def evalIntExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Either[Expr, BigInt] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case i: IntLiteral => Right(i.value) + case o => Left(o) + } +} + +def evalBVExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Either[Expr, BitVecLiteral] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case b: BitVecLiteral => Right(b) + case o => Left(o) + } +} + +def evalLogExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Either[Expr, Boolean] = { + partialEvalExpr(exp, variableAssignment, memory) match { + case TrueLiteral => Right(true) + case FalseLiteral => Right(false) + case o => Left(o) + } +} + +def evalExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((d, a, b, c) => None) +): Option[Literal] = { + partialEvalExpr match { + case l: Literal => Some(l) + case _ => None + } +} + +/** typeclass defining variable and memory laoding from state S + */ +trait Loader[S, E] { + def getVariable(v: Variable): State[S, Option[Literal], E] + def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[S, Option[Literal], E] = { + State.pure(None) + } +} + +def statePartialEvalExpr[S](l: Loader[S, InterpreterError])(exp: Expr): State[S, Expr, InterpreterError] = { + val eval = statePartialEvalExpr(l) + val ns = exp match { + case f: UninterpretedFunction => State.pure(f) + case unOp: UnaryExpr => + for { + body <- eval(unOp.arg) + } yield (body match { + case l: Literal => evalUnOp(unOp.op, l) + case o => UnaryExpr(unOp.op, body) + }) + case binOp: BinaryExpr => + for { + lhs <- eval(binOp.arg1) + rhs <- eval(binOp.arg2) + } yield (binOp.getType match { + case m: MapType => binOp + case b: BitVecType => { + (binOp.op, lhs, rhs) match { + case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => evalBVBinExpr(o, l, r) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + case BoolType => { + def bool2lit(b: Boolean) = if b then TrueLiteral else FalseLiteral + (binOp.op, lhs, rhs) match { + case (o: BVBinOp, l: BitVecLiteral, r: BitVecLiteral) => bool2lit(evalBVLogBinExpr(o, l, r)) + case (o: IntBinOp, l: IntLiteral, r: IntLiteral) => bool2lit(evalIntLogBinExpr(o, l.value, r.value)) + case (o: BoolBinOp, l: BoolLit, r: BoolLit) => bool2lit(evalBoolLogBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + case IntType => { + (binOp.op, lhs, rhs) match { + case (o: IntBinOp, l: IntLiteral, r: IntLiteral) => IntLiteral(evalIntBinExpr(o, l.value, r.value)) + case _ => BinaryExpr(binOp.op, lhs, rhs) + } + } + }) + case extend: ZeroExtend => + for { + body <- eval(extend.body) + } yield (body match { + case b: BitVecLiteral => BitVectorEval.smt_zero_extend(extend.extension, b) + case o => extend.copy(body = o) + }) + case extend: SignExtend => + for { + body <- eval(extend.body) + } yield (body match { + case b: BitVecLiteral => BitVectorEval.smt_sign_extend(extend.extension, b) + case o => extend.copy(body = o) + }) + case e: Extract => + for { + body <- eval(e.body) + } yield (body match { + case b: BitVecLiteral => BitVectorEval.boogie_extract(e.end, e.start, b) + case o => e.copy(body = o) + }) + case r: Repeat => + for { + body <- eval(r.body) + } yield (body match { + case b: BitVecLiteral => { + assert(r.repeats > 0) + if (r.repeats == 1) b + else { + (2 to r.repeats).foldLeft(b)((acc, r) => BitVectorEval.smt_concat(acc, b)) + } + } + case o => r.copy(body = o) + }) + case variable: Variable => + for { + v: Option[Literal] <- l.getVariable(variable) + } yield (v.getOrElse(variable)) + case ml: MemoryLoad => + for { + addr <- eval(ml.index) + mem <- l.loadMemory(ml.mem, addr, ml.endian, ml.size) + } yield (mem.getOrElse(ml)) + case b: BitVecLiteral => State.pure(b) + case b: IntLiteral => State.pure(b) + case b: BoolLit => State.pure(b) + } + State.protect( + () => ns, + { case e => + Errored(e.toString) + }: PartialFunction[Exception, InterpreterError] + ) + +} + +class StatelessLoader[E]( + getVar: Variable => Option[Literal], + loadMem: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +) extends Loader[Unit, E] { + def getVariable(v: Variable): State[Unit, Option[Literal], E] = State.pure(getVar(v)) + override def loadMemory(m: Memory, addr: Expr, endian: Endian, size: Int): State[Unit, Option[Literal], E] = + State.pure(loadMem(m, addr, endian, size)) +} + +def partialEvalExpr( + exp: Expr, + variableAssignment: Variable => Option[Literal], + memory: (Memory, Expr, Endian, Int) => Option[Literal] = ((a, b, c, d) => None) +): Expr = { + val l = StatelessLoader[InterpreterError](variableAssignment, memory) + State.evaluate((), statePartialEvalExpr(l)(exp)) match { + case Right(e) => e + case Left(e) => throw Exception("Unable to evaluate expr : " + e.toString) + } +} diff --git a/src/main/scala/ir/eval/InterpretBasilIR.scala b/src/main/scala/ir/eval/InterpretBasilIR.scala new file mode 100644 index 000000000..6554b71c9 --- /dev/null +++ b/src/main/scala/ir/eval/InterpretBasilIR.scala @@ -0,0 +1,546 @@ +package ir.eval +import ir.* +import util.IRContext +import util.Logger +import util.functional.* +import boogie.Scope + +/** Abstraction for memload and variable lookup used by the expression evaluator. + */ +case class StVarLoader[S, F <: Effects[S, InterpreterError]](f: F) extends Loader[S, InterpreterError] { + + def getVariable(v: Variable): State[S, Option[Literal], InterpreterError] = { + for { + v <- f.loadVar(v.name) + } yield { + v match { + case Scalar(l) => Some(l) + case _ => None + } + } + } + + override def loadMemory( + m: Memory, + addr: Expr, + endian: Endian, + size: Int + ): State[S, Option[Literal], InterpreterError] = { + for { + r <- addr match { + case l: Literal if size == 1 => + Eval + .loadSingle(f)(m.name, Scalar(l)) + .map((v: BasilValue) => + v match { + case Scalar(l) => Some(l) + case _ => None + } + ) + case l: Literal => Eval.loadBV(f)(m.name, Scalar(l), endian, size).map(Some(_)) + case _ => State.get((s: S) => None) + } + } yield r + } + +} + +/* + * Helper functions for compiling high level structures to the interpreter effects. + * All are parametric in concrete state S and Effects[S] + */ +case object Eval { + + /*--------------------------------------------------------------------------------*/ + /* Eval functions */ + /*--------------------------------------------------------------------------------*/ + + def evalExpr[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Expr, InterpreterError] = { + val ldr = StVarLoader[S, T](f) + for { + res <- ir.eval.statePartialEvalExpr[S](ldr)(e) + } yield res + } + + def evalBV[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, BitVecLiteral, InterpreterError] = { + for { + res <- evalExpr(f)(e) + r <- State.pureE(res match { + case l: BitVecLiteral => Right(l) + case _ => Left((Errored(s"Eval BV residual $e"))) + }) + } yield r + } + + def evalInt[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, BigInt, InterpreterError] = { + for { + res <- evalExpr(f)(e) + r <- State.pureE(res match { + case l: IntLiteral => Right(l.value) + case _ => Left((Errored(s"Eval Int residual $e"))) + }) + } yield r + } + + def evalBool[S, T <: Effects[S, InterpreterError]](f: T)(e: Expr): State[S, Boolean, InterpreterError] = { + for { + res <- evalExpr(f)(e) + r <- State.pureE(res match { + case l: BoolLit => Right(l == TrueLiteral) + case _ => Left((Errored(s"Eval Bool residual $e"))) + }) + } yield r + } + + /*--------------------------------------------------------------------------------*/ + /* Load functions */ + /*--------------------------------------------------------------------------------*/ + + def load[S, T <: Effects[S, InterpreterError]]( + f: T + )(vname: String, addr: Scalar, endian: Endian, count: Int): State[S, List[BasilValue], InterpreterError] = { + for { + _ <- + if count == 0 then State.setError((Errored(s"Attempted fractional load"))) else State.pure(()) + keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), 0 until count) + values <- f.loadMem(vname, keys.toList) + vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } + } yield (vals.toList) + } + + /** Load and concat bitvectors */ + def loadBV[S, T <: Effects[S, InterpreterError]]( + f: T + )(vname: String, addr: Scalar, endian: Endian, size: Int): State[S, BitVecLiteral, InterpreterError] = for { + mem <- f.loadVar(vname) + x <- mem match { + case mapv @ BasilMapValue(_, MapType(_, BitVecType(sz))) => State.pure((sz, mapv)) + case _ => State.setError((Errored("Trued to load-concat non bv"))) + } + (valsize, mapv) = x + + cells = size / valsize + + res <- load(f)(vname, addr, endian, cells) // actual load + bvs: List[BitVecLiteral] <- + State.mapM( + (c: BasilValue) => + c match { + case Scalar(bv @ BitVecLiteral(v, sz)) if sz == valsize => State.pure(bv) + case c => + State.setError( + TypeError(s"Loaded value of type ${c.irType} did not match expected type bv$valsize") + ) + }, + res + ) + } yield { + bvs.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r)) + } + + def loadSingle[S, T <: Effects[S, InterpreterError]]( + f: T + )(vname: String, addr: Scalar): State[S, BasilValue, InterpreterError] = { + for { + m <- load(f)(vname, addr, Endian.LittleEndian, 1) + } yield { + m.head + } + } + + /*--------------------------------------------------------------------------------*/ + /* Store functions */ + /*--------------------------------------------------------------------------------*/ + + /* Expand addr for number of values to store */ + def store[S, T <: Effects[S, InterpreterError]](f: T)( + vname: String, + addr: BasilValue, + values: List[BasilValue], + endian: Endian + ): State[S, Unit, InterpreterError] = for { + mem <- f.loadVar(vname) + x <- mem match { + case m @ BasilMapValue(_, MapType(kt, vt)) + if Some(kt) == addr.irType && values.forall(v => v.irType == Some(vt)) => + State.pure((m, kt, vt)) + case v => State.setError((TypeError(s"Invalid map store operation to $vname : $v"))) + } + (mapval, keytype, valtype) = x + keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until values.size)) + vals = endian match { + case Endian.LittleEndian => values.reverse + case Endian.BigEndian => values + } + x <- f.storeMem(vname, keys.zip(vals).toMap) + } yield x + + /** Extract bitvec to bytes and store bytes */ + def storeBV[S, T <: Effects[S, InterpreterError]](f: T)( + vname: String, + addr: BasilValue, + value: BitVecLiteral, + endian: Endian + ): State[S, Unit, InterpreterError] = for { + mem <- f.loadVar(vname) + mr <- mem match { + case m @ BasilMapValue(_, MapType(kt, BitVecType(size))) if Some(kt) == addr.irType => State.pure((m, size)) + case v => + State.setError( + TypeError( + s"Invalid map store operation to $vname : ${v.irType} (expect [${addr.irType}] <- ${value.getType})" + ) + ) + } + (mapval, vsize) = mr + cells = value.size / vsize + _ = { + if (cells < 1) { + State.setError((MemoryError("Tried to execute fractional store"))) + } else { + State.pure(()) + } + } + + extractVals = (0 until cells).map(i => BitVectorEval.boogie_extract((i + 1) * vsize, i * vsize, value)).toList + vs = endian match { + case Endian.LittleEndian => extractVals.map(Scalar(_)) + case Endian.BigEndian => extractVals.reverse.map(Scalar(_)) + } + + keys <- State.mapM((i: Int) => State.pureE(BasilValue.unsafeAdd(addr, i)), (0 until cells)) + s <- f.storeMem(vname, keys.zip(vs).toMap) + } yield s + + def storeSingle[S, E, T <: Effects[S, E]]( + f: T + )(vname: String, addr: BasilValue, value: BasilValue): State[S, Unit, E] = { + f.storeMem(vname, Map((addr -> value))) + } + + /** Helper functions * */ + + /** + * Load all memory cells from pointer until reaching cell containing 0. + * Ptr -> List[Bitvector] + */ + def getNullTerminatedString[S, T <: Effects[S, InterpreterError]](f: T) + (rgn: String, src: BasilValue, acc: List[BitVecLiteral] = List()): State[S, List[BitVecLiteral], InterpreterError] = + for { + srv: BitVecLiteral <- src match { + case Scalar(b: BitVecLiteral) => State.pure(b) + case _ => State.setError(Errored(s"Not pointer : $src")) + } + c <- f.loadMem(rgn, List(src)) + res <- c.head match { + case Scalar(BitVecLiteral(0, 8)) => State.pure(acc) + case Scalar(b: BitVecLiteral) => { + for { + nsrc <- State.pureE(BasilValue.unsafeAdd(src, 1)) + r <- getNullTerminatedString(f)(rgn, nsrc, acc.appended(b)) + } yield r + } + case _ => State.setError(Errored(s"not byte $c")) + } + } yield res +} + +class BASILInterpreter[S](f: Effects[S, InterpreterError]) extends Interpreter[S, InterpreterError](f) { + + def interpretOne: State[S, Boolean, InterpreterError] = { + val next = for { + next <- f.getNext + _ <- State.pure(Logger.debug(s"$next")) + r: Boolean <- next match { + case Intrinsic(tgt) => LibcIntrinsic.intrinsics(tgt)(f).map(_ => true) + case Run(c: Statement) => interpretStatement(f)(c).map(_ => true) + case Run(c: Jump) => interpretJump(f)(c).map(_ => true) + case Stopped() => State.pure(false) + case ErrorStop(e) => State.pure(false) + } + } yield r + + next.flatMapE { (e: InterpreterError) => + f.setNext(ErrorStop(e)).map(_ => false) + } + } + + def interpretJump[S, T <: Effects[S, InterpreterError]](f: T)(j: Jump): State[S, Unit, InterpreterError] = { + j match { + case gt: GoTo if gt.targets.size == 1 => { + f.setNext(Run(IRWalk.firstInBlock(gt.targets.head))) + } + case gt: GoTo => + val assumes = gt.targets.flatMap(_.statements.headOption).collect { case a: Assume => + a + } + for { + _ <- + if (assumes.size != gt.targets.size) { + State.setError((Errored(s"Some goto target missing guard $gt"))) + } else { + State.pure(()) + } + chosen: List[Assume] <- State.filterM((a: Assume) => Eval.evalBool(f)(a.body), assumes) + + res <- chosen match { + case Nil => State.setError(Errored(s"No jump target satisfied $gt")) + case h :: Nil => f.setNext(Run(h)) + case h :: tl => State.setError(Errored(s"More than one jump guard satisfied $gt")) + } + } yield res + case r: Return => f.doReturn() + case h: Unreachable => State.setError(EscapedControlFlow(h)) + } + } + + def interpretStatement[S, T <: Effects[S, InterpreterError]](f: T)(s: Statement): State[S, Unit, InterpreterError] = { + s match { + case assign: Assign => { + for { + rhs <- Eval.evalBV(f)(assign.rhs) + st <- f.storeVar(assign.lhs.name, assign.lhs.toBoogie.scope, Scalar(rhs)) + n <- f.setNext(Run(s.successor)) + } yield st + } + case assign: MemoryAssign => + for { + index: BitVecLiteral <- Eval.evalBV(f)(assign.index) + value: BitVecLiteral <- Eval.evalBV(f)(assign.value) + _ <- Eval.storeBV(f)(assign.mem.name, Scalar(index), value, assign.endian) + n <- f.setNext(Run(s.successor)) + } yield n + case assert: Assert => + for { + b <- Eval.evalBool(f)(assert.body) + _ <- + if (!b) { + State.setError(FailedAssertion(assert)) + } else { + f.setNext(Run(s.successor)) + } + } yield () + case assume: Assume => + for { + b <- Eval.evalBool(f)(assume.body) + n <- + if (!b) { + State.setError(Errored(s"Assumption not satisfied: $assume")) + } else { + f.setNext(Run(s.successor)) + } + } yield n + case dc: DirectCall => + if (dc.target.entryBlock.isDefined) { + val block = dc.target.entryBlock.get + f.call(dc.target.name, Run(block.statements.headOption.getOrElse(block.jump)), Run(dc.successor)) + } else if (LibcIntrinsic.intrinsics.contains(dc.target.name)) { + f.call(dc.target.name, Intrinsic(dc.target.name), Run(dc.successor)) + } else { + State.setError(EscapedControlFlow(dc)) + } + case ic: IndirectCall => + if (ic.target == Register("R30", 64)) { + f.doReturn() + } else { + for { + addr <- Eval.evalBV(f)(ic.target) + fp <- f.evalAddrToProc(addr.value.toInt) + _ <- fp match { + case Some(fp) => f.call(fp.name, fp.call, Run(ic.successor)) + case None => State.setError(EscapedControlFlow(ic)) + } + } yield () + } + case _: NOP => f.setNext(Run(s.successor)) + } + } +} + +object InterpFuns { + + def initRelocTable[S, T <: Effects[S, InterpreterError]](s: T)(ctx: IRContext): State[S, Unit, InterpreterError] = { + + val p = ctx.program + + val base = ctx.symbols.find(_.name == "__end__").get + var addr = base.value + var done = false + var x = List[(String, FunPointer)]() + + def newAddr(): BigInt = { + addr += 8 + addr + } + + for ((fname, fun) <- LibcIntrinsic.intrinsics) { + val name = fname.takeWhile(c => c != '@') + x = (name, FunPointer(BitVecLiteral(newAddr(), 64), name, Intrinsic(name))) :: x + } + + val intrinsics = x.toMap + + val procs = p.procedures.filter(proc => proc.address.isDefined) + + val fptrs = ctx.externalFunctions.toList + .sortBy(_.name) + .flatMap(f => { + intrinsics + .get(f.name) + .map(fp => (f.offset, fp)) + .orElse( + procs + .find(p => p.name == f.name) + .map(proc => + ( + f.offset, + FunPointer( + BitVecLiteral(proc.address.getOrElse(newAddr().toInt), 64), + proc.name, + Run(DirectCall(proc)) + ) + ) + ) + ) + }) + + // sort for deterministic trace + val stores = fptrs + .sortBy(f => f(0)) + .map { p => + val (offset, fptr) = p + Eval.storeSingle(s)("ghost-funtable", Scalar(fptr.addr), fptr) + >> Eval.storeBV(s)( + "mem", + Scalar(BitVecLiteral(offset, 64)), + fptr.addr, + Endian.LittleEndian + ) + } + + for { + _ <- State.sequence(State.pure(()), stores) + malloc_top = BitVecLiteral(newAddr() + 1024, 64) + _ <- s.storeVar("ghost_malloc_top", Scope.Global, Scalar(malloc_top)) + } yield () + } + + /** Functions which compile BASIL IR down to the minimal interpreter effects. + * + * Each function takes as parameter an implementation of Effects[S] + */ + + def initialState[S, E, T <: Effects[S, E]](s: T): State[S, Unit, E] = { + val SP: BitVecLiteral = BitVecLiteral(0x78000000, 64) + val FP: BitVecLiteral = SP + val LR: BitVecLiteral = BitVecLiteral(BigInt("78000000", 16), 64) + + for { + h <- State.pure(Logger.debug("DEFINE MEMORY REGIONS")) + h <- s.storeVar("mem", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + i <- s.storeVar("stack", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + j <- s.storeVar("R31", Scope.Global, Scalar(SP)) + k <- s.storeVar("R29", Scope.Global, Scalar(FP)) + l <- s.storeVar("R30", Scope.Global, Scalar(LR)) + l <- s.storeVar("R0", Scope.Global, Scalar(BitVecLiteral(0, 64))) + l <- s.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(0, 64))) + _ <- s.storeVar("ghost-funtable", Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(64)))) + _ <- IntrinsicImpl.initFileGhostRegions(s) + } yield l + } + + def initialiseProgram[S, T <: Effects[S, InterpreterError]](f: T)(p: Program): State[S, Unit, InterpreterError] = { + def initMemory(mem: String, mems: Iterable[MemorySection]) = { + for { + m <- State.sequence( + State.pure(()), + mems + .filter(m => m.address != 0 && m.bytes.nonEmpty) + .map { memory => + Eval.store(f)( + mem, + Scalar(BitVecLiteral(memory.address, 64)), + memory.bytes.toList.map(Scalar(_)), + Endian.BigEndian + ) + } + ) + } yield () + } + + for { + d <- initialState(f) + funs <- State.sequence( + State.pure(Logger.debug("INITIALISE FUNCTION ADDRESSES")), + p.procedures + .filter(p => p.blocks.nonEmpty && p.address.isDefined) + .map { (proc: Procedure) => + Eval.storeSingle(f)( + "ghost-funtable", + Scalar(BitVecLiteral(proc.address.get, 64)), + FunPointer(BitVecLiteral(proc.address.get, 64), proc.name, Run(IRWalk.firstInBlock(proc.entryBlock.get))) + ) + } + ) + _ <- State.pure(Logger.debug("INITIALISE MEMORY SECTIONS")) + mem <- initMemory("mem", p.initialMemory.values) + mem <- initMemory("stack", p.initialMemory.values) + mainfun = { + p.mainProcedure + } + r <- f.call(mainfun.name, Run(IRWalk.firstInBlock(mainfun.entryBlock.get)), Stopped()) + } yield r + } + + def initBSS[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext): State[S, Unit, InterpreterError] = { + val bss = for { + first <- p.symbols.find(s => s.name == "__bss_start__").map(_.value) + last <- p.symbols.find(s => s.name == "__bss_end__").map(_.value) + r <- if first == last then None else Some((first, (last - first) * 8)) + (addr, sz) = r + st = { + (rgn => Eval.storeBV(f)(rgn, Scalar(BitVecLiteral(addr, 64)), BitVecLiteral(0, sz.toInt), Endian.LittleEndian)) + } + } yield st + + bss match { + case None => Logger.error("No BSS initialised"); State.pure(()) + case Some(init) => init("mem") >> init("stack") + } + } + + def initProgState[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { + val st = (initialiseProgram(f)(p.program) >> initBSS(f)(p)) + >> InterpFuns.initRelocTable(f)(p) + val (fs, v) = st.f(is) + v match { + case Right(r) => fs + case Left(e) => throw Exception(s"Init failed $e") + } + } + + /* Intialise from ELF and Interpret program */ + def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: IRContext, is: S): S = { + val begin = initProgState(f)(p, is) + val interp = BASILInterpreter(f) + interp.run(begin) + } + + /* Interpret IR program */ + def interpretProg[S, T <: Effects[S, InterpreterError]](f: T)(p: Program, is: S): S = { + val begin = State.execute(is, initialiseProgram(f)(p)) + val interp = BASILInterpreter(f) + interp.run(begin) + } +} + +def interpret(IRProgram: Program): InterpreterState = { + InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) +} + +def interpret(IRProgram: IRContext): InterpreterState = { + InterpFuns.interpretProg(NormalInterpreter)(IRProgram, InterpreterState()) +} diff --git a/src/main/scala/ir/eval/InterpretBreakpoints.scala b/src/main/scala/ir/eval/InterpretBreakpoints.scala new file mode 100644 index 000000000..1620e1b77 --- /dev/null +++ b/src/main/scala/ir/eval/InterpretBreakpoints.scala @@ -0,0 +1,121 @@ +package ir.eval +import ir.* +import util.Logger +import util.IRContext +import util.functional.* + +enum BreakPointLoc: + case CMD(c: Command) + case CMDCond(c: Command, condition: Expr) + +case class BreakPointAction( + saveState: Boolean = true, + stop: Boolean = false, + evalExprs: List[(String, Expr)] = List(), + log: Boolean = false +) + +case class BreakPoint(name: String = "", location: BreakPointLoc, action: BreakPointAction) + +case class RememberBreakpoints[T, I <: Effects[T, InterpreterError]](f: I, breaks: List[BreakPoint]) + extends NopEffects[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), InterpreterError] { + + def findBreaks[R](c: Command): State[(T, R), List[BreakPoint], InterpreterError] = { + State.filterM( + b => + b.location match { + case BreakPointLoc.CMD(bc) if bc == c => State.pure(true) + case BreakPointLoc.CMDCond(bc, e) if bc == c => doLeft(Eval.evalBool(f)(e)) + case _ => State.pure(false) + }, + breaks + ) + } + + override def getNext: State[ + (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), + ExecutionContinuation, + InterpreterError + ] = { + for { + v: ExecutionContinuation <- doLeft(f.getNext) + n <- v match { + case Run(s) => + for { + breaks: List[BreakPoint] <- findBreaks(s) + res <- State + .sequence[(T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])]), Unit, InterpreterError]( + State.pure(()), + breaks.map { + case breakpoint @ BreakPoint(name, stopcond, action) => + for { + saved <- doLeft( + if action.saveState then State.getS[T, InterpreterError].map(s => Some(s)) + else State.pure(None) + ) + evals <- State.mapM( + (e: (String, Expr)) => + for { + ev <- doLeft(Eval.evalExpr(f)(e(1))) + } yield (e(0), e(1), ev), + action.evalExprs + ) + _ <- State.pure({ + if (action.log) { + val bpn = breakpoint.name + val bpcond = breakpoint.location match { + case BreakPointLoc.CMD(c) => s"${c.parent.label}:$c" + case BreakPointLoc.CMDCond(c, e) => s"${c.parent.label}:$c when $e" + } + val saving = if action.saveState then " stashing state, " else "" + val stopping = if action.stop then " stopping. " else "" + val evalstr = evals.map(e => s"\n ${e(0)} : eval(${e(1)}) = ${e(2)}").mkString("") + Logger.warn(s"Breakpoint $bpn@$bpcond.$saving$stopping$evalstr") + } + }) + _ <- + if action.stop then doLeft(f.setNext(ErrorStop(Errored(s"Stopped at breakpoint ${name}")))) + else doLeft(State.pure(())) + _ <- State.modify((istate: (T, List[(BreakPoint, Option[T], List[(String, Expr, Expr)])])) => + (istate(0), ((breakpoint, saved, evals) :: istate(1))) + ) + } yield () + } + ) + } yield () + case _ => State.pure(()) + } + } yield v + } +} + +def interpretWithBreakPoints[I]( + p: IRContext, + breakpoints: List[BreakPoint], + innerInterpreter: Effects[I, InterpreterError], + innerInitialState: I +): (I, List[(BreakPoint, Option[I], List[(String, Expr, Expr)])]) = { + val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) + val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) + res +} + +def interpretWithBreakPoints[I]( + p: Program, + breakpoints: List[BreakPoint], + innerInterpreter: Effects[I, InterpreterError], + innerInitialState: I +): (I, List[(BreakPoint, Option[I], List[(String, Expr, Expr)])]) = { + val interp = LayerInterpreter(innerInterpreter, RememberBreakpoints(innerInterpreter, breakpoints)) + val res = InterpFuns.interpretProg(interp)(p, (innerInitialState, List())) + res +} + +def interpretBreakPoints(p: IRContext, breakpoints: List[BreakPoint]) = { + interpretWithBreakPoints(p, breakpoints, NormalInterpreter, InterpreterState()) +} + + +def interpretBreakPoints(p: Program, breakpoints: List[BreakPoint]) = { + interpretWithBreakPoints(p, breakpoints, NormalInterpreter, InterpreterState()) +} diff --git a/src/main/scala/ir/eval/InterpretRLimit.scala b/src/main/scala/ir/eval/InterpretRLimit.scala new file mode 100644 index 000000000..8bc8e5568 --- /dev/null +++ b/src/main/scala/ir/eval/InterpretRLimit.scala @@ -0,0 +1,41 @@ + +package ir.eval +import ir.* +import util.IRContext +import util.Logger +import util.functional.* + +case class EffectsRLimit[T, E, I <: Effects[T, InterpreterError]](limit: Int) extends NopEffects[(T, Int), InterpreterError] { + + override def getNext: State[(T, Int), ExecutionContinuation, InterpreterError] = { + for { + c: (T, Int) <- State.getS + (is, resources) = c + _ <- if (resources >= limit && limit >= 0) { + State.setError(Errored(s"Resource limit $limit reached")) + } else { + State.modify((s: (T, Int)) => (s(0), s(1) + 1)) + } + } yield Stopped() // thrown away by LayerInterpreter + } +} + +def interpretWithRLimit[I](p: Program, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Int) = { + val rlimitInterpreter = LayerInterpreter(innerInterpreter, EffectsRLimit(instructionLimit)) + InterpFuns.interpretProg(rlimitInterpreter)(p, (innerInitialState, 0)) +} + +def interpretWithRLimit[I](p: IRContext, instructionLimit: Int, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Int) = { + val rlimitInterpreter = LayerInterpreter(innerInterpreter, EffectsRLimit(instructionLimit)) + val (begin, _) = InterpFuns.initProgState(rlimitInterpreter)(p, (innerInitialState, 0)) + // throw away initialisation trace + BASILInterpreter(rlimitInterpreter).run((begin, 0)) +} + +def interpretRLimit(p: Program, instructionLimit: Int): (InterpreterState, Int) = { + interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) +} + +def interpretRLimit(p: IRContext, instructionLimit: Int): (InterpreterState, Int) = { + interpretWithRLimit(p, instructionLimit, NormalInterpreter, InterpreterState()) +} diff --git a/src/main/scala/ir/eval/InterpretTrace.scala b/src/main/scala/ir/eval/InterpretTrace.scala new file mode 100644 index 000000000..7087463c3 --- /dev/null +++ b/src/main/scala/ir/eval/InterpretTrace.scala @@ -0,0 +1,69 @@ +package ir.eval +import ir.* +import util.IRContext +import util.Logger +import util.functional.* +import boogie.Scope + +enum ExecEffect: + case Call(target: String, begin: ExecutionContinuation, returnTo: ExecutionContinuation) + case Return + case StoreVar(v: String, s: Scope, value: BasilValue) + case LoadVar(v: String) + case StoreMem(vname: String, update: Map[BasilValue, BasilValue]) + case LoadMem(vname: String, addrs: List[BasilValue]) + case FindProc(addr: Int) + +case class Trace(t: List[ExecEffect]) + +case object Trace { + def add[E](e: ExecEffect): State[Trace, Unit, E] = { + State.modify((t: Trace) => Trace(t.t.appended(e))) + } +} + +case class TraceGen[E]() extends NopEffects[Trace, E] { + + /** Values are discarded by ProductInterpreter so do not matter */ + override def loadMem(v: String, addrs: List[BasilValue]): State[Trace, List[BasilValue], E] = for { + s <- Trace.add(ExecEffect.LoadMem(v, addrs)) + } yield List() + + override def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[Trace, Unit, E] = for { + s <- Trace.add(ExecEffect.Call(target, beginFrom, returnTo)) + } yield () + + override def doReturn(): State[Trace, Unit, E] = for { + s <- Trace.add(ExecEffect.Return) + } yield () + + override def storeVar(v: String, scope: Scope, value: BasilValue): State[Trace, Unit, E] = for { + s <- if !v.startsWith("ghost") then Trace.add(ExecEffect.StoreVar(v, scope, value)) else State.pure(()) + } yield () + + override def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[Trace, Unit, E] = for { + s <- if !vname.startsWith("ghost") then Trace.add(ExecEffect.StoreMem(vname, update)) else State.pure(()) + } yield () + +} + +def tracingInterpreter[I, E](innerInterpreter: Effects[I, E]) = ProductInterpreter(innerInterpreter, TraceGen()) + +def interpretWithTrace[I](p: Program, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Trace) = { + InterpFuns.interpretProg(tracingInterpreter(innerInterpreter))(p, (innerInitialState, Trace(List()))) +} + +def interpretWithTrace[I](p: IRContext, innerInterpreter: Effects[I, InterpreterError], innerInitialState: I): (I, Trace) = { + val tracingInterpreter = ProductInterpreter(innerInterpreter, TraceGen()) + val (begin, _) = InterpFuns.initProgState(tracingInterpreter)(p, (innerInitialState, Trace(List()))) + // throw away initialisation trace + BASILInterpreter(tracingInterpreter).run((begin, Trace(List()))) +} + +def interpretTrace(p: Program) = { + interpretWithTrace(p, NormalInterpreter, InterpreterState()) +} + +def interpretTrace(p: IRContext) = { + interpretWithTrace(p, NormalInterpreter, InterpreterState()) +} diff --git a/src/main/scala/ir/eval/Interpreter.scala b/src/main/scala/ir/eval/Interpreter.scala new file mode 100644 index 000000000..6d42b5cc3 --- /dev/null +++ b/src/main/scala/ir/eval/Interpreter.scala @@ -0,0 +1,569 @@ +package ir.eval +import ir.* +import util.Logger +import util.functional.* +import boogie.Scope + +import scala.annotation.tailrec + +/** Interpreter status type, either stopped, run next command or error + */ +sealed trait ExecutionContinuation +case class Stopped() extends ExecutionContinuation /* normal program stop */ +case class ErrorStop(error: InterpreterError) extends ExecutionContinuation /* program stop in error state */ +case class Run(next: Command) extends ExecutionContinuation /* continue by executing next command */ +case class Intrinsic(name: String) extends ExecutionContinuation /* a named intrinsic instruction */ + +sealed trait InterpreterError +case class FailedAssertion(a: Assert) extends InterpreterError +case class EscapedControlFlow(call: Jump | Call) + extends InterpreterError /* controlflow has reached somewhere eunrecoverable */ +case class Errored(message: String = "") extends InterpreterError +case class TypeError(message: String = "") extends InterpreterError /* type mismatch appeared */ +case class EvalError(message: String = "") + extends InterpreterError /* failed to evaluate an expression to a concrete value */ +case class MemoryError(message: String = "") extends InterpreterError /* An error to do with memory */ + +/* Concrete value type of the interpreter. */ +sealed trait BasilValue(val irType: Option[IRType]) +case class Scalar(value: Literal) extends BasilValue(Some(value.getType)) { + override def toString: _root_.java.lang.String = value match { + case b: BitVecLiteral => "0x%x:bv%d".format(b.value, b.size) + case c => c.toString + } +} + +/* Abstract callable function address */ +case class FunPointer(addr: BitVecLiteral, name: String, call: ExecutionContinuation) + extends BasilValue(Some(addr.getType)) + +sealed trait MapValue { + def value: Map[BasilValue, BasilValue] +} + +/* We erase the type of basil values and enforce the invariant that + \exists i . \forall v \in value.keys , v.irType = i and + \exists j . \forall v \in value.values, v.irType = j + */ +case class BasilMapValue(value: Map[BasilValue, BasilValue], mapType: MapType) + extends MapValue + with BasilValue(Some(mapType)) { + override def toString = s"MapValue : $irType" +} + +case class GenMapValue(value: Map[BasilValue, BasilValue]) extends BasilValue(None) with MapValue { + override def toString = s"GenMapValue : $irType" +} + +case class Symbol(value: String) extends BasilValue(None) + +case object BasilValue { + + def size(v: IRType): Int = { + v match { + case BitVecType(sz) => sz + case _ => 1 + } + } + + def toBV[S, E](l: BasilValue): Either[InterpreterError, BitVecLiteral] = { + l match { + case Scalar(b1: BitVecLiteral) => Right(b1) + case _ => Left(TypeError(s"Not a bitvector add $l")) + } + } + + def unsafeAdd[S, E](l: BasilValue, vr: Int): Either[InterpreterError, BasilValue] = { + l match { + case _ if vr == 0 => Right(l) + case Scalar(IntLiteral(vl)) => Right(Scalar(IntLiteral(vl + vr))) + case Scalar(b1: BitVecLiteral) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, BitVecLiteral(vr, b1.size)))) + case _ => Left(TypeError(s"Operation add $vr undefined on $l")) + } + } + + def add[S, E](l: BasilValue, r: BasilValue): Either[InterpreterError, BasilValue] = { + (l, r) match { + case (Scalar(IntLiteral(vl)), Scalar(IntLiteral(vr))) => Right(Scalar(IntLiteral(vl + vr))) + case (Scalar(b1: BitVecLiteral), Scalar(b2: BitVecLiteral)) => Right(Scalar(eval.evalBVBinExpr(BVADD, b1, b2))) + case _ => Left(TypeError(s"Operation add undefined $l + $r")) + } + } + +} + +/** Minimal language defining all state transitions in the interpreter, defined for the interpreter's concrete state T. + */ +trait Effects[T, E] { + /* expression eval */ + + def loadVar(v: String): State[T, BasilValue, E] + + def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue], E] + + def evalAddrToProc(addr: Int): State[T, Option[FunPointer], E] + + def getNext: State[T, ExecutionContinuation, E] + + /** state effects */ + + /* High-level implementation of a program counter that leverages the intrusive CFG. */ + def setNext(c: ExecutionContinuation): State[T, Unit, E] + + /* Perform a call: + * target: arbitrary target name + * beginFrom: ExecutionContinuation which begins executing the procedure + * returnTo: ExecutionContinuation which begins executing after procedure return + */ + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit, E] + + def callIntrinsic(name: String, args: List[BasilValue]): State[T, Option[BasilValue], E] + + def doReturn(): State[T, Unit, E] + + def storeVar(v: String, scope: Scope, value: BasilValue): State[T, Unit, E] + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit, E] +} + +trait NopEffects[T, E] extends Effects[T, E] { + def loadVar(v: String): State[T, BasilValue, E] = State.pure(Scalar(FalseLiteral)) + def loadMem(v: String, addrs: List[BasilValue]): State[T, List[BasilValue], E] = State.pure(List()) + def evalAddrToProc(addr: Int): State[T, Option[FunPointer], E] = State.pure(None) + def getNext: State[T, ExecutionContinuation, E] = State.pure(Stopped()) + def setNext(c: ExecutionContinuation): State[T, Unit, E] = State.pure(()) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[T, Unit, E] = State.pure(()) + def callIntrinsic(name: String, args: List[BasilValue]): State[T, Option[BasilValue], E] = State.pure(None) + def doReturn(): State[T, Unit, E] = State.pure(()) + + def storeVar(v: String, scope: Scope, value: BasilValue): State[T, Unit, E] = State.pure(()) + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[T, Unit, E] = State.pure(()) +} + +/*-------------------------------------------------------------------------------- + * Definition of concrete state + *--------------------------------------------------------------------------------*/ + +type StackFrameID = String +val globalFrame: StackFrameID = "GLOBAL" + +case class MemoryState( + /* We have a very permissive value reprsentation and store all dynamic state in `stackFrames`. + * - activations is the call stack, the top of which indicates the current stackFrame. + * - activationCount: (procedurename -> int) is used to create uniquely-named stackframes. + */ + stackFrames: Map[StackFrameID, Map[String, BasilValue]] = Map((globalFrame -> Map.empty)), + activations: List[StackFrameID] = List.empty, + activationCount: Map[String, Int] = Map.empty.withDefault(_ => 0) +) { + + /** Debug return useful values * */ + + def getGlobalVals: Map[String, BitVecLiteral] = { + stackFrames(globalFrame).collect { + case (k, Scalar(b: BitVecLiteral)) => k -> b + } + } + + def getMem(name: String): Map[BitVecLiteral, BitVecLiteral] = { + stackFrames(globalFrame)(name) match { + case BasilMapValue(innerMap, MapType(BitVecType(ks), BitVecType(vs))) => + def unwrap(v: BasilValue): BitVecLiteral = v match { + case Scalar(b: BitVecLiteral) => b + case v => throw Exception(s"Failed to convert map value to bitvector: $v (interpreter type error somewhere)") + } + innerMap.map((k, v) => unwrap(k) -> unwrap(v)) + case v => throw Exception(s"$name not a bitvec map variable: ${v.irType}") + } + } + + /** Local Variable Stack * */ + + def pushStackFrame(function: String): MemoryState = { + val counts = activationCount + (function -> (activationCount(function) + 1)) + val frameName: StackFrameID = s"AR_${function}_${activationCount(function)}" + val frames = stackFrames + (frameName -> Map.empty) + MemoryState(frames, frameName :: activations, counts) + } + + def popStackFrame(): Either[InterpreterError, MemoryState] = { + val hv = activations match { + case Nil => Left((Errored("No stack frame to pop"))) + case h :: Nil if h == globalFrame => Left((Errored("tried to pop global scope"))) + case h :: tl => Right((h, tl)) + } + hv.map { hv => + val (frame, remactivs) = hv + val frames = stackFrames.removed(frame) + MemoryState(frames, remactivs, activationCount) + } + } + + /* Variable retrieval and setting */ + + /* Set variable in a given frame */ + def setVar(frame: StackFrameID, varname: String, value: BasilValue): MemoryState = { + val nv = stackFrames + (frame -> (stackFrames(frame) + (varname -> value))) + MemoryState(nv, activations, activationCount) + } + + /* Find variable definition scope and set it in the correct frame */ + def setVar(v: String, value: BasilValue): MemoryState = { + val frame = findVarOpt(v).map(_(0)).getOrElse(activations.head) + setVar(frame, v, value) + } + + /* Define a variable in the scope specified + * ignoring whether it may already be defined + */ + def defVar(name: String, s: Scope, value: BasilValue): MemoryState = { + val frame = s match { + case Scope.Global => globalFrame + case _ => activations.head + } + setVar(frame, name, value) + } + + /* Lookup the value of a variable */ + def findVarOpt(name: String): Option[(StackFrameID, BasilValue)] = { + val searchScopes = globalFrame :: activations.headOption.toList + searchScopes.foldRight(None: Option[(StackFrameID, BasilValue)]) { (r, acc) => + acc match { + case None => stackFrames(r).get(name).map(v => (r, v)) + case s => s + } + } + } + + def findVar(name: String): Either[InterpreterError, (StackFrameID, BasilValue)] = { + findVarOpt(name: String) + .map(Right(_)) + .getOrElse(Left(Errored(s"Access to undefined variable $name"))) + } + + def getVarOpt(name: String): Option[BasilValue] = findVarOpt(name).map(_(1)) + + def getVar(name: String): Either[InterpreterError, BasilValue] = { + getVarOpt(name).map(Right(_)).getOrElse(Left(Errored(s"Access undefined variable $name"))) + } + + def getVar(v: Variable): Either[InterpreterError, BasilValue] = { + val value = getVar(v.name) + value match { + case Right(dv: BasilValue) if Some(v.getType) != dv.irType => + Left(Errored(s"Type mismatch on variable definition and load: defined ${dv.irType}, variable ${v.getType}")) + case Right(o) => Right(o) + case o => o + } + } + + /* Map variable accessing ; load and store operations */ + def doLoad(vname: String, addr: List[BasilValue]): Either[InterpreterError, List[BasilValue]] = for { + v <- findVar(vname) + mapv: MapValue <- v(1) match { + case m: MapValue => Right(m) + case m => Left(TypeError(s"Load from nonmap ${m.irType}")) + } + rs: List[Option[BasilValue]] = addr.map(k => mapv.value.get(k)) + xs <- + if (rs.forall(_.isDefined)) { + Right(rs.map(_.get)) + } else { + Left(MemoryError(s"Read from uninitialised $vname[${addr.head} .. ${addr.last}]")) + } + } yield xs + + /** typecheck and some fields of a map variable */ + def doStore(vname: String, values: Map[BasilValue, BasilValue]): Either[InterpreterError, MemoryState] = for { + _ <- if values.isEmpty then Left(MemoryError("Tried to store size 0")) else Right(()) + v <- findVar(vname) + (frame, mem) = v + mapval <- mem match { + case m @ BasilMapValue(_, MapType(kt, vt)) => + for { + m <- values.find((k, v) => k.irType != Some(kt) || v.irType != Some(vt)) match { + case Some(v) => + Left( + TypeError( + s"Invalid addr or value type (${v(0).irType}, ${v(1).irType}) does not match map type $vname : ($kt, $vt)" + ) + ) + case None => Right(m) + } + nm = BasilMapValue(m.value ++ values, m.mapType) + } yield nm + case m @ GenMapValue(_) => + Right(GenMapValue(m.value ++ values)) + case v => Left(TypeError(s"Invalid map store operation to $vname : ${v.irType}")) + } + + ms <- Right(setVar(frame, vname, mapval)) + } yield ms +} + +object LibcIntrinsic { + + /** + * Part of the intrinsics implementation that lives above the Effects interface + * (i.e. we are defining the observable part of the intrinsics behaviour) + */ + + def singleArg[S, E, T <: Effects[S, E]](name: String)(s: T): State[S, Unit, E] = for { + c <- s.loadVar("R0") + res <- s.callIntrinsic(name, List(c)) + _ <- if res.isDefined then s.storeVar("R0", Scope.Global, res.get) else State.pure(()) + _ <- s.doReturn() + } yield () + + def calloc[S, T <: Effects[S, InterpreterError]](s: T): State[S, Unit, InterpreterError] = for { + size <- s.loadVar("R0") + res <- s.callIntrinsic("malloc", List(size)) + ptr = res.get + isize <- size match { + case Scalar(b: BitVecLiteral) => State.pure(b.value * 8) + case _ => State.setError(Errored("programmer error")) + } + cl <- Eval.storeBV(s)("mem", ptr, BitVecLiteral(0, isize.toInt), Endian.LittleEndian) + _ <- s.doReturn() + } yield () + + def intrinsics[S, T <: Effects[S, InterpreterError]]: Map[String, T => State[S, Unit, InterpreterError]] = + Map[String, T => State[S, Unit, InterpreterError]]( + "putc" -> singleArg("putc"), + "puts" -> singleArg("puts"), + "printf" -> singleArg("print"), + "malloc" -> singleArg("malloc"), + "free" -> singleArg("free"), + "#free" -> singleArg("free"), + "calloc" -> calloc + ) + +} + +object IntrinsicImpl { + + /** state initialisation for file modelling */ + def initFileGhostRegions[S, E, T <: Effects[S, E]](f: T): State[S, Unit, E] = for { + _ <- f.storeVar("ghost-file-bookkeeping", Scope.Global, GenMapValue(Map.empty)) + _ <- f.storeVar("ghost-fd-mapping", Scope.Global, GenMapValue(Map.empty)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> Scalar(BitVecLiteral(0, 64)))) + _ <- f.callIntrinsic("fopen", List(Symbol("stderr"))) + _ <- f.callIntrinsic("fopen", List(Symbol("stdout"))) + } yield () + + /** Intrinsics defined over arbitrary effects + * + * We call these from Effects[T, E] rather than the Interpreter so their implementation does not appear in the trace. + */ + def putc[S, T <: Effects[S, InterpreterError]](f: T)(arg: BasilValue): State[S, Option[BasilValue], InterpreterError] = { + for { + addr <- f.loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) + byte <- State.pureE(BasilValue.toBV(arg)) + c <- Eval.evalBV(f)(Extract(8, 0, byte)) + _ <- f.storeMem("stdout", Map(addr.head -> Scalar(c))) + naddr <- State.pureE(BasilValue.unsafeAdd(addr.head, 1)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) + } yield None + } + + def fopen[S, T <: Effects[S, InterpreterError]](f: T)(file: BasilValue): State[S, Option[BasilValue], InterpreterError] = { + for { + fname <- file match { + case Symbol(name) => State.pure(name) + case _ => State.setError(Errored("Intrinsic fopen open not given filename")) + } + _ <- f.storeVar(fname, Scope.Global, BasilMapValue(Map.empty, MapType(BitVecType(64), BitVecType(8)))) + filecount <- f.loadMem("ghost-file-bookkeeping", List(Symbol("$$filecount"))) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol(fname + "-ptr") -> Scalar(BitVecLiteral(0, 64)))) + _ <- f.storeMem("ghost-fd-mapping", Map(filecount.head -> Symbol(fname + "-ptr"))) + _ <- f.storeVar("R0", Scope.Global, filecount.head) + nfilecount <- State.pureE(BasilValue.unsafeAdd(filecount.head, 1)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("$$filecount") -> nfilecount)) + } yield Some(filecount.head) + } + + def print[S, T <: Effects[S, InterpreterError]](f: T)(strptr: BasilValue): State[S, Option[BasilValue], InterpreterError] = { + for { + str <- Eval.getNullTerminatedString(f)("mem", strptr) + baseptr: List[BasilValue] <- f.loadMem("ghost-file-bookkeeping", List(Symbol("stdout-ptr"))) + offs: List[BasilValue] <- State.mapM( + ((i: Int) => State.pureE(BasilValue.unsafeAdd(baseptr.head, i))), + (0 until (str.size + 1)) + ) + _ <- f.storeMem("stdout", offs.zip(str.map(Scalar(_))).toMap) + naddr <- State.pureE(BasilValue.unsafeAdd(baseptr.head, str.size)) + _ <- f.storeMem("ghost-file-bookkeeping", Map(Symbol("stdout-ptr") -> naddr)) + } yield None + } + + def malloc[S, T <: Effects[S, InterpreterError]](f: T)(size: BasilValue): State[S, Option[BasilValue], InterpreterError] = { + for { + size <- (size match { + case x @ Scalar(_: BitVecLiteral) => State.pure(x) + case Scalar(x: IntLiteral) => State.pure(Scalar(BitVecLiteral(x.value, 64))) + case _ => State.setError(Errored("illegal prim arg")) + }) + x <- f.loadVar("ghost_malloc_top") + x_gap <- State.pureE(BasilValue.unsafeAdd(x, 128)) // put a gap around allocations to catch buffer overflows + x_end <- State.pureE(BasilValue.add(x_gap, size)) + _ <- f.storeVar("ghost_malloc_top", Scope.Global, x_end) + _ <- f.storeVar("R0", Scope.Global, x_gap) + } yield Some(x_gap) + } +} + +case class InterpreterState( + nextCmd: ExecutionContinuation = Stopped(), + callStack: List[ExecutionContinuation] = List.empty, + memoryState: MemoryState = MemoryState() +) + +/** Implementation of Effects for InterpreterState concrete state representation. + */ +object NormalInterpreter extends Effects[InterpreterState, InterpreterError] { + def callIntrinsic( + name: String, + args: List[BasilValue] + ): State[InterpreterState, Option[BasilValue], InterpreterError] = { + name match { + case "free" => State.pure(None) + case "malloc" => IntrinsicImpl.malloc(this)(args.head) + case "fopen" => IntrinsicImpl.fopen(this)(args.head) + case "putc" => IntrinsicImpl.putc(this)(args.head) + case "strlen" => + for { + str <- Eval.getNullTerminatedString(this)("mem", args.head) + r = Scalar(BitVecLiteral(str.length, 64)) + _ <- storeVar("R0", Scope.Global, r) + } yield Some(r) + case "print" => IntrinsicImpl.print(this)(args.head) + case "puts" => IntrinsicImpl.print(this)(args.head) >> IntrinsicImpl.putc(this)(Scalar(BitVecLiteral('\n'.toInt, 64))) + case _ => State.setError(Errored(s"Call undefined intrinsic $name")) + } + } + + def loadVar(v: String): State[InterpreterState, BasilValue, InterpreterError] = { + State.getE((s: InterpreterState) => { + s.memoryState.getVar(v) + }) + } + + def evalAddrToProc(addr: Int): State[InterpreterState, Option[FunPointer], InterpreterError] = + Logger.debug(s" eff : FIND PROC $addr") + for { + res: List[BasilValue] <- State.getE((s: InterpreterState) => + s.memoryState.doLoad("ghost-funtable", List(Scalar(BitVecLiteral(addr, 64)))) + ) + } yield { + res match { + case (f: FunPointer) :: Nil => Some(f) + case _ => None + } + } + + def formatStore(varname: String, update: Map[BasilValue, BasilValue]): String = { + val ks = update.toList.sortWith { (x, y) => + def conv(v: BasilValue): BigInt = v match { + case Scalar(b: BitVecLiteral) => b.value + case Scalar(b: IntLiteral) => b.value + case _ => BigInt(0) + } + conv(x(0)) <= conv(y(0)) + } + + val rs = ks.foldLeft(Some((None, List[BitVecLiteral]())): Option[(Option[BigInt], List[BitVecLiteral])]) { + (acc, v) => + v match { + case (Scalar(bv: BitVecLiteral), Scalar(bv2: BitVecLiteral)) => + acc match { + case None => None + case Some(None, l) => Some(Some(bv.value), bv2 :: l) + case Some(Some(v), l) if bv.value == v + 1 => Some(Some(bv.value), bv2 :: l) + case Some(Some(v), l) => + None + } + case (bv, bv2) => None + } + } + + rs match { + case Some(_, l) => + val vs = Scalar(l.foldLeft(BitVecLiteral(0, 0))((acc, r) => eval.evalBVBinExpr(BVCONCAT, acc, r))).toString + s"$varname[${ks.headOption.map(_(0)).getOrElse("null")}] := $vs" + case None if ks.length < 8 => s"$varname[${ks.map(_(0)).mkString(",")}] := ${ks.map(_(1)).mkString(",")}" + case None => s"$varname[${ks.map(_(0)).take(8).mkString(",")}...] := ${ks.map(_(1)).take(8).mkString(", ")}... " + } + + } + + def loadMem(v: String, addrs: List[BasilValue]): State[InterpreterState, List[BasilValue], InterpreterError] = { + State.getE((s: InterpreterState) => { + val r = s.memoryState.doLoad(v, addrs) + Logger.debug(s" eff : LOAD ${addrs.head} x ${addrs.size}") + r + }) + } + + def getNext: State[InterpreterState, ExecutionContinuation, InterpreterError] = State.get((s: InterpreterState) => s.nextCmd) + + /** effects * */ + def setNext(c: ExecutionContinuation): State[InterpreterState, Unit, InterpreterError] = State.modify((s: InterpreterState) => { + s.copy(nextCmd = c) + }) + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[InterpreterState, Unit, InterpreterError] = + State.modify((s: InterpreterState) => { + Logger.debug(s" eff : CALL $target") + s.copy( + nextCmd = beginFrom, + callStack = returnTo :: s.callStack, + memoryState = s.memoryState.pushStackFrame(target) + ) + }) + + def doReturn(): State[InterpreterState, Unit, InterpreterError] = { + Logger.debug(s" eff : RETURN") + State.modifyE((s: InterpreterState) => { + s.callStack match { + case Nil => Right(s.copy(nextCmd = Stopped())) + case h :: tl => + for { + ms <- s.memoryState.popStackFrame() + } yield (s.copy(nextCmd = h, callStack = tl, memoryState = ms)) + } + }) + } + + def storeVar(v: String, scope: Scope, value: BasilValue): State[InterpreterState, Unit, InterpreterError] = { + Logger.debug(s" eff : SET $v := $value") + State.modify((s: InterpreterState) => s.copy(memoryState = s.memoryState.defVar(v, scope, value))) + } + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[InterpreterState, Unit, InterpreterError] = + State.modifyE((s: InterpreterState) => { + Logger.debug(s" eff : STORE ${formatStore(vname, update)}") + for { + ms <- s.memoryState.doStore(vname, update) + } yield s.copy(memoryState = ms) + }) +} + +trait Interpreter[S, E](val f: Effects[S, E]) { + + /* + * Returns value deciding whether to continue. + */ + def interpretOne: State[S, Boolean, E] + + @tailrec + final def run(begin: S): S = { + val (fs, cont) = interpretOne.f(begin) + + if (cont.contains(true)) { + run(fs) + } else { + fs + } + } +} diff --git a/src/main/scala/ir/eval/InterpreterProduct.scala b/src/main/scala/ir/eval/InterpreterProduct.scala new file mode 100644 index 000000000..18841a9c3 --- /dev/null +++ b/src/main/scala/ir/eval/InterpreterProduct.scala @@ -0,0 +1,130 @@ +package ir.eval +import ir.* +import util.Logger +import util.functional.* +import boogie.Scope + +def doLeft[L, T, V, E](f: State[L, V, E]): State[(L, T), V, E] = for { + n <- State[(L, T), V, E]((s: (L, T)) => { + val r = f.f(s(0)) + ((r(0), s(1)), r(1)) + }) +} yield n + +def doRight[L, T, V, E](f: State[T, V, E]): State[(L, T), V, E] = for { + n <- State[(L, T), V, E]((s: (L, T)) => { + val r = f.f(s(1)) + ((s(0), r(0)), r(1)) + }) +} yield n + +/** Runs two interpreters "inner" and "before" simultaneously, returning the value from inner, and ignoring before + */ +case class ProductInterpreter[L, T, E](inner: Effects[L, E], before: Effects[T, E]) extends Effects[(L, T), E] { + + def loadVar(v: String): State[(L, T), BasilValue, E] = for { + n <- doRight(before.loadVar(v)) + f <- doLeft(inner.loadVar(v)) + } yield f + + def loadMem(v: String, addrs: List[BasilValue]): State[(L, T), List[BasilValue], E] = for { + n <- doRight(before.loadMem(v, addrs)) + f <- doLeft(inner.loadMem(v, addrs)) + } yield f + + def evalAddrToProc(addr: Int): State[(L, T), Option[FunPointer], E] = for { + n <- doRight(before.evalAddrToProc(addr: Int)) + f <- doLeft(inner.evalAddrToProc(addr)) + } yield f + + def getNext: State[(L, T), ExecutionContinuation, E] = for { + n <- doRight(before.getNext) + f <- doLeft(inner.getNext) + } yield f + + /** state effects */ + def setNext(c: ExecutionContinuation): State[(L, T), Unit, E] = for { + n <- doRight(before.setNext(c)) + f <- doLeft(inner.setNext(c)) + } yield f + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[(L, T), Unit, E] = for { + n <- doRight(before.call(target, beginFrom, returnTo)) + f <- doLeft(inner.call(target, beginFrom, returnTo)) + } yield f + + def callIntrinsic(name: String, args: List[BasilValue]): State[(L, T), Option[BasilValue], E] = for { + n <- doRight(before.callIntrinsic(name, args)) + f <- doLeft(inner.callIntrinsic(name, args)) + } yield f + + def doReturn(): State[(L, T), Unit, E] = for { + n <- doRight(before.doReturn()) + f <- doLeft(inner.doReturn()) + } yield f + + def storeVar(v: String, scope: Scope, value: BasilValue): State[(L, T), Unit, E] = for { + n <- doRight(before.storeVar(v, scope, value)) + f <- doLeft(inner.storeVar(v, scope, value)) + } yield f + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[(L, T), Unit, E] = for { + n <- doRight(before.storeMem(vname, update)) + f <- doLeft(inner.storeMem(vname, update)) + } yield f +} + +case class LayerInterpreter[L, T, E](inner: Effects[L, E], before: Effects[(L, T), E]) + extends Effects[(L, T), E] { + + def loadVar(v: String): State[(L, T), BasilValue, E] = for { + n <- before.loadVar(v) + f <- doLeft(inner.loadVar(v)) + } yield f + + def loadMem(v: String, addrs: List[BasilValue]): State[(L, T), List[BasilValue], E] = for { + n <- before.loadMem(v, addrs) + f <- doLeft(inner.loadMem(v, addrs)) + } yield f + + def evalAddrToProc(addr: Int): State[(L, T), Option[FunPointer], E] = for { + n <- before.evalAddrToProc(addr) + f <- doLeft(inner.evalAddrToProc(addr)) + } yield f + + def getNext: State[(L, T), ExecutionContinuation, E] = for { + n <- before.getNext + f <- doLeft(inner.getNext) + } yield f + + /** state effects */ + def setNext(c: ExecutionContinuation): State[(L, T), Unit, E] = for { + n <- before.setNext(c) + f <- doLeft(inner.setNext(c)) + } yield f + + def call(target: String, beginFrom: ExecutionContinuation, returnTo: ExecutionContinuation): State[(L, T), Unit, E] = for { + n <- before.call(target, beginFrom, returnTo) + f <- doLeft(inner.call(target, beginFrom, returnTo)) + } yield f + + def callIntrinsic(name: String, args: List[BasilValue]): State[(L, T), Option[BasilValue], E] = for { + n <- before.callIntrinsic(name, args) + f <- doLeft(inner.callIntrinsic(name, args)) + } yield f + + def doReturn(): State[(L, T), Unit, E] = for { + n <- before.doReturn() + f <- doLeft(inner.doReturn()) + } yield f + + def storeVar(v: String, scope: Scope, value: BasilValue): State[(L, T), Unit, E] = for { + n <- before.storeVar(v, scope, value) + f <- doLeft(inner.storeVar(v, scope, value)) + } yield f + + def storeMem(vname: String, update: Map[BasilValue, BasilValue]): State[(L, T), Unit, E] = for { + n <- before.storeMem(vname, update) + f <- doLeft(inner.storeMem(vname, update)) + } yield f +} diff --git a/src/main/scala/util/PerformanceTimer.scala b/src/main/scala/util/PerformanceTimer.scala index a79b1d3b7..f5fe74151 100644 --- a/src/main/scala/util/PerformanceTimer.scala +++ b/src/main/scala/util/PerformanceTimer.scala @@ -15,7 +15,7 @@ case class PerformanceTimer(timerName: String = "") { Logger.debug(s"PerformanceTimer $timerName [$name]: ${delta}ms") delta } - private def elapsed() : Long = { + def elapsed() : Long = { System.currentTimeMillis() - lastCheckpoint } diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index 0dec78187..6d43c4851 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -5,6 +5,7 @@ import com.grammatech.gtirb.proto.IR.IR import com.grammatech.gtirb.proto.Module.Module import com.grammatech.gtirb.proto.Section.Section import spray.json.* +import ir.eval import gtirb.* import scala.collection.mutable.ListBuffer @@ -540,8 +541,14 @@ object RunUtils { q.loading.dumpIL.foreach(s => writeToFile(serialiseIL(ctx.program), s"$s-after-analysis.il")) if (q.runInterpret) { - val interpreter = Interpreter() - interpreter.interpret(ctx.program) + val fs = eval.interpretTrace(ctx) + Logger.info("Interpreter Trace:\n" + fs._2.t.mkString("\n")) + val stopState = fs._1.nextCmd + if (stopState != eval.Stopped()) { + Logger.error(s"Interpreter exited with $stopState") + } else { + Logger.info("Interpreter stopped normally.") + } } IRTransform.prepareForTranslation(q, ctx) diff --git a/src/main/scala/util/functional/State.scala b/src/main/scala/util/functional/State.scala new file mode 100644 index 000000000..fed5f6329 --- /dev/null +++ b/src/main/scala/util/functional/State.scala @@ -0,0 +1,105 @@ +package util.functional + +/* + * Flattened state monad with error. + */ +case class State[S, A, E](f: S => (S, Either[E, A])) { + + def unit[A](a: A): State[S, A, E] = State(s => (s, Right(a))) + + def >>(o: State[S,A,E]) = for { + _ <- this + x <- o + } yield x + + def flatMap[B](f: A => State[S, B, E]): State[S, B, E] = State(s => { + val (s2, a) = this.f(s) + val r = a match { + case Left(l) => (s2, Left(l)) + case Right(a) => f(a).f(s2) + } + r + }) + + def map[B](f: A => B): State[S, B, E] = { + State(s => { + val (s2, a) = this.f(s) + a match { + case Left(l) => (s2, Left(l)) + case Right(a) => (s2, Right(f(a))) + } + }) + } + + def flatMapE(f: E => State[S, A, E]): State[S, A, E] = { + State(s => { + val (s2, a) = this.f(s) + a match { + case Left(l) => f(l).f(s2) + case Right(_) => (s2, a) + } + }) + } +} + + +object State { + def get[S, A, E](f: S => A) : State[S, A, E] = State(s => (s, Right(f(s)))) + def getE[S, A, E](f: S => Either[E,A]) : State[S, A, E] = State(s => (s, f(s))) + def getS[S,E]: State[S,S,E] = State((s:S) => (s,Right(s))) + def putS[S,E](s: S): State[S,Unit,E] = State(_ => (s,Right(()))) + def modify[S, E](f: S => S): State[S, Unit, E] = State(s => (f(s), Right(()))) + def modifyE[S, E](f: S => Either[E, S]): State[S, Unit, E] = State(s => f(s) match { + case Right(ns) => (ns, Right(())) + case Left(e) => (s, Left(e)) + }) + def execute[S, A, E](s: S, c: State[S,A, E]): S = c.f(s)._1 + def evaluate[S, A, E](s: S, c: State[S,A, E]): Either[E,A] = c.f(s)._2 + + def setError[S,A,E](e: E): State[S,A,E] = State(s => (s, Left(e))) + + def pure[S, A, E](a: A): State[S, A, E] = State((s:S) => (s, Right(a))) + def pureE[S, A, E](a: Either[E, A]): State[S, A, E] = State((s:S) => (s, a)) + + def sequence[S, V, E](ident: State[S,V, E], xs: Iterable[State[S,V, E]]): State[S, V, E] = { + xs.foldRight(ident) { + (l, r) => for { + x <- l + y <- r + } yield y + } + } + + def filterM[A, S, E](m: (A => State[S, Boolean, E]), xs: Iterable[A]): State[S, List[A], E] = { + xs.foldRight(pure(List[A]()))((b,acc) => acc.flatMap(c => m(b).map(v => if v then b::c else c))) + } + + def mapM[A, B, S, E](m: (A => State[S, B, E]), xs: Iterable[A]): State[S, List[B], E] = { + xs.foldRight(pure(List[B]()))((b,acc) => acc.flatMap(c => m(b).map(v => v::c))) + } + + def protect[S, V, E](f: () => State[S, V, E], fnly: PartialFunction[Exception, E]): State[S, V, E] = { + State((s: S) => try { + f().f(s) + } catch { + case e: Exception if fnly.isDefinedAt(e) => (s, Left(fnly(e))) + }) + } + + def protectPure[S,V,E](f: () => V, fnly: PartialFunction[Exception, E]): State[S, V, E] = { + State((s: S) => try { + (s, Right(f())) + } catch { + case e: Exception if fnly.isDefinedAt(e) => (s, Left(fnly(e))) + }) + } + +} + +def protect[T](x: () => T, fnly: PartialFunction[Exception, T]): T = { + try { + x() + } catch { + case e: Exception if fnly.isDefinedAt(e) => fnly(e) + } +} diff --git a/src/test/scala/BitVectorAnalysisTests.scala b/src/test/scala/BitVectorAnalysisTests.scala index 485ba6a83..08710a940 100644 --- a/src/test/scala/BitVectorAnalysisTests.scala +++ b/src/test/scala/BitVectorAnalysisTests.scala @@ -1,4 +1,4 @@ -import analysis.BitVectorEval.* +import ir.eval.BitVectorEval._ import ir.* import org.scalatest.funsuite.AnyFunSuite import util.Logger @@ -181,20 +181,20 @@ class BitVectorAnalysisTests extends AnyFunSuite { // smt_bveq test("BitVector Equal - should return true if two BitVectors are equal") { val result = smt_bveq(BitVecLiteral(255, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test("BitVector Equal - should return false if two BitVectors are not equal") { val result = smt_bveq(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvneq test("BitVector Not Equal - should return false if two BitVectors are equal") { val result = smt_bvneq(BitVecLiteral(255, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } test("BitVector Not Equal - should return true if two BitVectors are not equal") { val result = smt_bvneq(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } // smt_bvshl test("BitVector Shift Left - should shift bits left") { @@ -239,14 +239,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { test("BitVector unsigned less then - should return true if first argument is less than second argument") { val result = smt_bvult(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsigned less then - should return false if first argument is greater than or equal to second argument" ) { val result = smt_bvult(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvule @@ -255,14 +255,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector unsigned less then or equal to - should return true if first argument is less equal to second argument" ) { val result = smt_bvule(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsigned less then or equal to - should return false if first argument is greater than second argument" ) { val result = smt_bvule(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvugt @@ -270,14 +270,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector unsinged greater than - should return true if first argument is greater equal to than second argument" ) { val result = smt_bvugt(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsinged greater than - should return false if first argument is less than or equal to second argument" ) { val result = smt_bvugt(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvuge @@ -285,27 +285,27 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector unsinged greater than or equal to - should return true if first argument is greater equal or equal to second argument" ) { val result = smt_bvuge(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector unsinged greater than or equal to - should return false if first argument is less than second argument" ) { val result = smt_bvuge(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvslt test("BitVector signed less than - should return true if first argument is less than second argument") { val result = smt_bvslt(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector signed less than - should return false if first argument is greater than or equal to second argument" ) { val result = smt_bvslt(BitVecLiteral(254, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvsle @@ -313,25 +313,25 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector signed less than or equal to - should return true if first argument is less than or equal to second argument" ) { val result = smt_bvsle(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector signed less than or equal to - should return false if first argument is greater than second argument" ) { val result = smt_bvsle(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvsgt test("BitVector signed greater than - should return true if first argument is greater than second argument") { val result = smt_bvsgt(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test("BitVector signed greater than - should return false if first argument is less than second argument") { val result = smt_bvsgt(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvsge @@ -339,14 +339,14 @@ class BitVectorAnalysisTests extends AnyFunSuite { "BitVector signed greater than or equal to - should return true if first argument is greater than or equal to second argument" ) { val result = smt_bvsge(BitVecLiteral(255, 8), BitVecLiteral(254, 8)) - assert(result == TrueLiteral) + assert(result) } test( "BitVector signed greater than or equal to - should return false if first argument is less than second argument" ) { val result = smt_bvsge(BitVecLiteral(254, 8), BitVecLiteral(255, 8)) - assert(result == FalseLiteral) + assert(!result) } // smt_bvashr test("BitVector Arithmetic shift right - should return shift right a positive number") { diff --git a/src/test/scala/DifferentialAnalysis.scala b/src/test/scala/DifferentialAnalysis.scala new file mode 100644 index 000000000..6062d2db5 --- /dev/null +++ b/src/test/scala/DifferentialAnalysis.scala @@ -0,0 +1,116 @@ + +import ir.* +import java.io.{BufferedWriter, File, FileWriter} +import ir.Endian.LittleEndian +import org.scalatest.* +import org.scalatest.funsuite.* +import specification.* +import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel, IRTransform} +import ir.eval.* +import test_util.BASILTest.getSubdirectories + +import java.io.IOException +import java.nio.file.* +import java.nio.file.attribute.BasicFileAttributes +import ir.dsl.* +import util.RunUtils.loadAndTranslate + +import scala.collection.mutable + +class DifferentialAnalysis extends AnyFunSuite { + + Logger.setLevel(LogLevel.WARN) + + def diffTest(initial: IRContext, transformed: IRContext): Unit = { + + val instructionLimit = 1000000 + + def interp(p: IRContext) : (InterpreterState, Trace) = { + val interpreter = LayerInterpreter(tracingInterpreter(NormalInterpreter), EffectsRLimit(instructionLimit)) + val initialState = InterpFuns.initProgState(NormalInterpreter)(p, InterpreterState()) + //Logger.setLevel(LogLevel.DEBUG) + val (r, _) = BASILInterpreter(interpreter).run((initialState, Trace(List())), 0) + //Logger.setLevel(LogLevel.WARN) + r + } + + val (initialRes,traceInit) = interp(initial) + val (result,traceRes) = interp(transformed) + + def filterEvents(trace: List[ExecEffect]) = { + trace.collect { + case e @ ExecEffect.Call(_, _, _) => e + case e @ ExecEffect.StoreMem("mem", _) => e + case e @ ExecEffect.LoadMem("mem", _) => e + } + } + + Logger.info(traceInit.t.map(_.toString.take(80)).mkString("\n")) + val initstdout = initialRes.memoryState.getMem("stdout") + val comparstdout = result.memoryState.getMem("stdout") + val text = initstdout.toList.sortBy(_._1.value).map(_._2.value.toChar).mkString("") + info("STDOUT: \"" + text + "\"") + assert(initstdout == comparstdout) + assert(initialRes.nextCmd == Stopped()) + assert(result.nextCmd == Stopped()) + assert(traceInit.t.nonEmpty) + assert(traceRes.t.nonEmpty) + assert(filterEvents(traceInit.t).mkString("\n") == filterEvents(traceRes.t).mkString("\n")) + } + + def testProgram(name: String, variation: String, path: String): Unit = { + val variationPath = path + name + "/" + variation + "/" + name + val loading = ILLoadingConfig( + inputFile = variationPath + ".adt", + relfFile = variationPath + ".relf", + dumpIL = None, + ) + + var ictx = IRLoading.load(loading) + ictx = IRTransform.doCleanup(ictx) + + var comparectx = IRLoading.load(loading) + comparectx = IRTransform.doCleanup(ictx) + val analysisres = RunUtils.staticAnalysis(StaticAnalysisConfig(None, None, None), comparectx) + + diffTest(ictx, comparectx) + } + + test("indirect_calls/indirect_call/gcc_pic:BAP") { + testProgram("indirect_call", "gcc_pic", "./src/test/indirect_calls/") + } + + test("indirect_calls/jumptable2/gcc_pic:BAP") { + testProgram("jumptable2", "gcc_pic", "./src/test/indirect_calls/") + } + + test("indirect_calls/jumptable/gcc:BAP") { + testProgram("jumptable", "gcc", "./src/test/indirect_calls/") + } + + test("functionpointer/gcc_pic:BAP") { + testProgram("functionpointer", "gcc_pic", "./src/test/indirect_calls/") + } + + def runTests(): Unit = { + val path = System.getProperty("user.dir") + s"/src/test/correct/" + val programs = getSubdirectories(path) + + // get all variations of each program + for (p <- programs) { + val programPath = path + "/" + p + val variations = getSubdirectories(programPath) + variations.foreach { t => + val variationPath = programPath + "/" + t + "/" + p + val inputPath = variationPath + ".adt" + if (File(inputPath).exists) { + test("correct" + "/" + p + "/" + t + ":BAP") { + testProgram(p, t, path) + } + } + } + } + } + + runTests() +} diff --git a/src/test/scala/InterpretTestConstProp.scala b/src/test/scala/InterpretTestConstProp.scala new file mode 100644 index 000000000..49f4fcd48 --- /dev/null +++ b/src/test/scala/InterpretTestConstProp.scala @@ -0,0 +1,119 @@ +import ir.* +import ir.eval.* +import analysis.* +import java.io.{BufferedWriter, File, FileWriter} +import ir.Endian.LittleEndian +import org.scalatest.* +import org.scalatest.funsuite.* +import specification.* +import util.{BASILConfig, IRLoading, ILLoadingConfig, IRContext, RunUtils, StaticAnalysis, StaticAnalysisConfig, StaticAnalysisContext, BASILResult, Logger, LogLevel, IRTransform} +import ir.eval.{interpretTrace, interpret, ExecEffect, Stopped} +import ir.dsl + + +import java.io.IOException +import java.nio.file.* +import java.nio.file.attribute.BasicFileAttributes +import ir.dsl.* +import util.RunUtils.loadAndTranslate + +import scala.collection.mutable + +class ConstPropInterpreterValidate extends AnyFunSuite { + + Logger.setLevel(LogLevel.ERROR) + + def testInterpretConstProp(name: String, variation: String, path: String): Unit = { + val variationPath = path + name + "/" + variation + "/" + name + val loading = ILLoadingConfig( + inputFile = variationPath + ".adt", + relfFile = variationPath + ".relf", + dumpIL = None, + ) + + var ictx = IRLoading.load(loading) + ictx = IRTransform.doCleanup(ictx) + val analysisres = RunUtils.staticAnalysis(StaticAnalysisConfig(None, None, None), ictx) + + val breaks: List[BreakPoint] = analysisres.constPropResult.collect { + // convert analysis result to a list of breakpoints, each which evaluates an expression describing + // the invariant inferred by the analysis (the assignment of registers) at a corresponding program point + + case (command: Command, v) => + val expectedPredicates: List[(String, Expr)] = v.toList.map { r => + val (variable, value) = r + val assertion = value match { + case Top => TrueLiteral + case Bottom => FalseLiteral /* unreachable */ + case FlatEl(value) => BinaryExpr(BVEQ, variable, value) + } + (variable.name, assertion) + } + BreakPoint( + location = BreakPointLoc.CMD(command), + BreakPointAction(saveState = false, evalExprs = expectedPredicates) + ) + }.toList + + assert(breaks.nonEmpty) + + // run the interpreter evaluating the analysis result at each command with a breakpoint + val interpretResult = interpretWithBreakPoints(ictx, breaks.toList, NormalInterpreter, InterpreterState()) + val breakres: List[(BreakPoint, _, List[(String, Expr, Expr)])] = interpretResult._2 + assert(interpretResult._1.nextCmd == Stopped()) + assert(breakres.nonEmpty) + + // assert all the collected breakpoint watches have evaluated to true + for (b <- breakres) { + val (_, _, evaluatedexprs) = b + evaluatedexprs.forall { c => + val (n, before, evaled) = c + evaled == TrueLiteral + } + } + } + + test("indirect_call/gcc_pic:BAP") { + testInterpretConstProp("indirect_call", "gcc_pic", "./src/test/indirect_calls/") + } + + test("indirect_call/gcc:BAP") { + testInterpretConstProp("indirect_call", "gcc", "./src/test/indirect_calls/") + } + + test("indirect_call/clang:BAP") { + testInterpretConstProp("indirect_call", "clang", "./src/test/indirect_calls/") + } + + test("jumptable2/gcc_pic:BAP") { + testInterpretConstProp("jumptable2", "gcc_pic", "./src/test/indirect_calls/") + } + + test("jumptable2/gcc:BAP") { + testInterpretConstProp("jumptable2", "gcc", "./src/test/indirect_calls/") + } + + test("jumptable2/clang:BAP") { + testInterpretConstProp("jumptable2", "clang", "./src/test/indirect_calls/") + } + + test("functionpointer/gcc_pic:BAP") { + testInterpretConstProp("functionpointer", "gcc_pic", "./src/test/indirect_calls/") + } + + test("functionpointer/gcc:BAP") { + testInterpretConstProp("functionpointer", "gcc", "./src/test/indirect_calls/") + } + + test("functionpointer/clang:BAP") { + testInterpretConstProp("functionpointer", "clang", "./src/test/indirect_calls/") + } + + test("secret_write/clang:BAP") { + testInterpretConstProp("secret_write", "clang", "./src/test/correct/") + } + + test("secret_write/gcc:BAP") { + testInterpretConstProp("secret_write", "gcc", "./src/test/correct/") + } +} diff --git a/src/test/scala/ir/IRTest.scala b/src/test/scala/ir/IRTest.scala index a1592f9eb..3c9bc74d5 100644 --- a/src/test/scala/ir/IRTest.scala +++ b/src/test/scala/ir/IRTest.scala @@ -131,13 +131,12 @@ class IRTest extends AnyFunSuite { val aftercallGotos = p.collect { case c: Command if isAfterCall(c) => c }.toSet - // assert(aftercallGotos == Set(blocks("l_main_1").fallthrough.get)) assert(1 == aftercallGotos.count(b => IntraProcIRCursor.pred(b).contains(blocks("l_main_1").jump))) assert(1 == aftercallGotos.count(b => IntraProcIRCursor.succ(b).contains(blocks("l_main_1").jump match { case GoTo(targets, _) => targets.head + case _ => throw Exception("unreachable") }))) - } test("addblocks") { @@ -304,6 +303,10 @@ class IRTest extends AnyFunSuite { transforms.addReturnBlocks(p) cilvisitor.visit_prog(transforms.ConvertSingleReturn(), p) + cilvisitor.visit_prog(transforms.ReplaceReturns(), p) + transforms.addReturnBlocks(p) + cilvisitor.visit_prog(transforms.ConvertSingleReturn(), p) + val blocks = p.labelToBlock val procs = p.nameToProcedure diff --git a/src/test/scala/ir/InterpreterTests.scala b/src/test/scala/ir/InterpreterTests.scala index 92f728ee8..6d206bfaf 100644 --- a/src/test/scala/ir/InterpreterTests.scala +++ b/src/test/scala/ir/InterpreterTests.scala @@ -1,45 +1,71 @@ package ir -import analysis.BitVectorEval.* +import util.PerformanceTimer +import util.functional.* +import ir.eval.* +import boogie.Scope +import ir.dsl.* import org.scalatest.funsuite.AnyFunSuite import org.scalatest.BeforeAndAfter import specification.SpecGlobal import translating.BAPToIR import util.{LogLevel, Logger} import util.IRLoading.{loadBAP, loadReadELF} -import util.ILLoadingConfig +import util.{ILLoadingConfig, IRContext, IRLoading, IRTransform} -class InterpreterTests extends AnyFunSuite with BeforeAndAfter { +def load(s: InterpreterState, global: SpecGlobal): Option[BitVecLiteral] = { + val f = NormalInterpreter + + State.evaluate( + s, + Eval.evalBV(f)( + MemoryLoad(SharedMemory("mem", 64, 8), BitVecLiteral(global.address, 64), Endian.LittleEndian, global.size) + ) + ) match { + case Right(e) => Some(e) + case Left(e) => { + None + } + } +} - var i: Interpreter = Interpreter() - Logger.setLevel(LogLevel.DEBUG) +def mems[E, T <: Effects[T, E]](m: MemoryState): Map[BigInt, BitVecLiteral] = { + m.getMem("mem").map((k, v) => k.value -> v) +} - def getProgram(name: String): (Program, Set[SpecGlobal]) = { +class InterpreterTests extends AnyFunSuite with BeforeAndAfter { + Logger.setLevel(LogLevel.WARN) + + def getProgram(name: String, folder: String): IRContext = { + val compiler = "gcc" val loading = ILLoadingConfig( - inputFile = s"src/test/correct/$name/gcc/$name.adt", - relfFile = s"src/test/correct/$name/gcc/$name.relf", + inputFile = s"src/test/$folder/$name/$compiler/$name.adt", + relfFile = s"src/test/$folder/$name/$compiler/$name.relf", specFile = None, dumpIL = None ) - val bapProgram = loadBAP(loading.inputFile) - val (_, externalFunctions, globals, _, _, mainAddress) = loadReadELF(loading.relfFile, loading) - val IRTranslator = BAPToIR(bapProgram, mainAddress) - var IRProgram = IRTranslator.translate - IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) - IRProgram = Renamer(Set("free")).visitProgram(IRProgram) - transforms.stripUnreachableFunctions(IRProgram) - val stackIdentification = StackSubstituter() - stackIdentification.visitProgram(IRProgram) - IRProgram.setModifies(Map()) - - (IRProgram, globals) + val p = IRLoading.load(loading) + val ctx = IRTransform.doCleanup(p) + // val bapProgram = loadBAP(loading.inputFile) + // val (symbols, externalFunctions, globals, _, mainAddress) = loadReadELF(loading.relfFile, loading) + // val IRTranslator = BAPToIR(bapProgram, mainAddress) + // var IRProgram = IRTranslator.translate + // IRProgram = ExternalRemover(externalFunctions.map(e => e.name)).visitProgram(IRProgram) + // IRProgram = Renamer(Set("free")).visitProgram(IRProgram) + //IRProgram.stripUnreachableFunctions() + // val stackIdentification = StackSubstituter() + // stackIdentification.visitProgram(IRProgram) + ctx.program.setModifies(Map()) + ctx } - def testInterpret(name: String, expected: Map[String, Int]): Unit = { - val (program, globals) = getProgram(name) - val regs = i.interpret(program) + def testInterpret(name: String, folder: String, expected: Map[String, Int]): Unit = { + val ctx = getProgram(name, folder) + val fstate = interpret(ctx) + val regs = fstate.memoryState.getGlobalVals + val globals = ctx.globals // Show interpreted result Logger.info("Registers:") @@ -48,112 +74,92 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { } Logger.info("Globals:") + // def loadBV(vname: String, addr: BasilValue, valueSize: Int, endian: Endian, size: Int): List[BitVecLiteral] = { globals.foreach { global => - val mem = i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems) - Logger.info(s"$global := $mem") + val mem = load(fstate, global) + mem.foreach(mem => Logger.info(s"$global := $mem")) } // Test expected value - expected.foreach { (name, expected) => - globals.find(_.name == name) match { - case Some(global) => - val actual = i.getMemory(global.address.toInt, global.size, Endian.LittleEndian, i.mems).value.toInt - assert(actual == expected) - case None => assert("None" == name) - } - } + val actual: Map[String, Int] = expected.flatMap((name, expected) => + globals.find(_.name == name).flatMap(global => load(fstate, global).map(gv => name -> gv.value.toInt)) + ) + assert(fstate.nextCmd == Stopped()) + assert(expected == actual) } - before { - i = Interpreter() - } + test("initialise") { - test("getMemory in LittleEndian") { - i.mems(0) = BitVecLiteral(BigInt("0D", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("0C", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("0B", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("0A", 16), 8) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) - assert(actual == expected) - } + val init = InterpFuns.initialState(NormalInterpreter) + + val s = State.execute(InterpreterState(), init) + assert(s.memoryState.getVarOpt("mem").isDefined) + assert(s.memoryState.getVarOpt("stack").isDefined) + assert(s.memoryState.getVarOpt("R31").isDefined) + assert(s.memoryState.getVarOpt("R29").isDefined) - test("getMemory in BigEndian") { - i.mems(0) = BitVecLiteral(BigInt("0A", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("0B", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("0C", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("0D", 16), 8) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) - assert(actual == expected) } - test("setMemory in LittleEndian") { - i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - i.setMemory(0, 32, Endian.LittleEndian, expected, i.mems) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.LittleEndian, i.mems) - assert(actual == expected) + test("var load store") { + val s = for { + s <- InterpFuns.initialState(NormalInterpreter) + v <- NormalInterpreter.storeVar("R1", Scope.Global, Scalar(BitVecLiteral(1024, 64))) + v <- NormalInterpreter.loadVar("R1") + } yield (v) + val l = State.evaluate(InterpreterState(), s) + + assert(l == Right(Scalar(BitVecLiteral(1024, 64)))) } - test("setMemory in BigEndian") { - i.mems(0) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(1) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(2) = BitVecLiteral(BigInt("FF", 16), 8) - i.mems(3) = BitVecLiteral(BigInt("FF", 16), 8) - val expected: BitVecLiteral = BitVecLiteral(BigInt("0A0B0C0D", 16), 32) - i.setMemory(0, 32, Endian.BigEndian, expected, i.mems) - val actual: BitVecLiteral = i.getMemory(0, 32, Endian.BigEndian, i.mems) - assert(actual == expected) + test("Store = Load LittleEndian") { + val ts = List( + BitVecLiteral(BigInt("0D", 16), 8), + BitVecLiteral(BigInt("0C", 16), 8), + BitVecLiteral(BigInt("0B", 16), 8), + BitVecLiteral(BigInt("0A", 16), 8) + ) + + val loader = StVarLoader(NormalInterpreter) + + val s = for { + _ <- InterpFuns.initialState(NormalInterpreter) + _ <- Eval.store(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), ts.map(Scalar(_)), Endian.LittleEndian) + r <- Eval.loadBV(NormalInterpreter)("mem", Scalar(BitVecLiteral(0, 64)), Endian.LittleEndian, 32) + } yield (r) + val expected: BitVecLiteral = BitVecLiteral(BigInt("0D0C0B0A", 16), 32) + val actual = State.evaluate(InterpreterState(), s) + assert(actual == Right(expected)) + } - /* test("basic_arrays_read") { val expected = Map( "arr" -> 0 ) - testInterpret("basic_arrays_read", expected) + testInterpret("basic_arrays_read", "correct", expected) } test("basic_assign_assign") { val expected = Map( "x" -> 5 ) - testInterpret("basic_assign_assign", expected) + testInterpret("basic_assign_assign", "correct", expected) } test("basic_assign_increment") { val expected = Map( "x" -> 1 ) - testInterpret("basic_assign_increment", expected) + testInterpret("basic_assign_increment", "correct", expected) } - test("basic_loop_loop") { - val expected = Map( - "x" -> 10 - ) - testInterpret("basic_loop_loop", expected) - } - - test("basicassign") { - val expected = Map( - "x" -> 0, - "z" -> 0, - "secret" -> 0 - ) - testInterpret("basicassign", expected) - } test("function") { val expected = Map( "x" -> 1, "y" -> 2 ) - testInterpret("function", expected) + testInterpret("function", "correct", expected) } test("function1") { @@ -161,7 +167,7 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "x" -> 1, "y" -> 1410065515 // 10000000107 % 2147483648 = 1410065515 ) - testInterpret("function1", expected) + testInterpret("function1", "correct", expected) } test("secret_write") { @@ -170,14 +176,19 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "x" -> 0, "secret" -> 0 ) - testInterpret("secret_write", expected) + testInterpret("secret_write", "correct", expected) + } + + test("indirect_call") { + val expected = Map[String, Int]() + testInterpret("indirect_call", "indirect_calls", expected) } test("ifglobal") { val expected = Map( "x" -> 1 ) - testInterpret("ifglobal", expected) + testInterpret("ifglobal", "correct", expected) } test("cjump") { @@ -185,21 +196,154 @@ class InterpreterTests extends AnyFunSuite with BeforeAndAfter { "x" -> 1, "y" -> 3 ) - testInterpret("cjump", expected) + testInterpret("cjump", "correct", expected) + } + + test("initialisation") { + + // Logger.setLevel(LogLevel.WARN) + val expected = Map( + "x" -> 6, + "y" -> ('b'.toInt) + ) + + testInterpret("initialisation", "correct", expected) } test("no_interference_update_x") { val expected = Map( "x" -> 1 ) - testInterpret("no_interference_update_x", expected) + testInterpret("no_interference_update_x", "correct", expected) } test("no_interference_update_y") { val expected = Map( "y" -> 1 ) - testInterpret("no_interference_update_y", expected) + testInterpret("no_interference_update_y", "correct", expected) } - */ + + def fib(n: Int): Int = { + n match { + case 0 => 0 + case 1 => 1 + case n => fib(n - 1) + fib(n - 2) + } + } + + def fibonacciProg(n: Int) = { + prog( + proc( + "begin", + block("entry", Assign(R8, Register("R31", 64)), Assign(R0, bv64(n)), directCall("fib"), goto("done")), + block("done", Assert(BinaryExpr(BVEQ, R0, bv64(fib(n)))), ret) + ), + proc( + "fib", + block("base", goto("base1", "base2", "dofib")), + block("base1", Assume(BinaryExpr(BVEQ, R0, bv64(0))), ret), + block("base2", Assume(BinaryExpr(BVEQ, R0, bv64(1))), ret), + block( + "dofib", + Assume(BinaryExpr(BoolAND, BinaryExpr(BVNEQ, R0, bv64(0)), BinaryExpr(BVNEQ, R0, bv64(1)))), + // R8 stack pointer preserved across calls + Assign(R7, BinaryExpr(BVADD, R8, bv64(8))), + MemoryAssign(stack, R7, R8, Endian.LittleEndian, 64), // sp + Assign(R8, R7), + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 8 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 8] = arg0 + Assign(R0, BinaryExpr(BVSUB, R0, bv64(1))), + directCall("fib"), + Assign(R2, R8), // sp + 8 + Assign(R8, BinaryExpr(BVADD, R8, bv64(8))), // sp + 16 + MemoryAssign(stack, R8, R0, Endian.LittleEndian, 64), // [sp + 16] = r1 + Assign(R0, MemoryLoad(stack, R2, Endian.LittleEndian, 64)), // [sp + 8] + Assign(R0, BinaryExpr(BVSUB, R0, bv64(2))), + directCall("fib"), + Assign(R2, MemoryLoad(stack, R8, Endian.LittleEndian, 64)), // [sp + 16] (r1) + Assign(R0, BinaryExpr(BVADD, R0, R2)), + Assign(R8, MemoryLoad(stack, BinaryExpr(BVSUB, R8, bv64(16)), Endian.LittleEndian, 64)), + ret + ) + ) + ) + } + + test("fibonacci") { + + Logger.setLevel(LogLevel.ERROR) + val fib = fibonacciProg(8) + val r = interpret(fib) + assert(r.nextCmd == Stopped()) + // Show interpreted result + // r.regs.foreach { (key, value) => + // Logger.info(s"$key := $value") + // } + + } + + test("fibonaccistress") { + + Logger.setLevel(LogLevel.ERROR) + var res = List[(Int, Double, Double, Int)]() + + for (i <- 0 to 20) { + val prog = fibonacciProg(i) + + val t = PerformanceTimer("native") + val r = fib(i) + val native = t.elapsed() + + val intt = PerformanceTimer("interp") + val ir = interpretRLimit(prog, 100000000) + val it = intt.elapsed() + + res = (i, native.toDouble, it.toDouble, ir._2) :: res + + } + + info(("fibonacci runtime table:\nFibNumber,ScalaRunTime,interpreterRunTime,instructionCycleCount" :: (res.map(x => s"${x._1},${x._2},${x._3},${x._4}"))).mkString("\n")) + + } + + test("fibonacci Trace") { + + val fib = fibonacciProg(8) + + val r = interpretTrace(fib) + + assert(r._1.nextCmd == Stopped()) + // Show interpreted result + // + + } + + test("fib breakpoints") { + + Logger.setLevel(LogLevel.INFO) + val fib = fibonacciProg(8) + val watch = IRWalk.firstInProc((fib.procedures.find(_.name == "fib")).get).get + val bp = BreakPoint( + "Fibentry", + BreakPointLoc.CMDCond(watch, BinaryExpr(BVEQ, BitVecLiteral(5, 64), Register("R0", 64))), + BreakPointAction(true, true, List(("R0", Register("R0", 64))), true) + ) + val bp2 = BreakPoint("Fibentry", BreakPointLoc.CMD(watch), BreakPointAction(true, true , List(("R0", Register("R0", 64))), true)) + val res = interpretWithBreakPoints(fib, List(bp), NormalInterpreter, InterpreterState()) + assert(res._1.nextCmd.isInstanceOf[ErrorStop]) + assert(res._2.nonEmpty) + } + + test("Capture IllegalArg") { + + val tp = prog( + proc("begin", block("shouldfail", Assign(R0, ZeroExtend(-1, BitVecLiteral(0, 64))), ret)) + ) + + val ir = interpret(tp) + assert(ir.nextCmd.isInstanceOf[ErrorStop]) + + } + } diff --git a/src/test/scala/util/StateMonad.scala b/src/test/scala/util/StateMonad.scala new file mode 100644 index 000000000..efcf4e767 --- /dev/null +++ b/src/test/scala/util/StateMonad.scala @@ -0,0 +1,31 @@ +import ir._ +import util.functional._ + + +import ir.eval._ +import ir.dsl._ +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.BeforeAndAfter +import specification.SpecGlobal +import translating.BAPToIR +import util.{LogLevel, Logger} +import util.IRLoading.{loadBAP, loadReadELF} +import util.ILLoadingConfig + + +def add: State[Int, Unit, Unit] = State(s => (s+1, Right(()))) + +class StateMonadTest extends AnyFunSuite { + + test("forcompre") { + val s = for { + _ <- add + _ <- add + _ <- add + } yield () + + + val res = State.execute(0, s) + assert(res == 3) + } +}