From 5be0722abc913deea1a0bc3e433f1bf4c29f4e09 Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Mon, 22 Jun 2015 15:08:32 +0200 Subject: [PATCH] Rewrite closure invocations to the lambda body method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When an indylambda closure is allocated and invoked within the same method, rewrite the invocation to the implementation method. This works for any indylambda / SAM type, not only Scala functions. However, the Scala compiler (under -Xexperimental) currently desugars function literals for non-FunctionN types to an anonymous class during typer. No testing yet, waiting for FunctionN to become SAMs first. The feature requires scala-java8-compat to be on the classpath and a number of compiler flags: -Ydelambdafy:method -Ybackend:GenBCode -Yopt:closure-elimination -target:jvm-1.8 ➜ scala git:(opt/closureInlining) ant -Dscala-java8-compat.package=1 -Dlocker.skip=1 ➜ scala git:(opt/closureInlining) cd sandbox ➜ sandbox git:(opt/closureInlining) cat Fun.java public interface Fun { T apply(T x); } ➜ sandbox git:(opt/closureInlining) javac Fun.java ➜ sandbox git:(opt/closureInlining) cat Test.scala class C { val z = "too" def f = { val kap = "me! me!" val f: Tuple2[String, String] => String = (o => z + kap + o.toString) f(("a", "b")) } def g = { val f: Int => String = x => x.toString f(10) } def h = { val f: Fun[Int] = x => x + 100 // Java SAM, requires -Xexperimental, will create an anonymous class in typer f(10) } def i = { val l = 10l val f: (Long, String) => String = (x, s) => s + l + z + x f(20l, "n") } def j = { val f: Int => Int = x => x + 101 // specialized f(33) } } ➜ sandbox git:(opt/closureInlining) ../build/quick/bin/scalac -target:jvm-1.8 -Yopt:closure-elimination -Ydelambdafy:method -Ybackend:GenBCode -Xexperimental -cp ../build/quick/scala-java8-compat:. Test.scala ➜ sandbox git:(opt/closureInlining) asm -a C.class ➜ sandbox git:(opt/closureInlining) cat C.asm [...] public g()Ljava/lang/String; L0 INVOKEDYNAMIC apply()Lscala/compat/java8/JFunction1; [ // handle kind 0x6 : INVOKESTATIC java/lang/invoke/LambdaMetafactory.altMetafactory(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite; // arguments: (Ljava/lang/Object;)Ljava/lang/Object;, // handle kind 0x6 : INVOKESTATIC C.C$$$anonfun$2$adapted(Ljava/lang/Object;)Ljava/lang/String;, (Ljava/lang/Object;)Ljava/lang/String;, 3, 1, Lscala/Serializable;.class, 0 ] CHECKCAST scala/Function1 L1 ASTORE 1 L2 ALOAD 1 BIPUSH 10 INVOKESTATIC scala/runtime/BoxesRunTime.boxToInteger (I)Ljava/lang/Integer; ASTORE 2 POP ALOAD 2 INVOKESTATIC C.C$$$anonfun$2$adapted (Ljava/lang/Object;)Ljava/lang/String; CHECKCAST java/lang/String L3 ARETURN [...] --- .../nsc/backend/jvm/BCodeSkelBuilder.scala | 2 +- .../scala/tools/nsc/backend/jvm/BTypes.scala | 2 + .../nsc/backend/jvm/BTypesFromSymbols.scala | 4 +- .../nsc/backend/jvm/BackendReporting.scala | 22 ++ .../tools/nsc/backend/jvm/GenBCode.scala | 17 +- .../backend/jvm/opt/ByteCodeRepository.scala | 2 +- .../tools/nsc/backend/jvm/opt/CallGraph.scala | 38 ++- .../backend/jvm/opt/ClosureOptimizer.scala | 314 ++++++++++++++++++ .../tools/nsc/backend/jvm/opt/Inliner.scala | 189 ++++++----- .../tools/nsc/settings/ScalaSettings.scala | 7 +- 10 files changed, 489 insertions(+), 108 deletions(-) create mode 100644 src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala index 0f6785280407..d2d510e8a9ab 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala @@ -140,7 +140,7 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { if (AsmUtils.traceClassEnabled && cnode.name.contains(AsmUtils.traceClassPattern)) AsmUtils.traceClass(cnode) - if (settings.YoptInlinerEnabled) { + if (settings.YoptAddToBytecodeRepository) { // The inliner needs to find all classes in the code repo, also those being compiled byteCodeRepository.add(cnode, ByteCodeRepository.CompilationUnit) } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala index e61190bf3a0a..ec0017270ecc 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala @@ -44,6 +44,8 @@ abstract class BTypes { val inliner: Inliner[this.type] + val closureOptimizer: ClosureOptimizer[this.type] + val callGraph: CallGraph[this.type] val backendReporting: BackendReporting diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypesFromSymbols.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypesFromSymbols.scala index 356af36455f7..8740193b58d6 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BTypesFromSymbols.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypesFromSymbols.scala @@ -7,7 +7,7 @@ package scala.tools.nsc package backend.jvm import scala.tools.asm -import scala.tools.nsc.backend.jvm.opt.{LocalOpt, CallGraph, Inliner, ByteCodeRepository} +import scala.tools.nsc.backend.jvm.opt._ import scala.tools.nsc.backend.jvm.BTypes.{InlineInfo, MethodInlineInfo, InternalName} import BackendReporting._ import scala.tools.nsc.settings.ScalaSettings @@ -42,6 +42,8 @@ class BTypesFromSymbols[G <: Global](val global: G) extends BTypes { val inliner: Inliner[this.type] = new Inliner(this) + val closureOptimizer: ClosureOptimizer[this.type] = new ClosureOptimizer(this) + val callGraph: CallGraph[this.type] = new CallGraph(this) val backendReporting: BackendReporting = new BackendReportingImpl(global) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BackendReporting.scala b/src/compiler/scala/tools/nsc/backend/jvm/BackendReporting.scala index d641f708d218..4fc05cafdc93 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BackendReporting.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BackendReporting.scala @@ -246,6 +246,28 @@ object BackendReporting { case class ResultingMethodTooLarge(calleeDeclarationClass: InternalName, name: String, descriptor: String, callsiteClass: InternalName, callsiteName: String, callsiteDesc: String) extends CannotInlineWarning + /** + * Used in `rewriteClosureApplyInvocations` when a closure apply callsite cannot be rewritten + * to the closure body method. + */ + trait RewriteClosureApplyToClosureBodyFailed extends OptimizerWarning { + def pos: Position + + override def emitWarning(settings: ScalaSettings): Boolean = this match { + case RewriteClosureAccessCheckFailed(_, cause) => cause.emitWarning(settings) + case RewriteClosureIllegalAccess(_, _) => settings.YoptWarningEmitAtInlineFailed + } + + override def toString: String = this match { + case RewriteClosureAccessCheckFailed(_, cause) => + s"Failed to rewrite the closure invocation to its implementation method:\n" + cause + case RewriteClosureIllegalAccess(_, callsiteClass) => + s"The closure body invocation cannot be rewritten because the target method is not accessible in class $callsiteClass." + } + } + case class RewriteClosureAccessCheckFailed(pos: Position, cause: OptimizerWarning) extends RewriteClosureApplyToClosureBodyFailed + case class RewriteClosureIllegalAccess(pos: Position, callsiteClass: InternalName) extends RewriteClosureApplyToClosureBodyFailed + /** * Used in the InlineInfo of a ClassBType, when some issue occurred obtaining the inline information. */ diff --git a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala index c6ee36d7b212..455117d837fa 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala @@ -216,12 +216,17 @@ abstract class GenBCode extends BCodeSyncAndTry { class Worker2 { def runGlobalOptimizations(): Unit = { import scala.collection.convert.decorateAsScala._ - q2.asScala foreach { - case Item2(_, _, plain, _, _) => - // skip mirror / bean: wd don't inline into tem, and they are not used in the plain class - if (plain != null) callGraph.addClass(plain) + if (settings.YoptBuildCallGraph) { + q2.asScala foreach { + case Item2(_, _, plain, _, _) => + // skip mirror / bean: wd don't inline into tem, and they are not used in the plain class + if (plain != null) callGraph.addClass(plain) + } } - bTypes.inliner.runInliner() + if (settings.YoptInlinerEnabled) + bTypes.inliner.runInliner() + if (settings.YoptClosureElimination) + closureOptimizer.rewriteClosureApplyInvocations() } def localOptimizations(classNode: ClassNode): Unit = { @@ -229,7 +234,7 @@ abstract class GenBCode extends BCodeSyncAndTry { } def run() { - if (settings.YoptInlinerEnabled) runGlobalOptimizations() + runGlobalOptimizations() while (true) { val item = q2.poll diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/ByteCodeRepository.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/ByteCodeRepository.scala index dbf19744fabb..a5b85e54e790 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/ByteCodeRepository.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/ByteCodeRepository.scala @@ -102,7 +102,7 @@ class ByteCodeRepository(val classPath: ClassFileLookup[AbstractFile], val isJav } /** - * The method node for a method matching `name` and `descriptor`, accessed in class `classInternalName`. + * The method node for a method matching `name` and `descriptor`, accessed in class `ownerInternalNameOrArrayDescriptor`. * The declaration of the method may be in one of the parents. * * @return The [[MethodNode]] of the requested method and the [[InternalName]] of its declaring diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala index 0932564b1f6e..8abecdb26121 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala @@ -11,6 +11,7 @@ import scala.reflect.internal.util.{NoPosition, Position} import scala.tools.asm.tree.analysis.{Value, Analyzer, BasicInterpreter} import scala.tools.asm.{Opcodes, Type} import scala.tools.asm.tree._ +import scala.collection.concurrent import scala.collection.convert.decorateAsScala._ import scala.tools.nsc.backend.jvm.BTypes.InternalName import scala.tools.nsc.backend.jvm.BackendReporting._ @@ -21,14 +22,25 @@ import BytecodeUtils._ class CallGraph[BT <: BTypes](val btypes: BT) { import btypes._ - val callsites: collection.concurrent.Map[MethodInsnNode, Callsite] = recordPerRunCache(collection.concurrent.TrieMap.empty[MethodInsnNode, Callsite]) + val callsites: concurrent.Map[MethodInsnNode, Callsite] = recordPerRunCache(concurrent.TrieMap.empty) + + val closureInstantiations: concurrent.Map[InvokeDynamicInsnNode, (MethodNode, ClassBType)] = recordPerRunCache(concurrent.TrieMap.empty) def addClass(classNode: ClassNode): Unit = { - for (m <- classNode.methods.asScala; callsite <- analyzeCallsites(m, classBTypeFromClassNode(classNode))) - callsites(callsite.callsiteInstruction) = callsite + val classType = classBTypeFromClassNode(classNode) + for { + m <- classNode.methods.asScala + (calls, closureInits) = analyzeCallsites(m, classType) + } { + calls foreach (callsite => callsites(callsite.callsiteInstruction) = callsite) + closureInits foreach (indy => closureInstantiations(indy) = (m, classType)) + } } - def analyzeCallsites(methodNode: MethodNode, definingClass: ClassBType): List[Callsite] = { + /** + * Returns a list of callsites in the method, plus a list of closure instantiation indy instructions. + */ + def analyzeCallsites(methodNode: MethodNode, definingClass: ClassBType): (List[Callsite], List[InvokeDynamicInsnNode]) = { case class CallsiteInfo(safeToInline: Boolean, safeToRewrite: Boolean, annotatedInline: Boolean, annotatedNoInline: Boolean, @@ -116,7 +128,10 @@ class CallGraph[BT <: BTypes](val btypes: BT) { case _ => false } - methodNode.instructions.iterator.asScala.collect({ + val callsites = new collection.mutable.ListBuffer[Callsite] + val closureInstantiations = new collection.mutable.ListBuffer[InvokeDynamicInsnNode] + + methodNode.instructions.iterator.asScala foreach { case call: MethodInsnNode => val callee: Either[OptimizerWarning, Callee] = for { (method, declarationClass) <- byteCodeRepository.methodNode(call.owner, call.name, call.desc): Either[OptimizerWarning, (MethodNode, InternalName)] @@ -147,7 +162,7 @@ class CallGraph[BT <: BTypes](val btypes: BT) { receiverNotNullByAnalysis(call, numArgs) } - Callsite( + callsites += Callsite( callsiteInstruction = call, callsiteMethod = methodNode, callsiteClass = definingClass, @@ -157,7 +172,14 @@ class CallGraph[BT <: BTypes](val btypes: BT) { receiverKnownNotNull = receiverNotNull, callsitePosition = callsitePositions.getOrElse(call, NoPosition) ) - }).toList + + case indy: InvokeDynamicInsnNode => + if (closureOptimizer.isClosureInstantiation(indy)) closureInstantiations += indy + + case _ => + } + + (callsites.toList, closureInstantiations.toList) } /** @@ -201,7 +223,7 @@ class CallGraph[BT <: BTypes](val btypes: BT) { * @param calleeDeclarationClass The class in which the callee is declared * @param safeToInline True if the callee can be safely inlined: it cannot be overridden, * and the inliner settings (project / global) allow inlining it. - * @param safeToRewrite True if the callee the interface method of a concrete trait method + * @param safeToRewrite True if the callee is the interface method of a concrete trait method * that can be safely re-written to the static implementation method. * @param annotatedInline True if the callee is annotated @inline * @param annotatedNoInline True if the callee is annotated @noinline diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala new file mode 100644 index 000000000000..29896b5ffd8e --- /dev/null +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala @@ -0,0 +1,314 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2015 LAMP/EPFL + * @author Martin Odersky + */ + +package scala.tools.nsc +package backend.jvm +package opt + +import scala.annotation.switch +import scala.reflect.internal.util.NoPosition +import scala.tools.asm.{Handle, Type, Opcodes} +import scala.tools.asm.tree._ +import scala.tools.nsc.backend.jvm.BTypes.InternalName +import scala.tools.nsc.backend.jvm.analysis.ProdConsAnalyzer +import BytecodeUtils._ +import BackendReporting._ +import Opcodes._ +import scala.tools.nsc.backend.jvm.opt.ByteCodeRepository.CompilationUnit +import scala.collection.convert.decorateAsScala._ + +class ClosureOptimizer[BT <: BTypes](val btypes: BT) { + import btypes._ + import callGraph._ + + def rewriteClosureApplyInvocations(): Unit = { + closureInstantiations foreach { + case (indy, (methodNode, ownerClass)) => + val warnings = rewriteClosureApplyInvocations(indy, methodNode, ownerClass) + warnings.foreach(w => backendReporting.inlinerWarning(w.pos, w.toString)) + } + } + + private val lambdaMetaFactoryInternalName: InternalName = "java/lang/invoke/LambdaMetafactory" + + private val metafactoryHandle = { + val metafactoryMethodName: String = "metafactory" + val metafactoryDesc: String = "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;" + new Handle(H_INVOKESTATIC, lambdaMetaFactoryInternalName, metafactoryMethodName, metafactoryDesc) + } + + private val altMetafactoryHandle = { + val altMetafactoryMethodName: String = "altMetafactory" + val altMetafactoryDesc: String = "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;" + new Handle(H_INVOKESTATIC, lambdaMetaFactoryInternalName, altMetafactoryMethodName, altMetafactoryDesc) + } + + def isClosureInstantiation(indy: InvokeDynamicInsnNode): Boolean = { + (indy.bsm == metafactoryHandle || indy.bsm == altMetafactoryHandle) && + { + indy.bsmArgs match { + case Array(samMethodType: Type, implMethod: Handle, instantiatedMethodType: Type, xs @ _*) => + // LambdaMetaFactory performs a number of automatic adaptations when invoking the lambda + // implementation method (casting, boxing, unboxing, and primitive widening, see Javadoc). + // The closure optimizer (rewriteClosureApplyInvocations) does not currently support these + // adaptations, so we don't consider indy calls that need adaptations for rewriting. + // Indy calls emitted by scalac never rely on adaptation, they are implemented explicitly + // in the implMethod. + // + // Note that we don't check all the invariants requried for a metafactory indy call, only + // those required not to crash the compiler. + + val implMethodType = Type.getType(implMethod.getDesc) + val numCaptures = implMethodType.getArgumentTypes.length - instantiatedMethodType.getArgumentTypes.length + val implMethodTypeWithoutCaputres = Type.getMethodType(implMethodType.getReturnType, implMethodType.getArgumentTypes.drop(numCaptures): _*) + implMethodTypeWithoutCaputres == instantiatedMethodType + + case _ => + false + } + } + } + + def isSamInvocation(invocation: MethodInsnNode, indy: InvokeDynamicInsnNode, prodCons: => ProdConsAnalyzer): Boolean = { + if (invocation.getOpcode == INVOKESTATIC) false + else { + def closureIsReceiver = { + val invocationFrame = prodCons.frameAt(invocation) + val receiverSlot = { + val numArgs = Type.getArgumentTypes(invocation.desc).length + invocationFrame.stackTop - numArgs + } + val receiverProducers = prodCons.initialProducersForValueAt(invocation, receiverSlot) + receiverProducers.size == 1 && receiverProducers.head == indy + } + + invocation.name == indy.name && { + val indySamMethodDesc = indy.bsmArgs(0).asInstanceOf[Type].getDescriptor // safe, checked in isClosureInstantiation + indySamMethodDesc == invocation.desc + } && + closureIsReceiver // most expensive check last + } + } + + /** + * Stores the values captured by a closure creation into fresh local variables. + * Returns the list of locals holding the captured values. + */ + private def storeCaptures(indy: InvokeDynamicInsnNode, methodNode: MethodNode): LocalsList = { + val capturedTypes = Type.getArgumentTypes(indy.desc) + val firstCaptureLocal = methodNode.maxLocals + + // This could be optimized: in many cases the captured values are produced by LOAD instructions. + // If the variable is not modified within the method, we could avoid introducing yet another + // local. On the other hand, further optimizations (copy propagation, remove unused locals) will + // clean it up. + val localsForCaptures = LocalsList.fromTypes(firstCaptureLocal, capturedTypes) + methodNode.maxLocals = firstCaptureLocal + localsForCaptures.size + + insertStoreOps(indy, methodNode, localsForCaptures) + insertLoadOps(indy, methodNode, localsForCaptures) + + localsForCaptures + } + + /** + * Insert store operations in front of the `before` instruction to copy stack values into the + * locals denoted by `localsList`. + * + * The lowest stack value is stored in the head of the locals list, so the last local is stored first. + */ + private def insertStoreOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) = + insertLocalValueOps(before, methodNode, localsList, store = true) + + /** + * Insert load operations in front of the `before` instruction to copy the local values denoted + * by `localsList` onto the stack. + * + * The head of the locals list will be the lowest value on the stack, so the first local is loaded first. + */ + private def insertLoadOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) = + insertLocalValueOps(before, methodNode, localsList, store = false) + + private def insertLocalValueOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList, store: Boolean): Unit = { + // If `store` is true, the first instruction needs to store into the last local of the `localsList`. + // Load instructions on the other hand are emitted in the order of the list. + // To avoid reversing the list, we use `insert(previousInstr)` for stores and `insertBefore(before)` for loads. + lazy val previous = before.getPrevious + for (l <- localsList.locals) { + val varOp = new VarInsnNode(if (store) l.storeOpcode else l.loadOpcode, l.local) + if (store) methodNode.instructions.insert(previous, varOp) + else methodNode.instructions.insertBefore(before, varOp) + } + } + + def rewriteClosureApplyInvocations(indy: InvokeDynamicInsnNode, methodNode: MethodNode, ownerClass: ClassBType): List[RewriteClosureApplyToClosureBodyFailed] = { + val lambdaBodyHandle = indy.bsmArgs(1).asInstanceOf[Handle] // safe, checked in isClosureInstantiation + + // Kept as a lazy val to make sure the analysis is only computed if it's actually needed. + // ProdCons is used to identify closure body invocations (see isSamInvocation), but only if the + // callsite has the right name and signature. If the method has no invcation instruction with + // the right name and signature, the analysis is not executed. + lazy val prodCons = new ProdConsAnalyzer(methodNode, ownerClass.internalName) + + // First collect all callsites without modifying the instructions list yet. + // Once we start modifying the instruction list, prodCons becomes unusable. + + // A list of callsites and stack heights. If the invocation cannot be rewritten, a warning + // message is stored in the stack height value. + val invocationsToRewrite: List[(MethodInsnNode, Either[RewriteClosureApplyToClosureBodyFailed, Int])] = methodNode.instructions.iterator.asScala.collect({ + case invocation: MethodInsnNode if isSamInvocation(invocation, indy, prodCons) => + val bodyAccessible: Either[OptimizerWarning, Boolean] = for { + (bodyMethodNode, declClass) <- byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc): Either[OptimizerWarning, (MethodNode, InternalName)] + isAccessible <- inliner.memberIsAccessible(bodyMethodNode.access, classBTypeFromParsedClassfile(declClass), classBTypeFromParsedClassfile(lambdaBodyHandle.getOwner), ownerClass) + } yield { + isAccessible + } + + def pos = callGraph.callsites.get(invocation).map(_.callsitePosition).getOrElse(NoPosition) + val stackSize: Either[RewriteClosureApplyToClosureBodyFailed, Int] = bodyAccessible match { + case Left(w) => Left(RewriteClosureAccessCheckFailed(pos, w)) + case Right(false) => Left(RewriteClosureIllegalAccess(pos, ownerClass.internalName)) + case _ => Right(prodCons.frameAt(invocation).getStackSize) + } + + (invocation, stackSize) + }).toList + + if (invocationsToRewrite.isEmpty) Nil + else { + // lazy val to make sure locals for captures and arguments are only allocated if there's + // effectively a callsite to rewrite. + lazy val (localsForCapturedValues, argumentLocalsList) = { + val captureLocals = storeCaptures(indy, methodNode) + + // allocate locals for storing the arguments of the closure apply callsites. + // if there are multiple callsites, the same locals are re-used. + val argTypes = indy.bsmArgs(0).asInstanceOf[Type].getArgumentTypes // safe, checked in isClosureInstantiation + val firstArgLocal = methodNode.maxLocals + val argLocals = LocalsList.fromTypes(firstArgLocal, argTypes) + methodNode.maxLocals = firstArgLocal + argLocals.size + + (captureLocals, argLocals) + } + + val warnings = invocationsToRewrite flatMap { + case (invocation, Left(warning)) => Some(warning) + + case (invocation, Right(stackHeight)) => + // store arguments + insertStoreOps(invocation, methodNode, argumentLocalsList) + + // drop the closure from the stack + methodNode.instructions.insertBefore(invocation, new InsnNode(POP)) + + // load captured values and arguments + insertLoadOps(invocation, methodNode, localsForCapturedValues) + insertLoadOps(invocation, methodNode, argumentLocalsList) + + // update maxStack + val capturesStackSize = localsForCapturedValues.size + val invocationStackHeight = stackHeight + capturesStackSize - 1 // -1 because the closure is gone + if (invocationStackHeight > methodNode.maxStack) + methodNode.maxStack = invocationStackHeight + + // replace the callsite with a new call to the body method + val bodyOpcode = (lambdaBodyHandle.getTag: @switch) match { + case H_INVOKEVIRTUAL => INVOKEVIRTUAL + case H_INVOKESTATIC => INVOKESTATIC + case H_INVOKESPECIAL => INVOKESPECIAL + case H_INVOKEINTERFACE => INVOKEINTERFACE + case H_NEWINVOKESPECIAL => + val insns = methodNode.instructions + insns.insertBefore(invocation, new TypeInsnNode(NEW, lambdaBodyHandle.getOwner)) + insns.insertBefore(invocation, new InsnNode(DUP)) + INVOKESPECIAL + } + val isInterface = bodyOpcode == INVOKEINTERFACE + val bodyInvocation = new MethodInsnNode(bodyOpcode, lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc, isInterface) + methodNode.instructions.insertBefore(invocation, bodyInvocation) + methodNode.instructions.remove(invocation) + + // update the call graph + val originalCallsite = callGraph.callsites.remove(invocation) + + // the method node is needed for building the call graph entry + val bodyMethod = byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc) + def bodyMethodIsBeingCompiled = byteCodeRepository.classNodeAndSource(lambdaBodyHandle.getOwner).map(_._2 == CompilationUnit).getOrElse(false) + val bodyMethodCallsite = Callsite( + callsiteInstruction = bodyInvocation, + callsiteMethod = methodNode, + callsiteClass = ownerClass, + callee = bodyMethod.map({ + case (bodyMethodNode, bodyMethodDeclClass) => Callee( + callee = bodyMethodNode, + calleeDeclarationClass = classBTypeFromParsedClassfile(bodyMethodDeclClass), + safeToInline = compilerSettings.YoptInlineGlobal || bodyMethodIsBeingCompiled, + safeToRewrite = false, // the lambda body method is not a trait interface method + annotatedInline = false, + annotatedNoInline = false, + calleeInfoWarning = None) + }), + argInfos = Nil, + callsiteStackHeight = invocationStackHeight, + receiverKnownNotNull = true, // see below (*) + callsitePosition = originalCallsite.map(_.callsitePosition).getOrElse(NoPosition) + ) + // (*) The documentation in class LambdaMetafactory says: + // "if implMethod corresponds to an instance method, the first capture argument + // (corresponding to the receiver) must be non-null" + // Explanation: If the lambda body method is non-static, the receiver is a captured + // value. It can only be captured within some instance method, so we know it's non-null. + callGraph.callsites(bodyInvocation) = bodyMethodCallsite + None + } + + warnings.toList + } + } + + /** + * A list of local variables. Each local stores information about its type, see class [[Local]]. + */ + case class LocalsList(locals: List[Local]) { + val size = locals.iterator.map(_.size).sum + } + + object LocalsList { + /** + * A list of local variables starting at `firstLocal` that can hold values of the types in the + * `types` parameter. + * + * For example, `fromTypes(3, Array(Int, Long, String))` returns + * Local(3, intOpOffset) :: + * Local(4, longOpOffset) :: // note that this local occupies two slots, the next is at 6 + * Local(6, refOpOffset) :: + * Nil + */ + def fromTypes(firstLocal: Int, types: Array[Type]): LocalsList = { + var sizeTwoOffset = 0 + val locals: List[Local] = types.indices.map(i => { + // The ASM method `type.getOpcode` returns the opcode for operating on a value of `type`. + val offset = types(i).getOpcode(ILOAD) - ILOAD + val local = Local(firstLocal + i + sizeTwoOffset, offset) + if (local.size == 2) sizeTwoOffset += 1 + local + })(collection.breakOut) + LocalsList(locals) + } + } + + /** + * Stores a local varaible index the opcode offset required for operating on that variable. + * + * The xLOAD / xSTORE opcodes are in the following sequence: I, L, F, D, A, so the offset for + * a local variable holding a reference (`A`) is 4. See also method `getOpcode` in [[scala.tools.asm.Type]]. + */ + case class Local(local: Int, opcodeOffset: Int) { + def size = if (loadOpcode == LLOAD || loadOpcode == DLOAD) 2 else 1 + + def loadOpcode = ILOAD + opcodeOffset + def storeOpcode = ISTORE + opcodeOffset + } +} diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala index b4f091b37fdf..e8e848161ca2 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala @@ -26,7 +26,8 @@ class Inliner[BT <: BTypes](val btypes: BT) { def eliminateUnreachableCodeAndUpdateCallGraph(methodNode: MethodNode, definingClass: InternalName): Unit = { localOpt.minimalRemoveUnreachableCode(methodNode, definingClass) foreach { - case invocation: MethodInsnNode => callGraph.callsites.remove(invocation) + case invocation: MethodInsnNode => callGraph.callsites.remove(invocation) + case indy: InvokeDynamicInsnNode => callGraph.closureInstantiations.remove(indy) case _ => } } @@ -432,7 +433,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { callsiteMethod.localVariables.addAll(cloneLocalVariableNodes(callee, labelsMap, callee.name + "_").asJava) callsiteMethod.tryCatchBlocks.addAll(cloneTryCatchBlockNodes(callee, labelsMap).asJava) - // Add all invocation instructions that were inlined to the call graph + // Add all invocation instructions and closure instantiations that were inlined to the call graph callee.instructions.iterator().asScala foreach { case originalCallsiteIns: MethodInsnNode => callGraph.callsites.get(originalCallsiteIns) match { @@ -452,6 +453,15 @@ class Inliner[BT <: BTypes](val btypes: BT) { case None => } + case indy: InvokeDynamicInsnNode => + callGraph.closureInstantiations.get(indy) match { + case Some((methodNode, ownerClass)) => + val newIndy = instructionMap(indy).asInstanceOf[InvokeDynamicInsnNode] + callGraph.closureInstantiations(newIndy) = (callsiteMethod, callsiteClass) + + case None => + } + case _ => } // Remove the elided invocation from the call graph @@ -529,98 +539,97 @@ class Inliner[BT <: BTypes](val btypes: BT) { } /** - * Returns the first instruction in the `instructions` list that would cause a - * [[java.lang.IllegalAccessError]] when inlined into the `destinationClass`. - * - * If validity of some instruction could not be checked because an error occurred, the instruction - * is returned together with a warning message that describes the problem. + * Check if a type is accessible to some class, as defined in JVMS 5.4.4. + * (A1) C is public + * (A2) C and D are members of the same run-time package */ - def findIllegalAccess(instructions: InsnList, calleeDeclarationClass: ClassBType, destinationClass: ClassBType): Option[(AbstractInsnNode, Option[OptimizerWarning])] = { - - /** - * Check if a type is accessible to some class, as defined in JVMS 5.4.4. - * (A1) C is public - * (A2) C and D are members of the same run-time package - */ - def classIsAccessible(accessed: BType, from: ClassBType = destinationClass): Either[OptimizerWarning, Boolean] = (accessed: @unchecked) match { - // TODO: A2 requires "same run-time package", which seems to be package + classloader (JMVS 5.3.). is the below ok? - case c: ClassBType => c.isPublic.map(_ || c.packageInternalName == from.packageInternalName) - case a: ArrayBType => classIsAccessible(a.elementType, from) - case _: PrimitiveBType => Right(true) - } + def classIsAccessible(accessed: BType, from: ClassBType): Either[OptimizerWarning, Boolean] = (accessed: @unchecked) match { + // TODO: A2 requires "same run-time package", which seems to be package + classloader (JMVS 5.3.). is the below ok? + case c: ClassBType => c.isPublic.map(_ || c.packageInternalName == from.packageInternalName) + case a: ArrayBType => classIsAccessible(a.elementType, from) + case _: PrimitiveBType => Right(true) + } - /** - * Check if a member reference is accessible from the [[destinationClass]], as defined in the - * JVMS 5.4.4. Note that the class name in a field / method reference is not necessarily the - * class in which the member is declared: - * - * class A { def f = 0 }; class B extends A { f } - * - * The INVOKEVIRTUAL instruction uses a method reference "B.f ()I". Therefore this method has - * two parameters: - * - * @param memberDeclClass The class in which the member is declared (A) - * @param memberRefClass The class used in the member reference (B) - * - * (B0) JVMS 5.4.3.2 / 5.4.3.3: when resolving a member of class C in D, the class C is resolved - * first. According to 5.4.3.1, this requires C to be accessible in D. - * - * JVMS 5.4.4 summary: A field or method R is accessible to a class D (destinationClass) iff - * (B1) R is public - * (B2) R is protected, declared in C (memberDeclClass) and D is a subclass of C. - * If R is not static, R must contain a symbolic reference to a class T (memberRefClass), - * such that T is either a subclass of D, a superclass of D, or D itself. - * Also (P) needs to be satisfied. - * (B3) R is either protected or has default access and declared by a class in the same - * run-time package as D. - * If R is protected, also (P) needs to be satisfied. - * (B4) R is private and is declared in D. - * - * (P) When accessing a protected instance member, the target object on the stack (the receiver) - * has to be a subtype of D (destinationClass). This is enforced by classfile verification - * (https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.10.1.8). - * - * TODO: we cannot currently implement (P) because we don't have the necessary information - * available. Once we have a type propagation analysis implemented, we can extract the receiver - * type from there (https://github.com/scala-opt/scala/issues/13). - */ - def memberIsAccessible(memberFlags: Int, memberDeclClass: ClassBType, memberRefClass: ClassBType): Either[OptimizerWarning, Boolean] = { - // TODO: B3 requires "same run-time package", which seems to be package + classloader (JMVS 5.3.). is the below ok? - def samePackageAsDestination = memberDeclClass.packageInternalName == destinationClass.packageInternalName - def targetObjectConformsToDestinationClass = false // needs type propagation analysis, see above - - def memberIsAccessibleImpl = { - val key = (ACC_PUBLIC | ACC_PROTECTED | ACC_PRIVATE) & memberFlags - key match { - case ACC_PUBLIC => // B1 - Right(true) - - case ACC_PROTECTED => // B2 - val isStatic = (ACC_STATIC & memberFlags) != 0 - tryEither { - val condB2 = destinationClass.isSubtypeOf(memberDeclClass).orThrow && { - isStatic || memberRefClass.isSubtypeOf(destinationClass).orThrow || destinationClass.isSubtypeOf(memberRefClass).orThrow - } - Right( - (condB2 || samePackageAsDestination /* B3 (protected) */) && - (isStatic || targetObjectConformsToDestinationClass) // (P) - ) + /** + * Check if a member reference is accessible from the [[destinationClass]], as defined in the + * JVMS 5.4.4. Note that the class name in a field / method reference is not necessarily the + * class in which the member is declared: + * + * class A { def f = 0 }; class B extends A { f } + * + * The INVOKEVIRTUAL instruction uses a method reference "B.f ()I". Therefore this method has + * two parameters: + * + * @param memberDeclClass The class in which the member is declared (A) + * @param memberRefClass The class used in the member reference (B) + * + * (B0) JVMS 5.4.3.2 / 5.4.3.3: when resolving a member of class C in D, the class C is resolved + * first. According to 5.4.3.1, this requires C to be accessible in D. + * + * JVMS 5.4.4 summary: A field or method R is accessible to a class D (destinationClass) iff + * (B1) R is public + * (B2) R is protected, declared in C (memberDeclClass) and D is a subclass of C. + * If R is not static, R must contain a symbolic reference to a class T (memberRefClass), + * such that T is either a subclass of D, a superclass of D, or D itself. + * Also (P) needs to be satisfied. + * (B3) R is either protected or has default access and declared by a class in the same + * run-time package as D. + * If R is protected, also (P) needs to be satisfied. + * (B4) R is private and is declared in D. + * + * (P) When accessing a protected instance member, the target object on the stack (the receiver) + * has to be a subtype of D (destinationClass). This is enforced by classfile verification + * (https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.10.1.8). + * + * TODO: we cannot currently implement (P) because we don't have the necessary information + * available. Once we have a type propagation analysis implemented, we can extract the receiver + * type from there (https://github.com/scala-opt/scala/issues/13). + */ + def memberIsAccessible(memberFlags: Int, memberDeclClass: ClassBType, memberRefClass: ClassBType, from: ClassBType): Either[OptimizerWarning, Boolean] = { + // TODO: B3 requires "same run-time package", which seems to be package + classloader (JMVS 5.3.). is the below ok? + def samePackageAsDestination = memberDeclClass.packageInternalName == from.packageInternalName + def targetObjectConformsToDestinationClass = false // needs type propagation analysis, see above + + def memberIsAccessibleImpl = { + val key = (ACC_PUBLIC | ACC_PROTECTED | ACC_PRIVATE) & memberFlags + key match { + case ACC_PUBLIC => // B1 + Right(true) + + case ACC_PROTECTED => // B2 + val isStatic = (ACC_STATIC & memberFlags) != 0 + tryEither { + val condB2 = from.isSubtypeOf(memberDeclClass).orThrow && { + isStatic || memberRefClass.isSubtypeOf(from).orThrow || from.isSubtypeOf(memberRefClass).orThrow } + Right( + (condB2 || samePackageAsDestination /* B3 (protected) */) && + (isStatic || targetObjectConformsToDestinationClass) // (P) + ) + } - case 0 => // B3 (default access) - Right(samePackageAsDestination) + case 0 => // B3 (default access) + Right(samePackageAsDestination) - case ACC_PRIVATE => // B4 - Right(memberDeclClass == destinationClass) - } + case ACC_PRIVATE => // B4 + Right(memberDeclClass == from) } + } - classIsAccessible(memberDeclClass) match { // B0 - case Right(true) => memberIsAccessibleImpl - case r => r - } + classIsAccessible(memberDeclClass, from) match { // B0 + case Right(true) => memberIsAccessibleImpl + case r => r } + } + /** + * Returns the first instruction in the `instructions` list that would cause a + * [[java.lang.IllegalAccessError]] when inlined into the `destinationClass`. + * + * If validity of some instruction could not be checked because an error occurred, the instruction + * is returned together with a warning message that describes the problem. + */ + def findIllegalAccess(instructions: InsnList, calleeDeclarationClass: ClassBType, destinationClass: ClassBType): Option[(AbstractInsnNode, Option[OptimizerWarning])] = { /** * Check if `instruction` can be transplanted to `destinationClass`. * @@ -637,18 +646,18 @@ class Inliner[BT <: BTypes](val btypes: BT) { // NEW, ANEWARRAY, CHECKCAST or INSTANCEOF. For these instructions, the reference // "must be a symbolic reference to a class, array, or interface type" (JVMS 6), so // it can be an internal name, or a full array descriptor. - classIsAccessible(bTypeForDescriptorOrInternalNameFromClassfile(ti.desc)) + classIsAccessible(bTypeForDescriptorOrInternalNameFromClassfile(ti.desc), destinationClass) case ma: MultiANewArrayInsnNode => // "a symbolic reference to a class, array, or interface type" - classIsAccessible(bTypeForDescriptorOrInternalNameFromClassfile(ma.desc)) + classIsAccessible(bTypeForDescriptorOrInternalNameFromClassfile(ma.desc), destinationClass) case fi: FieldInsnNode => val fieldRefClass = classBTypeFromParsedClassfile(fi.owner) for { (fieldNode, fieldDeclClassNode) <- byteCodeRepository.fieldNode(fieldRefClass.internalName, fi.name, fi.desc): Either[OptimizerWarning, (FieldNode, InternalName)] fieldDeclClass = classBTypeFromParsedClassfile(fieldDeclClassNode) - res <- memberIsAccessible(fieldNode.access, fieldDeclClass, fieldRefClass) + res <- memberIsAccessible(fieldNode.access, fieldDeclClass, fieldRefClass, destinationClass) } yield { res } @@ -664,7 +673,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { Right(destinationClass == calleeDeclarationClass) case _ => // INVOKEVIRTUAL, INVOKESTATIC, INVOKEINTERFACE and INVOKESPECIAL of constructors - memberIsAccessible(methodFlags, methodDeclClass, methodRefClass) + memberIsAccessible(methodFlags, methodDeclClass, methodRefClass, destinationClass) } } @@ -683,7 +692,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { Right(false) case ci: LdcInsnNode => ci.cst match { - case t: asm.Type => classIsAccessible(bTypeForDescriptorOrInternalNameFromClassfile(t.getInternalName)) + case t: asm.Type => classIsAccessible(bTypeForDescriptorOrInternalNameFromClassfile(t.getInternalName), destinationClass) case _ => Right(true) } diff --git a/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala b/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala index d3cdf69d3095..0cdece59e155 100644 --- a/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala +++ b/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala @@ -235,6 +235,7 @@ trait ScalaSettings extends AbsScalaSettings val emptyLabels = Choice("empty-labels", "Eliminate and collapse redundant labels in the bytecode.") val compactLocals = Choice("compact-locals", "Eliminate empty slots in the sequence of local variables.") val nullnessTracking = Choice("nullness-tracking", "Track nullness / non-nullness of local variables and apply optimizations.") + val closureElimination = Choice("closure-elimination" , "Rewrite closure invocations to the implementation method and eliminate closures.") val inlineProject = Choice("inline-project", "Inline only methods defined in the files being compiled.") val inlineGlobal = Choice("inline-global", "Inline methods from any source, including classfiles on the compile classpath.") @@ -243,7 +244,7 @@ trait ScalaSettings extends AbsScalaSettings private val defaultChoices = List(unreachableCode) val lDefault = Choice("l:default", "Enable default optimizations: "+ defaultChoices.mkString(","), expandsTo = defaultChoices) - private val methodChoices = List(unreachableCode, simplifyJumps, emptyLineNumbers, emptyLabels, compactLocals, nullnessTracking) + private val methodChoices = List(unreachableCode, simplifyJumps, emptyLineNumbers, emptyLabels, compactLocals, nullnessTracking, closureElimination) val lMethod = Choice("l:method", "Enable intra-method optimizations: "+ methodChoices.mkString(","), expandsTo = methodChoices) private val projectChoices = List(lMethod, inlineProject) @@ -266,11 +267,15 @@ trait ScalaSettings extends AbsScalaSettings def YoptEmptyLabels = Yopt.contains(YoptChoices.emptyLabels) def YoptCompactLocals = Yopt.contains(YoptChoices.compactLocals) def YoptNullnessTracking = Yopt.contains(YoptChoices.nullnessTracking) + def YoptClosureElimination = Yopt.contains(YoptChoices.closureElimination) def YoptInlineProject = Yopt.contains(YoptChoices.inlineProject) def YoptInlineGlobal = Yopt.contains(YoptChoices.inlineGlobal) def YoptInlinerEnabled = YoptInlineProject || YoptInlineGlobal + def YoptBuildCallGraph = YoptInlinerEnabled || YoptClosureElimination + def YoptAddToBytecodeRepository = YoptInlinerEnabled || YoptClosureElimination + val YoptInlineHeuristics = ChoiceSetting( name = "-Yopt-inline-heuristics", helpArg = "strategy",