Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pan3793 committed Feb 24, 2025
1 parent 67782cb commit a028cc2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.SparkSession

import org.apache.kyuubi.util.reflect.{DynClasses, DynMethods}

object SparkPlanHelper {

private val sparkSessionMethod = DynMethods.builder("spark")
.impl(classOf[SparkPlan])
.buildChecked()

def sparkSession(sparkPlan: SparkPlan): SparkSession = {
sparkSessionMethod.invokeChecked[SparkSession](sparkPlan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.CollectLimitExec
import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlanHelper}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -157,7 +157,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
val partsToScan =
partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts))

val sc = collectLimitExec.session.sparkContext
val sc = SparkPlanHelper.sparkSession(collectLimitExec).sparkContext
val res = sc.runJob(
childRDD,
(it: Iterator[InternalRow]) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, HiveResult, LocalTableScanExec, QueryExecution, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, HiveResult, LocalTableScanExec, QueryExecution, SparkPlan, SparkPlanHelper, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -83,8 +83,9 @@ object SparkDatasetHelper extends Logging {
*/
def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
val schemaCaptured = plan.schema
val maxRecordsPerBatch = plan.session.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = plan.session.sessionState.conf.sessionLocalTimeZone
val spark = SparkPlanHelper.sparkSession(plan)
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
// note that, we can't pass the lazy variable `maxBatchSize` directly, this is because input
// arguments are serialized and sent to the executor side for execution.
val maxBatchSizePerBatch = maxBatchSize
Expand Down Expand Up @@ -169,8 +170,9 @@ object SparkDatasetHelper extends Logging {
}

private def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = {
val timeZoneId = collectLimit.session.sessionState.conf.sessionLocalTimeZone
val maxRecordsPerBatch = collectLimit.session.sessionState.conf.arrowMaxRecordsPerBatch
val spark = SparkPlanHelper.sparkSession(collectLimit)
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch

val batches = KyuubiArrowConverters.takeAsArrowBatches(
collectLimit,
Expand Down Expand Up @@ -199,7 +201,7 @@ object SparkDatasetHelper extends Logging {
}

private def doCommandResultExec(commandResult: CommandResultExec): Array[Array[Byte]] = {
val spark = commandResult.session
val spark = SparkPlanHelper.sparkSession(commandResult)
commandResult.longMetric("numOutputRows").add(commandResult.rows.size)
sendDriverMetrics(spark.sparkContext, commandResult.metrics)
KyuubiArrowConverters.toBatchIterator(
Expand All @@ -212,7 +214,7 @@ object SparkDatasetHelper extends Logging {
}

private def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = {
val spark = localTableScan.session
val spark = SparkPlanHelper.sparkSession(localTableScan)
localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size)
sendDriverMetrics(spark.sparkContext, localTableScan.metrics)
KyuubiArrowConverters.toBatchIterator(
Expand Down

0 comments on commit a028cc2

Please sign in to comment.