From 958b85418034d3bdce56a7520c0728d666c79480 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 8 Jun 2023 15:05:41 +0300 Subject: [PATCH] [SPARK-44006][CONNECT][PYTHON] Support cache artifacts ### What changes were proposed in this pull request? In the PR, I propose to extend Artifact API of the Python connect client by two new methods similarly to https://github.com/apache/spark/pull/40827: 1. `is_cached_artifact()` checks the cache of the given hash presents at the server side. 2. `cache_artifact()` caches a blob in memory at the server side. ### Why are the changes needed? To allow creating a dataframe from a large local collection. `spark.createDataFrame(...)` fails with the following error w/o the changes: ```python pyspark.errors.exceptions.connect.SparkConnectGrpcException: <_MultiThreadedRendezvous of RPC that terminated with: status = StatusCode.RESOURCE_EXHAUSTED details = "Sent message larger than max (629146388 vs. 134217728)" debug_error_string = "UNKNOWN:Error received from peer localhost:58218 {grpc_message:"Sent message larger than max (629146388 vs. 134217728)", grpc_status:8, created_time:"2023-06-05T18:35:50.912817+03:00"}" ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running new tests: ``` $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.client.test_artifact ArtifactTests' ``` Closes #41465 from MaxGekk/streaming-createDataFrame-python-3. Authored-by: Max Gekk Signed-off-by: Max Gekk --- python/pyspark/sql/connect/client/artifact.py | 52 ++++++++++++++++++- .../sql/tests/connect/client/test_artifact.py | 10 ++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index 68d3e70e91e88..cad030e0d5bd7 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -18,7 +18,9 @@ check_dependencies(__name__) +import hashlib import importlib +import io import sys import os import zlib @@ -41,6 +43,7 @@ ARCHIVE_PREFIX: str = "archives" FILE_PREFIX: str = "files" FORWARD_TO_FS_PREFIX: str = "forward_to_fs" +CACHE_PREFIX: str = "cache" class LocalData(metaclass=abc.ABCMeta): @@ -78,11 +81,26 @@ def stream(self) -> BinaryIO: return open(self.path, "rb") -class Artifact: +class InMemory(LocalData): """ Payload stored in memory. """ + def __init__(self, blob: bytes): + self.blob = blob + self._size: int + self._stream: int + + @cached_property + def size(self) -> int: + return len(self.blob) + + @cached_property + def stream(self) -> BinaryIO: + return io.BytesIO(self.blob) + + +class Artifact: def __init__(self, path: str, storage: LocalData): assert not Path(path).is_absolute(), f"Bad path: {path}" self.path = path @@ -113,6 +131,10 @@ def new_file_artifact(file_name: str, storage: LocalData) -> Artifact: return _new_artifact(FILE_PREFIX, "", file_name, storage) +def new_cache_artifact(id: str, storage: LocalData) -> Artifact: + return _new_artifact(CACHE_PREFIX, "", id, storage) + + def _new_artifact( prefix: str, required_suffix: str, file_name: str, storage: LocalData ) -> Artifact: @@ -351,3 +373,31 @@ def _add_chunked_artifact(self, artifact: Artifact) -> Iterator[proto.AddArtifac data=chunk, crc=zlib.crc32(chunk) ), ) + + def is_cached_artifact(self, hash: str) -> bool: + """ + Ask the server either any artifact with `hash` has been cached at the server side or not. + """ + artifactName = CACHE_PREFIX + "/" + hash + request = proto.ArtifactStatusesRequest( + user_context=self._user_context, session_id=self._session_id, names=[artifactName] + ) + resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(request) + status = resp.statuses.get(artifactName) + return status.exists if status is not None else False + + def cache_artifact(self, blob: bytes) -> str: + """ + Cache the give blob at the session. + """ + hash = hashlib.sha256(blob).hexdigest() + if not self.is_cached_artifact(hash): + requests = self._add_artifacts([new_cache_artifact(hash, InMemory(blob))]) + response: proto.AddArtifactsResponse = self._retrieve_responses(requests) + summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = [] + + for summary in response.artifacts: + summaries.append(summary) + # TODO(SPARK-42658): Handle responses containing CRC failures. + + return hash diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 9bc0822898f1e..ab285a2b862e7 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib import shutil import tempfile import unittest @@ -313,6 +314,15 @@ def test_copy_from_local_to_fs(self): with open(dest_path, "r") as f: self.assertEqual(f.read(), file_content) + def test_cache_artifact(self): + s = "Hello, World!" + blob = bytearray(s, "utf-8") + expected_hash = hashlib.sha256(blob).hexdigest() + self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), False) + actualHash = self.artifact_manager.cache_artifact(blob) + self.assertEqual(actualHash, expected_hash) + self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True) + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401