Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EC-384 : Aded Set support for PrismLang #36

Open
wants to merge 6 commits into
base: release
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions src/main/scala/org/encryfoundation/prismlang/codec/PCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@ object PCodec {
implicit def dString = dT.bind[PString.type](4)
implicit def dByte = dT.bind[PByte.type](5)
implicit def dColl = dT.bind[PCollection](6)
implicit def dFunc = dT.bind[PFunc](7)
implicit def dTuple = dT.bind[PTuple](8)
implicit def dObj = dT.bind[ArbitraryProduct](9)
implicit def dStructTag = dT.bind[StructTag](10)
implicit def dSig = dT.bind[Signature25519.type](11)
implicit def dMulSig = dT.bind[MultiSig.type](12)
implicit def dEncryBox = dT.bind[EncryBox.type](13)
implicit def dAssetBox = dT.bind[AssetBox.type](14)
implicit def dAiBox = dT.bind[AssetIssuingBox.type](15)
implicit def dDBox = dT.bind[DataBox.type](16)
implicit def dTransact = dT.bind[EncryTransaction.type](17)
implicit def dState = dT.bind[EncryState.type](18)
implicit def dNit = dT.bind[Nit.type](19)
implicit def dSet = dT.bind[PSet](7)
implicit def dFunc = dT.bind[PFunc](8)
implicit def dTuple = dT.bind[PTuple](9)
implicit def dObj = dT.bind[ArbitraryProduct](10)
implicit def dStructTag = dT.bind[StructTag](11)
implicit def dSig = dT.bind[Signature25519.type](12)
implicit def dMulSig = dT.bind[MultiSig.type](13)
implicit def dEncryBox = dT.bind[EncryBox.type](14)
implicit def dAssetBox = dT.bind[AssetBox.type](15)
implicit def dAiBox = dT.bind[AssetIssuingBox.type](16)
implicit def dDBox = dT.bind[DataBox.type](17)
implicit def dTransact = dT.bind[EncryTransaction.type](18)
implicit def dState = dT.bind[EncryState.type](19)
implicit def dNit = dT.bind[Nit.type](20)


implicit def dExp = Discriminated[Expr, Int](uint8)
implicit def dBlc = dExp.bind[Expr.Block](0)
Expand All @@ -52,16 +54,17 @@ object PCodec {
implicit def dByteConst = dExp.bind[Expr.ByteConst](16)
implicit def dStr = dExp.bind[Expr.Str](17)
implicit def dCollConst = dExp.bind[Expr.Collection](18)
implicit def dTupleConst = dExp.bind[Expr.Tuple](19)
implicit def dBase58Str = dExp.bind[Expr.Base58Str](20)
implicit def dBase16Str = dExp.bind[Expr.Base16Str](21)
implicit def dTrue = dExp.bind[Expr.True.type](22)
implicit def dFalse = dExp.bind[Expr.False.type](23)
implicit def dSizeOf = dExp.bind[Expr.SizeOf](24)
implicit def dExists = dExp.bind[Expr.Exists](25)
implicit def dSum = dExp.bind[Expr.Sum](26)
implicit def dMap = dExp.bind[Expr.Map](27)
implicit def dFilt = dExp.bind[Expr.Filter](28)
implicit def dSetConst = dExp.bind[Expr.PrismSet](19)
implicit def dTupleConst = dExp.bind[Expr.Tuple](20)
implicit def dBase58Str = dExp.bind[Expr.Base58Str](21)
implicit def dBase16Str = dExp.bind[Expr.Base16Str](22)
implicit def dTrue = dExp.bind[Expr.True.type](23)
implicit def dFalse = dExp.bind[Expr.False.type](24)
implicit def dSizeOf = dExp.bind[Expr.SizeOf](25)
implicit def dExists = dExp.bind[Expr.Exists](26)
implicit def dSum = dExp.bind[Expr.Sum](27)
implicit def dMap = dExp.bind[Expr.Map](28)
implicit def dFilt = dExp.bind[Expr.Filter](29)

implicit def dNode = Discriminated[Node, Int](uint2)
implicit def dModule = dNode.bind[Module](0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ case class CostEstimator(initialEnv: Map[String, Int]) {
}

def costOfConst: Cost = {
case Expr.PrismSet(elts, _) => CollEltC * elts.length + elts.map(costOf).sum
case Expr.Collection(elts, _) => CollEltC * elts.length + elts.map(costOf).sum
case Expr.Tuple(elts, _) => TupleEltC * elts.length + elts.map(costOf).sum
case Expr.Base58Str(value) => CharC * value.length + DecodingC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ case class StaticAnalyser(initialScope: ScopedSymbolTable, types: TypeSystem) ex
scanSimpleExpr orElse
scanRef orElse
scanCollection orElse
scanSet orElse
scanConstant orElse
scanTransformers orElse
pass
Expand Down Expand Up @@ -196,11 +197,26 @@ case class StaticAnalyser(initialScope: ScopedSymbolTable, types: TypeSystem) ex
eltsS.foreach(elt => matchType(eltsS.head.tpe, elt.tpe, Some(s"Collection is inconsistent, ${elt.tpe} stands out.")))
eltsS.head.tpe match {
case Types.PCollection(inT) if inT.isCollection => error("Illegal level of nesting")
case Types.PSet(inT) if inT.isCollection => error("Illegal level of nesting")
case _ => // Do nothing
}
coll.copy(eltsS, computeType(coll.copy(eltsS)))
}

def scanSet: Scan = {
case set @ Expr.PrismSet(elts, _) =>
if (elts.size > Constants.CollMaxLength) error(s"Set size limit overflow (${elts.size} > ${Constants.CollMaxLength})")
else if (elts.size < 1) error("Empty set")
val eltsS: List[Expr] = elts.map(scan).distinct
eltsS.foreach(elt => matchType(eltsS.head.tpe, elt.tpe, Some(s"Set is inconsistent, ${elt.tpe} stands out.")))
eltsS.head.tpe match {
case Types.PCollection(inT) if inT.isCollection => error("Illegal level of nesting")
case Types.PSet(inT) if inT.isCollection => error("Illegal level of nesting")
case _ =>
}
set.copy(eltsS, computeType(set.copy(eltsS)))
}

def scanConstant: Scan = {
/** Scan each element of tuple ensuring its actual size does not
* overflow `TupleMaxLength`, then check type consistency of all elements. */
Expand Down Expand Up @@ -305,6 +321,10 @@ case class StaticAnalyser(initialScope: ScopedSymbolTable, types: TypeSystem) ex
case _: SliceOp.Index => inT
case _: SliceOp.Slice => coll
}
case set @ Types.PSet(inT) => op match {
case _: SliceOp.Index => inT
case _: SliceOp.Slice => set
}
case otherT => error(s"$otherT does not support subscription")
}
case Expr.Unary(_, operand, _) => computeType(operand)
Expand All @@ -313,6 +333,7 @@ case class StaticAnalyser(initialScope: ScopedSymbolTable, types: TypeSystem) ex
case Expr.Lambda(args, body, _) => Types.PFunc(types.resolveArgs(args), computeType(body))
case Expr.Tuple(elts, _) => Types.PTuple(elts.map(elt => computeType(elt)))
case Expr.Collection(elts, _) => Types.PCollection(computeType(elts.head))
case Expr.PrismSet(elts, _) => Types.PSet(computeType(elts.head))
case Expr.Map(_, func, _) => computeType(func) match {
case Types.PFunc(_, retT) => Types.PCollection(retT)
case otherT => error(s"$otherT is not a function")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ object Transformer {
case sub @ Expr.Subscript(value, _, _) => sub.copy(transform(value))

case coll @ Expr.Collection(elts, _) => coll.copy(elts.map(transform))

case set @ Expr.PrismSet(elts, _) => set.copy(elts.map(transform))

case tuple @ Expr.Tuple(elts, _) => tuple.copy(elts.map(transform))

case other => other
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/org/encryfoundation/prismlang/core/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ object Ast {

case class Collection(elts: List[Expr], override val tpe: PType = Nit) extends Expr

case class PrismSet(elts: List[Expr], override val tpe: PType = Nit) extends Expr

case class Tuple(elts: List[Expr], override val tpe: PType = Nit) extends Expr

case class Base58Str(value: String) extends Expr { override val tpe: PType = PCollection.ofByte }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ case class TypeSystem(additionalTypes: Seq[Types.PType]) {
def resolveType(ident: TypeIdent): Types.PType = {
val typeParams: List[Types.PType] = ident.typeParams.map(resolveType)
typeByIdent(ident.name).map {
case Types.PSet(_) =>
if (typeParams.size == 1) Types.PSet(typeParams.head)
else throw TypeSystemException("'Set[T]' takes exactly one type parameter")
case Types.PCollection(_) =>
if (typeParams.size == 1) Types.PCollection(typeParams.head)
else throw TypeSystemException("'Array[T]' takes exactly one type parameter")
Expand Down
26 changes: 25 additions & 1 deletion src/main/scala/org/encryfoundation/prismlang/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,30 @@ object Types {
val ofString: PCollection = PCollection(PString)
}

case class PSet(valT : PType) extends PType with Parametrized {
override type Underlying = Set[valT.Underlying]
override val ident: String = "Set"
override val isCollection: Boolean = true
override val dataCost: Int = 10 * valT.dataCost

override def isApplicable(func: PFunc): Boolean =
func.args.size == 1 && (valT.isSubtypeOf(func.args.head._2) || valT == func.args.head._2)

override def equals(obj: Any): Boolean = obj match {
case set: PSet =>
set.valT == this.valT || set.valT.isSubtypeOf(this.valT) || set.valT.canBeDerivedTo(this.valT)
case tag: TaggedType if tag.isCollection => tag.underlyingType == this
case _ => false
}
}
object PSet {
val ofByte: PSet = PSet(PByte)
val ofByteArrays: PSet = PSet(PSet(PByte))
val ofInt: PSet = PSet(PInt)
val ofBool: PSet = PSet(PBoolean)
val ofString: PSet = PSet(PString)
}

case class PFunc(args: List[(String, PType)], retT: PType) extends PType {
override type Underlying = PFunction
override val ident: String = "Func"
Expand Down Expand Up @@ -328,7 +352,7 @@ object Types {
/** All types with type parameters including `PTuple` instances
* of all possible dimensions. */
val parametrizedTypes: List[Parametrized] = List(
PCollection(Nit)
PCollection(Nit), PSet(Nit)
) ++ (1 to Constants.TupleMaxDim).map(i => PTuple((1 to i).map(_ => Nit).toList))

val productTypes: List[Product] = List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ object PValue {
new PValue {
override val tpe: Types.PType = t
override val value: tpe.Underlying = (v match {
case set: Set[_] if t.isCollection => set.toList
case arr: Array[_] if t.isCollection => arr.toList
case seq: Seq[_] if t.isCollection => seq.toList
case int: Int if t.isNumeric => int.toLong
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ case class Evaluator(initialEnv: ScopedRuntimeEnvironment, types: TypeSystem) {
}
case Expr.Collection(elts, _) => elts.map(elt => eval[elt.tpe.Underlying](elt))
case Expr.Tuple(elts, _) => elts.map(elt => eval[elt.tpe.Underlying](elt))
case Expr.PrismSet(elts, _) => elts.map(elt => eval[elt.tpe.Underlying](elt)).toSet
case Expr.Base58Str(value) => Base58.decode(value).map(_.toList).getOrElse(error("Base58 string decoding failed"))
case Expr.Base16Str(value) => Base16.decode(value).map(_.toList).getOrElse(error("Base16 string decoding failed"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ object Expressions {
}
}

val setContents: noApi.Parser[Seq[Ast.Expr]] = P( test.rep(1, "," ~ LineBreak.?) ~ ",".? ~ LineBreak.? )
val set: core.Parser[Ast.Expr.PrismSet, Char, String] = P( setContents ).map(exps => Ast.Expr.PrismSet(exps.toList))
val listContents: noApi.Parser[Seq[Ast.Expr]] = P( test.rep(1, "," ~ LineBreak.?) ~ ",".? ~ LineBreak.? )
val list: core.Parser[Ast.Expr.Collection, Char, String] = P( listContents ).map(exps => Ast.Expr.Collection(exps.toList))
val tupleContents: core.Parser[Seq[Ast.Expr], Char, String] = P( test ~ "," ~ listContents.?).map { case (head, rest) => head +: rest.getOrElse(Seq.empty) }
Expand All @@ -81,6 +83,7 @@ object Expressions {
P(
"(" ~ LineBreak.? ~ (tuple | expr) ~ ")" |
"Array(" ~ LineBreak.? ~ list ~ ")" |
"Set(" ~ LineBreak.? ~ set ~ ")" |
BASE58STRING.rep(1).map(_.mkString).map(Ast.Expr.Base58Str) |
BASE16STRING.rep(1).map(_.mkString).map(Ast.Expr.Base16Str) |
STRING.rep(1).map(_.mkString).map(Ast.Expr.Str) |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ class SyntaxConstructionsSpec extends PropSpec with Utils {
evaluationSuccess = Option(true), expectedValue = Option(longMin))
}

property("Set construction") {
val setConstruction =
"""
{
let a = 1
let b = 1
let c = 1
let d = 1
let e = 6
let g = Set(a,b,c,d,e)
g[1] + g[0]
}
""".stripMargin

testCompiledExpressionWithOptionalEvaluation(setConstruction, compilationSuccess = true,
evaluationSuccess = Option(true), expectedValue = Option(7))
}

property("Int Lower Boundary check") {
val longMax = Long.MaxValue
val letLongMaxNumber =
Expand Down