Skip to content

Commit

Permalink
[SPARK-44006][CONNECT][PYTHON] Support cache artifacts
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In the PR, I propose to extend Artifact API of the Python connect client by two new methods similarly to apache#40827:
1. `is_cached_artifact()` checks the cache of the given hash presents at the server side.
2. `cache_artifact()` caches a blob in memory at the server side.

### Why are the changes needed?
To allow creating a dataframe from a large local collection. `spark.createDataFrame(...)` fails with the following error w/o the changes:
```python
pyspark.errors.exceptions.connect.SparkConnectGrpcException: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.RESOURCE_EXHAUSTED
	details = "Sent message larger than max (629146388 vs. 134217728)"
	debug_error_string = "UNKNOWN:Error received from peer localhost:58218 {grpc_message:"Sent message larger than max (629146388 vs. 134217728)", grpc_status:8, created_time:"2023-06-05T18:35:50.912817+03:00"}"
```

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By running new tests:
```
$ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.client.test_artifact ArtifactTests'
```

Closes apache#41465 from MaxGekk/streaming-createDataFrame-python-3.

Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
MaxGekk committed Jun 8, 2023
1 parent 3cae38b commit 958b854
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
52 changes: 51 additions & 1 deletion python/pyspark/sql/connect/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

check_dependencies(__name__)

import hashlib
import importlib
import io
import sys
import os
import zlib
Expand All @@ -41,6 +43,7 @@
ARCHIVE_PREFIX: str = "archives"
FILE_PREFIX: str = "files"
FORWARD_TO_FS_PREFIX: str = "forward_to_fs"
CACHE_PREFIX: str = "cache"


class LocalData(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -78,11 +81,26 @@ def stream(self) -> BinaryIO:
return open(self.path, "rb")


class Artifact:
class InMemory(LocalData):
"""
Payload stored in memory.
"""

def __init__(self, blob: bytes):
self.blob = blob
self._size: int
self._stream: int

@cached_property
def size(self) -> int:
return len(self.blob)

@cached_property
def stream(self) -> BinaryIO:
return io.BytesIO(self.blob)


class Artifact:
def __init__(self, path: str, storage: LocalData):
assert not Path(path).is_absolute(), f"Bad path: {path}"
self.path = path
Expand Down Expand Up @@ -113,6 +131,10 @@ def new_file_artifact(file_name: str, storage: LocalData) -> Artifact:
return _new_artifact(FILE_PREFIX, "", file_name, storage)


def new_cache_artifact(id: str, storage: LocalData) -> Artifact:
return _new_artifact(CACHE_PREFIX, "", id, storage)


def _new_artifact(
prefix: str, required_suffix: str, file_name: str, storage: LocalData
) -> Artifact:
Expand Down Expand Up @@ -351,3 +373,31 @@ def _add_chunked_artifact(self, artifact: Artifact) -> Iterator[proto.AddArtifac
data=chunk, crc=zlib.crc32(chunk)
),
)

def is_cached_artifact(self, hash: str) -> bool:
"""
Ask the server either any artifact with `hash` has been cached at the server side or not.
"""
artifactName = CACHE_PREFIX + "/" + hash
request = proto.ArtifactStatusesRequest(
user_context=self._user_context, session_id=self._session_id, names=[artifactName]
)
resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(request)
status = resp.statuses.get(artifactName)
return status.exists if status is not None else False

def cache_artifact(self, blob: bytes) -> str:
"""
Cache the give blob at the session.
"""
hash = hashlib.sha256(blob).hexdigest()
if not self.is_cached_artifact(hash):
requests = self._add_artifacts([new_cache_artifact(hash, InMemory(blob))])
response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []

for summary in response.artifacts:
summaries.append(summary)
# TODO(SPARK-42658): Handle responses containing CRC failures.

return hash
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/connect/client/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import shutil
import tempfile
import unittest
Expand Down Expand Up @@ -313,6 +314,15 @@ def test_copy_from_local_to_fs(self):
with open(dest_path, "r") as f:
self.assertEqual(f.read(), file_content)

def test_cache_artifact(self):
s = "Hello, World!"
blob = bytearray(s, "utf-8")
expected_hash = hashlib.sha256(blob).hexdigest()
self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), False)
actualHash = self.artifact_manager.cache_artifact(blob)
self.assertEqual(actualHash, expected_hash)
self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True)


if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401
Expand Down

0 comments on commit 958b854

Please sign in to comment.