Skip to content

Commit

Permalink
Restrict allowed trees in annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Nov 13, 2024
1 parent bed0e86 commit da97deb
Show file tree
Hide file tree
Showing 22 changed files with 199 additions and 66 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allTermArguments(fn) ::: args
case TypeApply(fn, args) => allTermArguments(fn)
// TOOD(mbovel): is it really safe to skip all blocks here and in `allArguments`?
case Block(_, expr) => allTermArguments(expr)
case _ => Nil
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ class Definitions {

@tu lazy val DummyImplicitClass: ClassSymbol = requiredClass("scala.DummyImplicit")

@tu lazy val SymbolModule: Symbol = requiredModule("scala.Symbol")
@tu lazy val JSSymbolModule: Symbol = requiredModule("scala.scalajs.js.Symbol")

@tu lazy val ScalaRuntimeModule: Symbol = requiredModule("scala.runtime.ScalaRunTime")
def runtimeMethodRef(name: PreName): TermRef = ScalaRuntimeModule.requiredMethodRef(name)
def ScalaRuntime_drop: Symbol = runtimeMethodRef(nme.drop).symbol
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,17 @@ object TreeChecker {
|${mismatch.message}${mismatch.explanation}
|tree = $tree ${tree.className}""".stripMargin
})
checkWellFormedType(tp1)
checkWellFormedType(tp2)

/** Check that the type `tp` is well-formed. Currently this only means
* checking that annotated types have valid annotation arguments.
*/
private def checkWellFormedType(tp: Type)(using Context): Unit =
tp.foreachPart:
case AnnotatedType(underlying, annot) => checkAnnot(annot.tree)
case _ => ()

}

/** Tree checker that can be applied to a local tree. */
Expand Down
50 changes: 45 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,6 @@ object Checking {
annot
case _ => annot
end checkNamedArgumentForJavaAnnotation

}

trait Checking {
Expand Down Expand Up @@ -1387,12 +1386,21 @@ trait Checking {
if !Inlines.inInlineMethod && !ctx.isInlineContext then
report.error(em"$what can only be used in an inline method", pos)

def checkAnnot(tree: Tree)(using Context): Tree =
tree match
case Ident(tpnme.BOUNDTYPE_ANNOT) =>
// `FirstTransform.toTypeTree` creates `Annotated` nodes whose `annot` are
// `Ident`s, not annotation instances. See `tests/pos/annot-boundtype.scala`.
tree
case _ =>
checkAnnotArgs(checkAnnotClass(tree))

/** Check that the class corresponding to this tree is either a Scala or Java annotation.
*
* @return The original tree or an error tree in case `tree` isn't a valid
* annotation or already an error tree.
*/
def checkAnnotClass(tree: Tree)(using Context): Tree =
private def checkAnnotClass(tree: Tree)(using Context): Tree =
if tree.tpe.isError then
return tree
val cls = Annotations.annotClass(tree)
Expand All @@ -1404,8 +1412,8 @@ trait Checking {
errorTree(tree, em"$cls is not a valid Scala annotation: it does not extend `scala.annotation.Annotation`")
else tree

/** Check arguments of compiler-defined annotations */
def checkAnnotArgs(tree: Tree)(using Context): tree.type =
/** Check arguments of annotations */
private def checkAnnotArgs(tree: Tree)(using Context): Tree =
val cls = Annotations.annotClass(tree)
tree match
case Apply(tycon, arg :: Nil) if cls == defn.TargetNameAnnot =>
Expand All @@ -1416,8 +1424,40 @@ trait Checking {
case _ =>
report.error(em"@${cls.name} needs a string literal as argument", arg.srcPos)
case _ =>
if cls.isRetainsLike then () // Do not check @retain annotations
else if cls == defn.ThrowsAnnot then
// Do not check @throws annotations.
// TODO(mbovel): in tests/run/t6380.scala, an annotation tree is
// `new throws[Exception](throws.<init>[Exception])`. What is this?
()
else
tpd.allTermArguments(tree).foreach(checkAnnotArg)
tree

private def checkAnnotArg(tree: Tree)(using Context): Unit =
def valid(t: Tree): Boolean =
t match
case _ if t.tpe.isEffectivelySingleton => true
case Literal(_) => true
// `_` is used as placeholder for unspecified arguments of Java
// annotations. Example: tests/run/java-ann-super-class
case Ident(nme.WILDCARD) => true
case Apply(fun, args) => valid(fun) && args.forall(valid)
case TypeApply(fun, args) => valid(fun)
case SeqLiteral(elems, _) => elems.forall(valid)
case Typed(expr, _) => valid(expr)
case NamedArg(_, arg) => valid(arg)
case Splice(_) => true
case Hole(_, _, _, _) => true
case _ => false
if !valid(tree) then
report.error(
i"""Implementation restriction: not a valid annotation argument.
|Argument: $tree
|Type: ${tree.tpe}""",
tree.srcPos
)

/** 1. Check that all case classes that extend `scala.reflect.Enum` are `enum` cases
* 2. Check that parameterised `enum` cases do not extend java.lang.Enum.
* 3. Check that only a static `enum` base class can extend java.lang.Enum.
Expand Down Expand Up @@ -1665,7 +1705,7 @@ trait NoChecking extends ReChecking {
override def checkImplicitConversionDefOK(sym: Symbol)(using Context): Unit = ()
override def checkImplicitConversionUseOK(tree: Tree, expected: Type)(using Context): Unit = ()
override def checkFeasibleParent(tp: Type, pos: SrcPos, where: => String = "")(using Context): Type = tp
override def checkAnnotArgs(tree: Tree)(using Context): tree.type = tree
override def checkAnnot(tree: Tree)(using Context): tree.type = tree
override def checkNoTargetNameConflict(stats: List[Tree])(using Context): Unit = ()
override def checkParentCall(call: Tree, caller: ClassSymbol)(using Context): Unit = ()
override def checkSimpleKinded(tpt: Tree)(using Context): Tree = tpt
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2780,7 +2780,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}

def typedAnnotation(annot: untpd.Tree)(using Context): Tree =
checkAnnotClass(checkAnnotArgs(typed(annot)))
checkAnnot(typed(annot))

def registerNowarn(tree: Tree, mdef: untpd.Tree)(using Context): Unit =
val annot = Annotations.Annotation(tree)
Expand Down Expand Up @@ -3310,7 +3310,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
end typedPackageDef

def typedAnnotated(tree: untpd.Annotated, pt: Type)(using Context): Tree = {
val annot1 = checkAnnotClass(typedExpr(tree.annot))
val annot1 = checkAnnot(typedExpr(tree.annot))
val annotCls = Annotations.annotClass(annot1)
if annotCls == defn.NowarnAnnot then
registerNowarn(annot1, tree)
Expand Down
1 change: 1 addition & 0 deletions tests/bench/inductive-implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ package shapeless {
import shapeless.*

object Test extends App {
import Selector.given
val sel = Selector[L, Boolean]

type L =
Expand Down
48 changes: 48 additions & 0 deletions tests/neg/annot-invalid.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
-- Error: tests/neg/annot-invalid.scala:4:21 ---------------------------------------------------------------------------
4 | val x1: Int @annot(new Object {}) = 0 // error
| ^^^^^^^^^^^^^
| Implementation restriction: not a valid annotation argument.
| Argument: {
| final class $anon() extends Object() {}
| new $anon():Object
| }
| Type: Object
-- Error: tests/neg/annot-invalid.scala:5:21 ---------------------------------------------------------------------------
5 | val x2: Int @annot({val x = 1}) = 0 // error
| ^^^^^^^^^^^
| Implementation restriction: not a valid annotation argument.
| Argument: {
| val x: Int = 1
| ()
| }
| Type: Unit
-- Error: tests/neg/annot-invalid.scala:6:21 ---------------------------------------------------------------------------
6 | val x3: Int @annot((x: Int) => x) = 0 // error
| ^^^^^^^^^^^^^
| Implementation restriction: not a valid annotation argument.
| Argument: (x: Int) => x
| Type: Int => Int
-- Error: tests/neg/annot-invalid.scala:8:9 ----------------------------------------------------------------------------
8 | @annot(new Object {}) val y1: Int = 0 // error
| ^^^^^^^^^^^^^
| Implementation restriction: not a valid annotation argument.
| Argument: {
| final class $anon() extends Object() {}
| new $anon():Object
| }
| Type: Object
-- Error: tests/neg/annot-invalid.scala:9:9 ----------------------------------------------------------------------------
9 | @annot({val x = 1}) val y2: Int = 0 // error
| ^^^^^^^^^^^
| Implementation restriction: not a valid annotation argument.
| Argument: {
| val x: Int = 1
| ()
| }
| Type: Unit
-- Error: tests/neg/annot-invalid.scala:10:9 ---------------------------------------------------------------------------
10 | @annot((x: Int) => x) val y3: Int = 0 // error
| ^^^^^^^^^^^^^
| Implementation restriction: not a valid annotation argument.
| Argument: (x: Int) => x
| Type: Int => Int
12 changes: 12 additions & 0 deletions tests/neg/annot-invalid.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class annot[T](arg: T) extends scala.annotation.Annotation

def main =
val x1: Int @annot(new Object {}) = 0 // error
val x2: Int @annot({val x = 1}) = 0 // error
val x3: Int @annot((x: Int) => x) = 0 // error

@annot(new Object {}) val y1: Int = 0 // error
@annot({val x = 1}) val y2: Int = 0 // error
@annot((x: Int) => x) val y3: Int = 0 // error

()
15 changes: 15 additions & 0 deletions tests/neg/i15054.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.annotation.Annotation

class AnAnnotation(function: Int => String) extends Annotation

@AnAnnotation(_.toString) // error: not a valid annotation
val a = 1
@AnAnnotation(_.toString.length.toString) // error: not a valid annotation
val b = 2

def test =
@AnAnnotation(_.toString) // error: not a valid annotation
val a = 1
@AnAnnotation(_.toString.length.toString) // error: not a valid annotation
val b = 2
a + b
2 changes: 2 additions & 0 deletions tests/neg/i7740a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class A(a: Any) extends annotation.StaticAnnotation
@A({val x = 0}) trait B // error: not a valid annotation
2 changes: 2 additions & 0 deletions tests/neg/i7740b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class A(a: Any) extends annotation.StaticAnnotation
@A({def x = 0}) trait B // error: not a valid annotation
2 changes: 1 addition & 1 deletion tests/pos/i9314.scala → tests/neg/i9314.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
final class fooAnnot[T](member: T) extends scala.annotation.StaticAnnotation // must have type parameter

@fooAnnot(new RecAnnotated {}) // must pass instance of anonymous subclass
@fooAnnot(new RecAnnotated {}) // error: not a valid annotation
trait RecAnnotated
3 changes: 3 additions & 0 deletions tests/neg/t7426.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class foo(x: Any) extends annotation.StaticAnnotation

@foo(new AnyRef { }) trait A // error: not a valid annotation
10 changes: 0 additions & 10 deletions tests/pos/annot-17939b.scala

This file was deleted.

16 changes: 16 additions & 0 deletions tests/pos/annot-boundtype.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// `FirstTransform.toTypeTree` creates `Annotated` nodes whose `annot` are
// `Ident`s, not annotation instances. This is relevant for `Checking.checkAnnot`.
//
// See also:
// - tests/run/t2755.scala
// - tests/neg/i13044.scala

def f(a: Array[?]) =
a match
case x: Array[?] => ()

def f2(t: Tuple) =
t match
case _: (t *: ts) => ()
case _ => ()

37 changes: 37 additions & 0 deletions tests/pos/annot-valid.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
class annot[T](arg: T) extends scala.annotation.Annotation

def main =
val n: Int = 0
def f(x: Any): Unit = ()

val x1: Int @annot(42) = 0
val x2: Int @annot("hello") = 0
val x3: Int @annot(classOf[Int]) = 0
val x4: Int @annot(Array(1,2)) = 0
val x5: Int @annot(Array(Array(1,2),Array(3,4))) = 0
val x6: Int @annot((1,2)) = 0
val x7: Int @annot((1,2,3)) = 0
val x8: Int @annot(((1,2),3)) = 0
val x9: Int @annot(((1,2),(3,4))) = 0
val x10: Int @annot(Symbol("hello")) = 0
val x11: Int @annot(n + 1) = 0
val x12: Int @annot(f(2)) = 0
val x13: Int @annot(throw new Error()) = 0
val x14: Int @annot(42: Double) = 0

@annot(42) val y1: Int = 0
@annot("hello") val y2: Int = 0
@annot(classOf[Int]) val y3: Int = 0
@annot(Array(1,2)) val y4: Int = 0
@annot(Array(Array(1,2),Array(3,4))) val y5: Int = 0
@annot((1,2)) val y6: Int = 0
@annot((1,2,3)) val y7: Int = 0
@annot(((1,2),3)) val y8: Int = 0
@annot(((1,2),(3,4))) val y9: Int = 0
@annot(Symbol("hello")) val y10: Int = 0
@annot(n + 1) val y11: Int = 0
@annot(f(2)) val y12: Int = 0
@annot(throw new Error()) val y13: Int = 0
@annot(42: Double) val y14: Int = 0

()
15 changes: 0 additions & 15 deletions tests/pos/i15054.scala

This file was deleted.

2 changes: 0 additions & 2 deletions tests/pos/i7740a.scala

This file was deleted.

2 changes: 0 additions & 2 deletions tests/pos/i7740b.scala

This file was deleted.

3 changes: 0 additions & 3 deletions tests/pos/t7426.scala

This file was deleted.

19 changes: 0 additions & 19 deletions tests/printing/annot-19846b.check

This file was deleted.

7 changes: 0 additions & 7 deletions tests/printing/annot-19846b.scala

This file was deleted.

0 comments on commit da97deb

Please sign in to comment.