diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index 68d3e70e91e88..cad030e0d5bd7 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -18,7 +18,9 @@ check_dependencies(__name__) +import hashlib import importlib +import io import sys import os import zlib @@ -41,6 +43,7 @@ ARCHIVE_PREFIX: str = "archives" FILE_PREFIX: str = "files" FORWARD_TO_FS_PREFIX: str = "forward_to_fs" +CACHE_PREFIX: str = "cache" class LocalData(metaclass=abc.ABCMeta): @@ -78,11 +81,26 @@ def stream(self) -> BinaryIO: return open(self.path, "rb") -class Artifact: +class InMemory(LocalData): """ Payload stored in memory. """ + def __init__(self, blob: bytes): + self.blob = blob + self._size: int + self._stream: int + + @cached_property + def size(self) -> int: + return len(self.blob) + + @cached_property + def stream(self) -> BinaryIO: + return io.BytesIO(self.blob) + + +class Artifact: def __init__(self, path: str, storage: LocalData): assert not Path(path).is_absolute(), f"Bad path: {path}" self.path = path @@ -113,6 +131,10 @@ def new_file_artifact(file_name: str, storage: LocalData) -> Artifact: return _new_artifact(FILE_PREFIX, "", file_name, storage) +def new_cache_artifact(id: str, storage: LocalData) -> Artifact: + return _new_artifact(CACHE_PREFIX, "", id, storage) + + def _new_artifact( prefix: str, required_suffix: str, file_name: str, storage: LocalData ) -> Artifact: @@ -351,3 +373,31 @@ def _add_chunked_artifact(self, artifact: Artifact) -> Iterator[proto.AddArtifac data=chunk, crc=zlib.crc32(chunk) ), ) + + def is_cached_artifact(self, hash: str) -> bool: + """ + Ask the server either any artifact with `hash` has been cached at the server side or not. + """ + artifactName = CACHE_PREFIX + "/" + hash + request = proto.ArtifactStatusesRequest( + user_context=self._user_context, session_id=self._session_id, names=[artifactName] + ) + resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(request) + status = resp.statuses.get(artifactName) + return status.exists if status is not None else False + + def cache_artifact(self, blob: bytes) -> str: + """ + Cache the give blob at the session. + """ + hash = hashlib.sha256(blob).hexdigest() + if not self.is_cached_artifact(hash): + requests = self._add_artifacts([new_cache_artifact(hash, InMemory(blob))]) + response: proto.AddArtifactsResponse = self._retrieve_responses(requests) + summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = [] + + for summary in response.artifacts: + summaries.append(summary) + # TODO(SPARK-42658): Handle responses containing CRC failures. + + return hash diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 9bc0822898f1e..ab285a2b862e7 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib import shutil import tempfile import unittest @@ -313,6 +314,15 @@ def test_copy_from_local_to_fs(self): with open(dest_path, "r") as f: self.assertEqual(f.read(), file_content) + def test_cache_artifact(self): + s = "Hello, World!" + blob = bytearray(s, "utf-8") + expected_hash = hashlib.sha256(blob).hexdigest() + self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), False) + actualHash = self.artifact_manager.cache_artifact(blob) + self.assertEqual(actualHash, expected_hash) + self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True) + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401