diff --git a/examples/notdependent b/examples/notdependent index 091af25..93335f2 100644 --- a/examples/notdependent +++ b/examples/notdependent @@ -9,7 +9,7 @@ _Mode: NoW _var r_secret: _L: FALSE -_P_0: z % 2 == 0 +_P_0: z == 0 _Gamma_0: x -> LOW, r_secret -> HIGH -x = r_secret * 0; // shouldn't fail as not dependent on r_secret +x = r_secret * z; // shouldn't fail as not dependent on r_secret as z == 0 diff --git a/examples/notdependent2 b/examples/notdependent2 new file mode 100644 index 0000000..4290e9a --- /dev/null +++ b/examples/notdependent2 @@ -0,0 +1,15 @@ +_var z: +_L: TRUE +_Mode: NoW + +_var x: +_L: z % 2 == 0 +_Mode: NoW + +_array r_secret[2]: +_L: FALSE + +_P_0: z % 2 == 0 +_Gamma_0: x -> LOW, r_secret -> HIGH + +x = r_secret[1] * r_secret[0] * 0; // shouldn't fail as not dependent on r_secret diff --git a/examples/notdependent3 b/examples/notdependent3 new file mode 100644 index 0000000..d45a695 --- /dev/null +++ b/examples/notdependent3 @@ -0,0 +1,15 @@ +_var z: +_L: TRUE +_Mode: NoW + +_var x: +_L: z % 2 == 0 +_Mode: NoW + +_array r_secret[2]: +_L: FALSE + +_P_0: z == 0 +_Gamma_0: x -> LOW, r_secret -> HIGH + +x = r_secret[z] - r_secret[0]; // shouldn't fail as not dependent on r_secret diff --git a/examples/notdependent4 b/examples/notdependent4 new file mode 100644 index 0000000..d45bede --- /dev/null +++ b/examples/notdependent4 @@ -0,0 +1,19 @@ +_var z: +_L: TRUE +_Mode: NoW + +_var x: +_L: z % 2 == 0 +_Mode: NoW + +_var r_secret: +_L: FALSE + +_array a[2]: +_L: TRUE +_Mode: NoW + +_P_0: z == 0, a[0] == 0 +_Gamma_0: x -> LOW, r_secret -> HIGH, a[*] -> LOW + +x = r_secret * a[0]; // shouldn't fail as not dependent on r_secret as a[0] == 0 diff --git a/todo b/todo index 0b93c16..f98602a 100644 --- a/todo +++ b/todo @@ -1,9 +1,7 @@ --comment everything --z3 simplifier? --atomic block - specify which variable stable, update D? --write up array rules - +-pruning P by removing predicates mostly dependent on otherwise unused variables? -organise expression class - boolean expression subclass & clean up operations? -more powerful array expressions? -make handling dependent variable stuff work with arrays? +options for different features + +optimisations? diff --git a/tool/src/tool/Exec.scala b/tool/src/tool/Exec.scala index 6aad487..d9b86f9 100644 --- a/tool/src/tool/Exec.scala +++ b/tool/src/tool/Exec.scala @@ -846,8 +846,11 @@ object Exec { println("knownW: " + knownw) } + + //val fallingFail = for (y <- falling -- st3.noReadWrite if !knownw.contains(y) || st3.security(y, PRestrict) == High) + // falling can only succeed if y is in gamma and maps to low - val fallingFail = for (y <- falling -- st3.noReadWrite if !knownw.contains(y) || st3.security(y, PRestrict) == High) + val fallingFail = for (y <- falling -- st3.noReadWrite if !st3.gamma.contains(y) || st3.gamma(y) != Low) yield y if (fallingFail.nonEmpty) { @@ -956,9 +959,10 @@ object Exec { println("knownW: " + knownw) } - // falling can only succeed if y is in gamma and maps to low + //val fallingFail = for (y <- falling -- st2.noReadWrite if !knownw.contains(y) || st2.security(y, PRestrict) == High) - val fallingFail = for (y <- falling -- st2.noReadWrite if !knownw.contains(y) || st2.security(y, PRestrict) == High) + // falling can only succeed if y is in gamma and maps to low + val fallingFail = for (y <- falling -- st2.noReadWrite if !st2.gamma.contains(y) || st2.gamma(y) != Low) yield y if (fallingFail.nonEmpty) { @@ -986,11 +990,11 @@ object Exec { // CAS rule if (st0.toLog) println("CAS applying") - // computes rd + // compute rd val (_r1, st1) = eval(r1, st0) val (_r2, st2) = eval(r2, st1) val st3 = st2.updateRead(x) - // computes wr + // compute wr val st4 = st3.updateWritten(x) val st5 = st4.updateWritten(lhs) // at this point the rd and wr sets are complete for the current line @@ -1118,9 +1122,10 @@ object Exec { println("knownW: " + knownw) } - // falling can only succeed if y is in gamma and maps to low + //val fallingFail = for (y <- falling -- st5.noReadWrite if !knownw.contains(y) || st5.security(y, PRestrictAssign) == High) - val fallingFail = for (y <- falling -- st5.noReadWrite if !knownw.contains(y) || st5.security(y, PRestrictAssign) == High) + // falling can only succeed if y is in gamma and maps to low + val fallingFail = for (y <- falling -- st5.noReadWrite if !st5.gamma.contains(y) || st5.gamma(y) != Low) yield y if (fallingFail.nonEmpty) { diff --git a/tool/src/tool/Expression.scala b/tool/src/tool/Expression.scala index a0cfa09..cc157bd 100644 --- a/tool/src/tool/Expression.scala +++ b/tool/src/tool/Expression.scala @@ -74,7 +74,11 @@ case class Access(name: Id, index: Expression) extends Expression { // array access with Var for use in logical predicates case class VarAccess(name: Var, index: Expression) extends Expression { def variables: Set[Id] = index.variables - def subst(su: Subst) = VarAccess(name.subst(su), index.subst(su)) + def subst(su: Subst) = if (su.keySet.contains(this)) { + su.getOrElse(this, this) + } else { + VarAccess(name.subst(su), index.subst(su)) + } // don't substitute in the case that the index expression is an integer but not the one specified override def subst(su: Subst, num: Int) = index match { @@ -83,7 +87,7 @@ case class VarAccess(name: Var, index: Expression) extends Expression { case _ => VarAccess(name.subst(su), index.subst(su)) } - override def toString = name + "[" + index + "]" + //override def toString = name + "[" + index + "]" override def arrays = this.name match { case Var(_, Some(index)) => Set() @@ -157,6 +161,7 @@ object MultiSwitch { } } +/* case class not(arg: BoolExpression) extends BoolExpression { override def toString = "(! " + arg + ")" override def variables: Set[Id] = arg.variables @@ -181,7 +186,6 @@ case class or(arg1: BoolExpression, arg2: BoolExpression) extends BoolExpression override def arrays = arg1.arrays ++ arg2.arrays } -/* case class eq(arg1: Expression, arg2: Expression) extends BoolExpression { override def toString = "(" + arg1 + " && " + arg2 + ")" override def variables: Set[Id] = arg1.variables ++ arg2.variables diff --git a/tool/src/tool/SMT.scala b/tool/src/tool/SMT.scala index 87539f4..ddde5c5 100644 --- a/tool/src/tool/SMT.scala +++ b/tool/src/tool/SMT.scala @@ -22,7 +22,7 @@ object SMT { solver.check } catch { case e: java.lang.UnsatisfiedLinkError if e.getMessage.equals("com.microsoft.z3.Native.INTERNALgetErrorMsgEx(JI)Ljava/lang/String;")=> - // weird unintuitive error z3 can have when input type is incorrect in a way it doesn't check + // weird unintuitive error z3 can have when an input type is incorrect in a way it doesn't check throw error.Z3Error("Z3 failed", cond, given.PStr, "incorrect z3 expression type, probably involving ForAll/Exists") case e: Throwable => throw error.Z3Error("Z3 failed", cond, given.PStr, e) @@ -40,7 +40,6 @@ object SMT { res == z3.Status.UNSATISFIABLE } - def proveSat(cond: Expression, given: List[Expression], debug: Boolean) = { if (debug) println("smt checking " + cond + " given " + given.PStr) @@ -55,7 +54,7 @@ object SMT { solver.check } catch { case e: java.lang.UnsatisfiedLinkError if e.getMessage.equals("com.microsoft.z3.Native.INTERNALgetErrorMsgEx(JI)Ljava/lang/String;")=> - // weird unintuitive error z3 can have when input type is incorrect in a way it doesn't check + // weird unintuitive error z3 can have when an input type is incorrect in a way it doesn't check throw error.Z3Error("Z3 failed", cond, given.PStr, "incorrect z3 expression type, probably involving ForAll/Exists") case e: Throwable => throw error.Z3Error("Z3 failed", cond, given.PStr, e) @@ -83,7 +82,7 @@ object SMT { solver.check } catch { case e: java.lang.UnsatisfiedLinkError if e.getMessage.equals("com.microsoft.z3.Native.INTERNALgetErrorMsgEx(JI)Ljava/lang/String;")=> - // weird unintuitive error z3 can have when input type is incorrect in a way it doesn't check + // weird unintuitive error z3 can have when an input type is incorrect in a way it doesn't check throw error.Z3Error("Z3 failed", given.PStr, "incorrect z3 expression type, probably involving ForAll/Exists") case e: Throwable => throw error.Z3Error("Z3 failed", given.PStr, e) @@ -100,17 +99,6 @@ object SMT { res == z3.Status.SATISFIABLE } - // recursively convert expression list into AND structure - def PToAnd(exprs: List[Expression]): z3.BoolExpr = exprs match { - case Nil => - ctx.mkTrue - - case expr :: rest => - val xs = PToAnd(rest) - val x = ctx.mkAnd(formula(expr), xs) - x - } - def proveImplies(strong: List[Expression], weak: List[Expression], debug: Boolean) = { if (debug) println("smt checking !(" + strong.PStr + newline + " implies " + weak.PStr + ")") @@ -133,6 +121,45 @@ object SMT { res == z3.Status.UNSATISFIABLE } + def proveExpression(cond: Expression, debug: Boolean) = { + if (debug) + println("smt checking (" + cond + ")") + solver.push() + val res = try { + // check that (NOT cond) is unsatisfiable + solver.add(formula(cond)) + solver.check + } catch { + case e: java.lang.UnsatisfiedLinkError if e.getMessage.equals("com.microsoft.z3.Native.INTERNALgetErrorMsgEx(JI)Ljava/lang/String;")=> + // weird unintuitive error z3 can have when an input type is incorrect in a way it doesn't check + throw error.Z3Error("Z3 failed", cond, "incorrect z3 expression type, probably involving ForAll/Exists") + case e: Throwable => + throw error.Z3Error("Z3 failed", cond, e) + } finally { + solver.pop() + } + + if (debug) { + println(res) + if (res == z3.Status.SATISFIABLE) { + val model = solver.getModel + println(model) + } + } + res == z3.Status.SATISFIABLE + } + + // recursively convert expression list into AND structure + def PToAnd(exprs: List[Expression]): z3.BoolExpr = exprs match { + case Nil => + ctx.mkTrue + + case expr :: rest => + val xs = PToAnd(rest) + val x = ctx.mkAnd(formula(expr), xs) + x + } + def formula(prop: Expression): z3.BoolExpr = translate(prop) match { case b: z3.BoolExpr => b case e => @@ -211,7 +238,6 @@ object SMT { // array index case VarAccess(name, index) => ctx.mkSelect(ctx.mkArrayConst(name.toString, ctx.getIntSort, ctx.getIntSort), translate(index)) - case _ => throw error.InvalidProgram("cannot translate to SMT", prop) } diff --git a/tool/src/tool/State.scala b/tool/src/tool/State.scala index 585fb21..454d5e7 100644 --- a/tool/src/tool/State.scala +++ b/tool/src/tool/State.scala @@ -151,38 +151,6 @@ case class State( mergePs(possiblePs) } - - - /* old arrayassign that - // substitute variable in P for fresh variable - val PReplace = index match { - // if this array access is unambiguous, don't quantify array access for assignments to different unambiguous - indices of this array - case Lit(n) => - P map (p => p.subst(toSubst, n)) - case _ => - P map (p => p.subst(toSubst)) - } - - // substitute variable in expression for fresh variable - val argReplace = arg.subst(toSubst, i) - - val indexToSubst: Subst = { - for (i <- index.variables) - yield i -> i.toVar - }.toMap ++ toSubst - - val indexSubst = index.subst(indexToSubst, i) - - // add new assignment statement to P - val PPrime = BinOp("==", VarAccess(v, indexSubst), argReplace) :: PReplace - - // restrict PPrime to stable variables - val POut = State.restrictP(PPrime, stable) - */ - - - if (debug) { println("assigning " + arg + " to " + a + "[" + index + "]" + ":") println("P: " + P.PStr) @@ -522,7 +490,7 @@ case class State( // ((index == 0) && (L(A[0]))) || ((index == 1) && (L(A[1]))) || ... to array.size def arrayAccessCheck(array: IdArray, indices: Seq[Int], index: Expression): Expression = { - val list = {for (i <- indices) + val list: List[Expression] = {for (i <- indices) yield BinOp("&&", BinOp("==", index, Lit(i)), L(array.array(i)))}.toList orPredicates(list) } @@ -537,6 +505,16 @@ case class State( x } + def andPredicates(exprs: List[Expression]): Expression = exprs match { + case Nil => + Const._true + + case expr :: rest => + val xs = andPredicates(rest) + val x = BinOp("&&", expr, xs) + x + } + // x is variable to get security of, p is P value given so can substitute in P_a etc. def security(x: Id, p: List[Expression]): Security = { if (debug) @@ -544,10 +522,8 @@ case class State( var sec: Security = High if (gamma.contains(x)) { sec = gamma(x) - } else { - if (lowP(x, p)) { - sec = Low - } // true/LOW if L(x) holds given P, false/HIGH otherwise + } else if (lowP(x, p)) { + sec = Low // LOW if L(x) holds given P, HIGH otherwise } if (debug) println(x + " security is " + sec) @@ -560,10 +536,7 @@ case class State( println("checking security for " + e) var sec: Security = Low val varE = e.variables - - var secMap: Map[Id, Security] = Map() - for (x <- varE) { val xSec = security(x, p) if (xSec == High) { @@ -573,30 +546,56 @@ case class State( } val arraysE = e.arrays - val it2 = arraysE.toIterator - while (it2.hasNext && sec == Low) { - val a: Access = it2.next() - sec = security(a.name, a.index, p) + var arraySecMap: Map[Access, Security] = Map() + for (a <- arraysE) { + val aSec = security(a.name, a.index, p) + if (aSec == High) { + sec = High + } + arraySecMap += (a -> aSec) } - // dealing with arrays with this check not implemented yet - if (sec == High && arraysE.isEmpty) { + // checking for possibility expression contains a high variable/array access but its evaluation is not dependent + // on the high data, e.g. high - high, high * 0 + if (sec == High) { + val highAccess: Set[Access] = (arraySecMap collect {case a if a._2 == High => a._1}).toSet + val idToVar: Subst = { for (v <- variables) yield v -> v.toVar }.toMap ++ { for (v <- arrays.keySet) yield v -> v.toVar - }.toMap + } + + // replace high array accesses with new variables + + var arrayReplace: Map[Access, Var] = Map() + for (a <- highAccess) { + arrayReplace += (a -> Var.fresh(a.name.name)) + for (i <- arrayReplace.keySet if i != a) { + // check if array indices can be proved to be equivalent, if so replace with same variable + if (i.name == a.name && SMT.prove(BinOp("==", i.index, a.index), p, debug)) { + arrayReplace += (a -> arrayReplace(i)) + } + } + } + val toSubst: Subst = arrayReplace map {case (k, v) => k.subst(idToVar) -> v} + + val eSubst = e.subst(toSubst).subst(idToVar) - val eSubst = e.subst(idToVar) - val high: Set[Var] = (secMap collect {case x if x._2 == High => x._1.toVar}).toSet - val test = if (guard) { + // set of variables to bind - high variables and replacement variables for high array accesses + val high: Set[Var] = (secMap collect {case x if x._2 == High => x._1.toVar}).toSet ++ + (highAccess map {x => toSubst(x.subst(idToVar))}) + + // guards are boolean expressions so must create boolean variable for them + val v = if (guard) { Switch(0) // to make boolean variable } else { Var("_var") // not a valid variable name from parser so won't clash } - if (SMT.prove(Exists(Set(test), ForAll(high, BinOp("==", eSubst, test))), p, debug)) { + + if (SMT.prove(Exists(Set(v), ForAll(high, BinOp("==", eSubst, v))), p, debug)) { sec = Low } }