Skip to content

Commit

Permalink
Allow blobs and file systems to pickle (#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghidalgo3 authored Jul 8, 2024
1 parent 126ffc0 commit 3bd3d09
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Unreleased
----------

- `AzureBlobFileSystem` and `AzureBlobFile` support pickling.
- Handle mixed casing for `hdi_isfolder` metadata when determining whether a blob should be treated as a folder.
- `_put_file`: `overwrite` now defaults to `True`.

Expand Down
52 changes: 36 additions & 16 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __init__(
account_host: str = None,
**kwargs,
):
self.kwargs = kwargs.copy()
super_kwargs = {
k: kwargs.pop(k)
for k in ["use_listings_cache", "listings_expiry_time", "max_paths"]
Expand Down Expand Up @@ -1923,22 +1924,8 @@ def __init__(
else None
)

try:
# Need to confirm there is an event loop running in
# the thread. If not, create the fsspec loop
# and set it. This is to handle issues with
# Async Credentials from the Azure SDK
loop = get_running_loop()

except RuntimeError:
loop = get_loop()
asyncio.set_event_loop(loop)

self.loop = self.fs.loop or get_loop()
self.container_client = (
fs.service_client.get_container_client(self.container_name)
or self.connect_client()
)
self.loop = self._get_loop()
self.container_client = self._get_container_client()
self.blocksize = (
self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size
)
Expand Down Expand Up @@ -1977,6 +1964,8 @@ def __init__(
self.path, version_id=self.version_id, refresh=True
)
self.size = self.details["size"]
self.cache_type = cache_type
self.cache_options = cache_options
self.cache = caches[cache_type](
blocksize=self.blocksize,
fetcher=self._fetch_range,
Expand All @@ -1998,6 +1987,26 @@ def __init__(
self.forced = False
self.location = None

def _get_loop(self):
try:
# Need to confirm there is an event loop running in
# the thread. If not, create the fsspec loop
# and set it. This is to handle issues with
# Async Credentials from the Azure SDK
loop = get_running_loop()

except RuntimeError:
loop = get_loop()
asyncio.set_event_loop(loop)

return self.fs.loop or get_loop()

def _get_container_client(self):
return (
self.fs.service_client.get_container_client(self.container_name)
or self.connect_client()
)

def close(self):
"""Close file and azure client."""
asyncio.run_coroutine_threadsafe(close_container_client(self), loop=self.loop)
Expand Down Expand Up @@ -2187,3 +2196,14 @@ def __del__(self):
self.close()
except TypeError:
pass

def __getstate__(self):
state = self.__dict__.copy()
del state["container_client"]
del state["loop"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.loop = self._get_loop()
self.container_client = self._get_container_client()
3 changes: 2 additions & 1 deletion adlfs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def storage(host):
conn_str = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA

bbs = BlobServiceClient.from_connection_string(conn_str=conn_str)
bbs.create_container("data")
if "data" not in [c["name"] for c in bbs.list_containers()]:
bbs.create_container("data")
container_client = bbs.get_container_client(container="data")
bbs.insert_time = datetime.datetime.utcnow().replace(
microsecond=0, tzinfo=datetime.timezone.utc
Expand Down
32 changes: 32 additions & 0 deletions adlfs/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pickle

from adlfs import AzureBlobFileSystem

URL = "http://127.0.0.1:10000"
ACCOUNT_NAME = "devstoreaccount1"
KEY = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" # NOQA
CONN_STR = f"DefaultEndpointsProtocol=http;AccountName={ACCOUNT_NAME};AccountKey={KEY};BlobEndpoint={URL}/{ACCOUNT_NAME};" # NOQA


def test_fs_pickling(storage):
fs = AzureBlobFileSystem(
account_name=storage.account_name,
connection_string=CONN_STR,
kwarg1="some_value",
)
fs2: AzureBlobFileSystem = pickle.loads(pickle.dumps(fs))
assert "data" in fs.ls("")
assert "data" in fs2.ls("")
assert fs2.kwargs["kwarg1"] == "some_value"


def test_blob_pickling(storage):
fs = AzureBlobFileSystem(
account_name=storage.account_name, connection_string=CONN_STR
)
fs2: AzureBlobFileSystem = pickle.loads(pickle.dumps(fs))
blob = fs2.open("data/root/a/file.txt")
assert blob.read() == b"0123456789"
blob2 = pickle.loads(pickle.dumps(blob))
blob2.seek(0)
assert blob2.read() == b"0123456789"

0 comments on commit 3bd3d09

Please sign in to comment.