Skip to content

Commit

Permalink
address review: part 2
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed Nov 14, 2024
1 parent 77704dd commit be8d0c6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 33 deletions.
52 changes: 22 additions & 30 deletions compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
def computeIndices(annotated: Symbol)(using Context): ComputedIndices =
unrolledDefs.getOrElseUpdate(annotated, {
if annotated.name.is(DefaultGetterName) then
Nil // happens in curried methods where more than one parameter list has @unroll
// happens in curried methods, where default argument occurs in parameter list
// after the unrolled parameter list.
// example:
// `final def foo(@unroll y: String = "")(x: Int = 23) = x`
// yields:
// `def foo$default$2(@unroll y: String): Int @uncheckedVariance = 23`
// Perhaps annotations should be preprocessed before they are copied?
Nil
else
val indices = annotated
.paramSymss
Expand Down Expand Up @@ -114,15 +121,13 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
* @param paramIndex index of the unrolled parameter (in the parameter list) that we stop at
* @param paramCount number of parameters in the annotated parameter list
* @param nextParamIndex index of next unrolled parameter - to fetch default argument
* @param nextSpan span of next forwarder - used to ensure the span is not identical by shifting (TODO remove)
* @param annotatedParamListIndex index of the parameter list that contains unrolled parameters
* @param isCaseApply if `defdef` is a case class apply/constructor - used for selection of default arguments
*/
private def generateSingleForwarder(defdef: DefDef,
paramIndex: Int,
paramCount: Int,
nextParamIndex: Int,
nextSpan: Span,
annotatedParamListIndex: Int,
isCaseApply: Boolean)(using Context): DefDef = {

Expand All @@ -133,7 +138,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
defdef.symbol.flags &~ HasDefaultParams |
Invisible | Synthetic,
NoType, // fill in later
coord = nextSpan.shift(1) // shift by 1 to avoid "secondary constructor must call preceding" error
coord = defdef.span
).entered

val newParamSymMappings = extractParamSymss(copyParamSym(_, forwarderDefSymbol0))
Expand Down Expand Up @@ -196,9 +201,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
newParamSymLists
.take(annotatedParamListIndex)
.map(_.map(ref))
.foldLeft(inner): (lhs, newParams) =>
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
else Apply(lhs, newParams)
.foldLeft(inner)(_.appliedToArgs(_))
)

val forwarderInner: Tree =
Expand All @@ -210,11 +213,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
else ps.map(ref)
}

val forwarderCall0 = forwarderCallArgs.foldLeft[Tree](forwarderInner){
case (lhs: Tree, newParams) =>
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
else Apply(lhs, newParams)
}
val forwarderCall0 = forwarderCallArgs.foldLeft(forwarderInner)(_.appliedToArgs(_))

val forwarderCall =
if (!defdef.symbol.isConstructor) forwarderCall0
Expand All @@ -224,9 +223,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
}

val forwarderDef =
tpd.DefDef(forwarderDefSymbol, rhs = forwarderRhs())
tpd.DefDef(forwarderDefSymbol, rhs = forwarderRhs()).withSpan(defdef.span)

forwarderDef.withSpan(nextSpan.shift(1))
forwarderDef
}

private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
Expand Down Expand Up @@ -261,10 +260,10 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {

private enum Gen:
case Substitute(origin: Symbol, newDef: DefDef)
case Forwarders(origin: Symbol, forwarders: Seq[DefDef])
case Forwarders(origin: Symbol, forwarders: List[DefDef])

def origin: Symbol
def extras: Seq[DefDef] = this match
def extras: List[DefDef] = this match
case Substitute(_, d) => d :: Nil
case Forwarders(_, ds) => ds

Expand All @@ -288,28 +287,26 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {

compute(annotated) match {
case Nil => None
case Seq((paramClauseIndex, annotationIndices)) =>
case (paramClauseIndex, annotationIndices) :: Nil =>
val paramCount = annotated.paramSymss(paramClauseIndex).size
if isCaseFromProduct then
Some(Gen.Substitute(
origin = defdef.symbol,
newDef = generateFromProduct(annotationIndices, paramCount, defdef)
))
else
val (generatedDefs, _) =
val generatedDefs =
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
indices.foldLeft((Seq.empty[DefDef], defdef.symbol.span)):
case ((defdefs, nextSpan), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
indices.foldLeft(List.empty[DefDef]):
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
generateSingleForwarder(
defdef,
paramIndex,
paramCount,
nextParamIndex,
nextSpan,
paramClauseIndex,
isCaseApply
)
(forwarder +: defdefs, forwarder.symbol.span)
) :: defdefs
case _ => unreachable("sliding with at least 2 elements")
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))

Expand All @@ -329,11 +326,6 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
val bodySubs = generatedBody.collect({ case s: Gen.Substitute => s.origin }).toSet
val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))

/** inlined from compiler/src/dotty/tools/dotc/typer/Checking.scala */
def checkClash(decl: Symbol, other: Symbol) =
def staticNonStaticPair = decl.isScalaStatic != other.isScalaStatic
decl.matches(other) && !staticNonStaticPair

if allGenerated.nonEmpty then
val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol)
for
Expand All @@ -342,7 +334,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
do
val replaced = dcl.symbol
byName.get(dcl.name.toString).foreach { syms =>
val clashes = syms.filter(checkClash(replaced, _))
val clashes = syms.filter(ctx.typer.matchesSameStatic(replaced, _))
for existing <- clashes do
val src = syntheticDefs.origin
report.error(i"""Unrolled $replaced clashes with existing declaration.
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,10 @@ trait Checking {
/** A hook to exclude selected symbols from double declaration check */
def excludeFromDoubleDeclCheck(sym: Symbol)(using Context): Boolean = false

def matchesSameStatic(decl: Symbol, other: Symbol)(using Context): Boolean =
def staticNonStaticPair = decl.isScalaStatic != other.isScalaStatic
decl.matches(other) && !staticNonStaticPair

/** Check that class does not declare same symbol twice */
def checkNoDoubleDeclaration(cls: Symbol)(using Context): Unit = {
val seen = new mutable.HashMap[Name, List[Symbol]].withDefaultValue(Nil)
Expand All @@ -1232,8 +1236,7 @@ trait Checking {
def javaFieldMethodPair =
decl.is(JavaDefined) && other.is(JavaDefined) &&
decl.is(Method) != other.is(Method)
def staticNonStaticPair = decl.isScalaStatic != other.isScalaStatic
if (decl.matches(other) && !javaFieldMethodPair && !staticNonStaticPair) {
if (matchesSameStatic(decl, other) && !javaFieldMethodPair) {
def doubleDefError(decl: Symbol, other: Symbol): Unit =
if (!decl.info.isErroneous && !other.info.isErroneous)
report.error(DoubleDefinition(decl, other, cls), decl.srcPos)
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2930,7 +2930,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def checkThisConstrCall(tree: Tree): Unit = tree match
case app: Apply if untpd.isSelfConstrCall(app) =>
if (sym.span.exists && app.symbol.span.exists && sym.span.start <= app.symbol.span.start)
if !sym.is(Synthetic)
&& sym.span.exists && app.symbol.span.exists && sym.span.start <= app.symbol.span.start
then
report.error("secondary constructor must call a preceding constructor", app.srcPos)
case Block(call :: _, _) => checkThisConstrCall(call)
case _ =>
Expand Down

0 comments on commit be8d0c6

Please sign in to comment.