Skip to content

Commit

Permalink
Merge branch 'master' into SPARK-50370
Browse files Browse the repository at this point in the history
  • Loading branch information
panbingkun committed Nov 22, 2024
2 parents bf9c287 + d5da49d commit 9a2026f
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer(
*/
private def doStart(startPort: Int): (Server, Int) = {
val threadPool = new QueuedThreadPool(masterConf.get(MASTER_REST_SERVER_MAX_THREADS))
threadPool.setName(getClass().getSimpleName())
if (Utils.isJavaVersionAtLeast21 && masterConf.get(MASTER_REST_SERVER_VIRTUAL_THREADS)) {
val newVirtualThreadPerTaskExecutor =
classOf[Executors].getMethod("newVirtualThreadPerTaskExecutor")
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Dict,
Set,
NoReturn,
Mapping,
cast,
TYPE_CHECKING,
Type,
Expand Down Expand Up @@ -1576,6 +1577,10 @@ def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
configs = dict(self.config(op).pairs)
return tuple(configs.get(key) for key in keys)

def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]:
op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
return dict(self.config(op).pairs)

def get_config_with_defaults(
self, *pairs: Tuple[str, Optional[str]]
) -> Tuple[Optional[str], ...]:
Expand Down
57 changes: 34 additions & 23 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
from pyspark.sql.utils import is_timestamp_ntz_preferred

check_dependencies(__name__)

Expand All @@ -37,6 +36,7 @@
cast,
overload,
Iterable,
Mapping,
TYPE_CHECKING,
ClassVar,
)
Expand Down Expand Up @@ -407,7 +407,10 @@ def clearProgressHandlers(self) -> None:
clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__

def _inferSchemaFromList(
self, data: Iterable[Any], names: Optional[List[str]] = None
self,
data: Iterable[Any],
names: Optional[List[str]],
configs: Mapping[str, Optional[str]],
) -> StructType:
"""
Infer schema from list of Row, dict, or tuple.
Expand All @@ -422,12 +425,12 @@ def _inferSchemaFromList(
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
) = self._client.get_configs(
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
"spark.sql.timestampType",
prefer_timestamp,
) = (
configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"],
configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"],
configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"],
configs["spark.sql.timestampType"],
)
return functools.reduce(
_merge_type,
Expand All @@ -438,7 +441,7 @@ def _inferSchemaFromList(
infer_dict_as_struct=(infer_dict_as_struct == "true"),
infer_array_from_first_element=(infer_array_from_first_element == "true"),
infer_map_from_first_pair=(infer_map_from_first_pair == "true"),
prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"),
prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"),
)
for row in data
),
Expand Down Expand Up @@ -508,8 +511,21 @@ def createDataFrame(
messageParameters={},
)

# Get all related configs in a batch
configs = self._client.get_config_dict(
"spark.sql.timestampType",
"spark.sql.session.timeZone",
"spark.sql.session.localRelationCacheThreshold",
"spark.sql.execution.pandas.convertToArrowArraySafely",
"spark.sql.execution.pandas.inferPandasDictAsMap",
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
)
timezone = configs["spark.sql.session.timeZone"]
prefer_timestamp = configs["spark.sql.timestampType"]

_table: Optional[pa.Table] = None
timezone: Optional[str] = None

if isinstance(data, pd.DataFrame):
# Logic was borrowed from `_create_from_pandas_with_arrow` in
Expand All @@ -519,8 +535,7 @@ def createDataFrame(
if schema is None:
_cols = [str(x) if not isinstance(x, str) else x for x in data.columns]
infer_pandas_dict_as_map = (
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
== "true"
configs["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true"
)
if infer_pandas_dict_as_map:
struct = StructType()
Expand Down Expand Up @@ -572,9 +587,7 @@ def createDataFrame(
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]

timezone, safecheck = self._client.get_configs(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
)
safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"]

ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")

Expand All @@ -596,10 +609,6 @@ def createDataFrame(
).cast(arrow_schema)

elif isinstance(data, pa.Table):
prefer_timestamp_ntz = is_timestamp_ntz_preferred()

(timezone,) = self._client.get_configs("spark.sql.session.timeZone")

# If no schema supplied by user then get the names of columns only
if schema is None:
_cols = data.column_names
Expand All @@ -609,7 +618,9 @@ def createDataFrame(
_num_cols = len(_cols)

if not isinstance(schema, StructType):
schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=prefer_timestamp_ntz)
schema = from_arrow_schema(
data.schema, prefer_timestamp_ntz=prefer_timestamp == "TIMESTAMP_NTZ"
)

_table = (
_check_arrow_table_timestamps_localize(data, schema, True, timezone)
Expand Down Expand Up @@ -671,7 +682,7 @@ def createDataFrame(
if not isinstance(_schema, StructType):
_schema = StructType().add("value", _schema)
else:
_schema = self._inferSchemaFromList(_data, _cols)
_schema = self._inferSchemaFromList(_data, _cols, configs)

if _cols is not None and cast(int, _num_cols) < len(_cols):
_num_cols = len(_cols)
Expand Down Expand Up @@ -706,9 +717,9 @@ def createDataFrame(
else:
local_relation = LocalRelation(_table)

cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
cache_threshold = configs["spark.sql.session.localRelationCacheThreshold"]
plan: LogicalPlan = local_relation
if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes:
if cache_threshold is not None and int(cache_threshold) <= _table.nbytes:
plan = CachedLocalRelation(self._cache_local_relation(local_relation))

df = DataFrame(plan, self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ private[sql] trait SqlApiConf {
def stackTracesInDataFrameContext: Int
def dataFrameQueryContextEnabled: Boolean
def legacyAllowUntypedScalaUDFs: Boolean
def allowReadingUnknownCollations: Boolean
}

private[sql] object SqlApiConf {
Expand All @@ -60,7 +59,6 @@ private[sql] object SqlApiConf {
SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY
}
val DEFAULT_COLLATION: String = SqlApiConfHelper.DEFAULT_COLLATION
val ALLOW_READING_UNKNOWN_COLLATIONS: String = SqlApiConfHelper.ALLOW_READING_UNKNOWN_COLLATIONS

def get: SqlApiConf = SqlApiConfHelper.getConfGetter.get()()

Expand Down Expand Up @@ -89,5 +87,4 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf {
override def stackTracesInDataFrameContext: Int = 1
override def dataFrameQueryContextEnabled: Boolean = true
override def legacyAllowUntypedScalaUDFs: Boolean = false
override def allowReadingUnknownCollations: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ private[sql] object SqlApiConfHelper {
val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone"
val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = "spark.sql.session.localRelationCacheThreshold"
val DEFAULT_COLLATION: String = "spark.sql.session.collation.default"
val ALLOW_READING_UNKNOWN_COLLATIONS: String =
"spark.sql.collation.allowReadingUnknownCollations"

val confGetter: AtomicReference[() => SqlApiConf] = {
new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf)
Expand Down
26 changes: 4 additions & 22 deletions sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.{SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis
import org.apache.spark.sql.catalyst.parser.DataTypeParser
Expand Down Expand Up @@ -340,17 +340,8 @@ object DataType {
fields.collect { case (fieldPath, JString(collation)) =>
collation.split("\\.", 2) match {
case Array(provider: String, collationName: String) =>
try {
CollationFactory.assertValidProvider(provider)
fieldPath -> collationName
} catch {
case e: SparkException
if e.getCondition == "COLLATION_INVALID_PROVIDER" &&
SqlApiConf.get.allowReadingUnknownCollations =>
// If the collation provider is unknown and the config for reading such
// collations is enabled, return the UTF8_BINARY collation.
fieldPath -> "UTF8_BINARY"
}
CollationFactory.assertValidProvider(provider)
fieldPath -> collationName
}
}.toMap

Expand All @@ -359,16 +350,7 @@ object DataType {
}

private def stringTypeWithCollation(collationName: String): StringType = {
try {
StringType(CollationFactory.collationNameToId(collationName))
} catch {
case e: SparkException
if e.getCondition == "COLLATION_INVALID_NAME" &&
SqlApiConf.get.allowReadingUnknownCollations =>
// If the collation name is unknown and the config for reading such collations is enabled,
// return the UTF8_BINARY collation.
StringType(CollationFactory.UTF8_BINARY_COLLATION_ID)
}
StringType(CollationFactory.collationNameToId(collationName))
}

protected[types] def buildFormattedString(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,6 @@ object SQLConf {
.booleanConf
.createWithDefault(Utils.isTesting)

val ALLOW_READING_UNKNOWN_COLLATIONS =
buildConf(SqlApiConfHelper.ALLOW_READING_UNKNOWN_COLLATIONS)
.internal()
.doc("Enables spark to read unknown collation name as UTF8_BINARY. If the config is " +
"not enabled, when spark encounters an unknown collation name, it will throw an error.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val DEFAULT_COLLATION =
buildConf(SqlApiConfHelper.DEFAULT_COLLATION)
.doc("Sets default collation to use for string literals, parameter markers or the string" +
Expand Down Expand Up @@ -5582,8 +5573,6 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
}
}

override def allowReadingUnknownCollations: Boolean = getConf(ALLOW_READING_UNKNOWN_COLLATIONS)

def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)

def adaptiveExecutionLogLevel: String = getConf(ADAPTIVE_EXECUTION_LOG_LEVEL)
Expand Down
Loading

0 comments on commit 9a2026f

Please sign in to comment.