diff --git a/obstore/python/obstore/fsspec.py b/obstore/python/obstore/fsspec.py index 09d7978..3af7afa 100644 --- a/obstore/python/obstore/fsspec.py +++ b/obstore/python/obstore/fsspec.py @@ -24,12 +24,35 @@ import asyncio from collections import defaultdict -from typing import Any, Coroutine, Dict, List, Tuple +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Coroutine, + Dict, + List, + Tuple, +) +from urllib.parse import urlparse import fsspec.asyn import fsspec.spec import obstore as obs +from obstore import Bytes +from obstore.store import from_url + +if TYPE_CHECKING: + from obstore.store import ( + AzureConfig, + AzureConfigInput, + ClientConfig, + GCSConfig, + GCSConfigInput, + RetryConfig, + S3Config, + S3ConfigInput, + ) class AsyncFsspecStore(fsspec.asyn.AsyncFileSystem): @@ -40,11 +63,32 @@ class AsyncFsspecStore(fsspec.asyn.AsyncFileSystem): """ cachable = False + config: ( + S3Config + | S3ConfigInput + | GCSConfig + | GCSConfigInput + | AzureConfig + | AzureConfigInput + | None + ) + client_options: ClientConfig | None + retry_config: RetryConfig | None def __init__( self, - store: obs.store.ObjectStore, *args, + config: ( + S3Config + | S3ConfigInput + | GCSConfig + | GCSConfigInput + | AzureConfig + | AzureConfigInput + | None + ) = None, + client_options: ClientConfig | None = None, + retry_config: RetryConfig | None = None, asynchronous: bool = False, loop: Any = None, batch_size: int | None = None, @@ -52,7 +96,11 @@ def __init__( """Construct a new AsyncFsspecStore Args: - store: a configured instance of one of the store classes in `obstore.store`. + config: Configuration for the cloud storage provider, which can be one of + S3Config, S3ConfigInput, GCSConfig, GCSConfigInput, AzureConfig, + or AzureConfigInput. If None, no cloud storage configuration is applied. + client_options: Additional options for configuring the client. + retry_config: Configuration for handling request errors. asynchronous: Set to `True` if this instance is meant to be be called using the fsspec async API. This should only be set to true when running within a coroutine. @@ -75,26 +123,88 @@ def __init__( ``` """ - self.store = store + self.config = config + self.client_options = client_options + self.retry_config = retry_config + super().__init__( *args, asynchronous=asynchronous, loop=loop, batch_size=batch_size ) + def _split_path(self, path: str) -> Tuple[str, str]: + """ + Split bucket and file path + + Args: + path (str): Input path, like `s3://mybucket/path/to/file` + + Examples: + >>> split_path("s3://mybucket/path/to/file") + ['mybucket', 'path/to/file'] + """ + + protocol_with_bucket = ["s3", "s3a", "gcs", "gs", "abfs", "https", "http"] + + if not self.protocol in protocol_with_bucket: + # no bucket name in path + return "", path + + res = urlparse(path) + if res.scheme: + if res.scheme != self.protocol: + raise ValueError( + f"Expect protocol to be {self.protocol}. Got {res.scheme}" + ) + path = res.netloc + res.path + + if "/" not in path: + return path, "" + else: + path_li = path.split("/") + bucket = path_li[0] + file_path = "/".join(path_li[1:]) + return (bucket, file_path) + + @lru_cache(maxsize=10) + def _construct_store(self, bucket: str): + return from_url( + url=f"{self.protocol}://{bucket}", + config=self.config, + client_options=self.client_options, + retry_config=self.retry_config if self.retry_config else None, + ) + async def _rm_file(self, path, **kwargs): - return await obs.delete_async(self.store, path) + bucket, path = self._split_path(path) + store = self._construct_store(bucket) + return await obs.delete_async(store, path) async def _cp_file(self, path1, path2, **kwargs): - return await obs.copy_async(self.store, path1, path2) + bucket1, path1 = self._split_path(path1) + bucket2, path2 = self._split_path(path2) + + if bucket1 != bucket2: + raise ValueError( + f"Bucket mismatch: Source bucket '{bucket1}' and destination bucket '{bucket2}' must be the same." + ) - async def _pipe_file(self, path, value, **kwargs): - return await obs.put_async(self.store, path, value) + store = self._construct_store(bucket1) + return await obs.copy_async(store, path1, path2) + + async def _pipe_file(self, path, value, mode="overwrite", **kwargs): + bucket, path = self._split_path(path) + store = self._construct_store(bucket) + return await obs.put_async(store, path, value) async def _cat_file(self, path, start=None, end=None, **kwargs): + bucket, path = self._split_path(path) + store = self._construct_store(bucket) + if start is None and end is None: - resp = await obs.get_async(self.store, path) - return await resp.bytes_async() + resp = await obs.get_async(store, path) + return (await resp.bytes_async()).to_bytes() - range_bytes = await obs.get_range_async(self.store, path, start=start, end=end) + range_bytes = await obs.get_range_async(store, path, start=start, end=end) return range_bytes.to_bytes() async def _cat_ranges( @@ -118,11 +228,14 @@ async def _cat_ranges( for idx, (path, start, end) in enumerate(zip(paths, starts, ends)): per_file_requests[path].append((start, end, idx)) - futs: List[Coroutine[Any, Any, List[bytes]]] = [] + futs: List[Coroutine[Any, Any, List[Bytes]]] = [] for path, ranges in per_file_requests.items(): + bucket, path = self._split_path(path) + store = self._construct_store(bucket) + offsets = [r[0] for r in ranges] ends = [r[1] for r in ranges] - fut = obs.get_ranges_async(self.store, path, starts=offsets, ends=ends) + fut = obs.get_ranges_async(store, path, starts=offsets, ends=ends) futs.append(fut) result = await asyncio.gather(*futs) @@ -137,17 +250,40 @@ async def _cat_ranges( return output_buffers async def _put_file(self, lpath, rpath, **kwargs): + lbucket, lpath = self._split_path(lpath) + rbucket, rpath = self._split_path(rpath) + + if lbucket != rbucket: + raise ValueError( + f"Bucket mismatch: Source bucket '{lbucket}' and destination bucket '{rbucket}' must be the same." + ) + + store = self._construct_store(lbucket) + with open(lpath, "rb") as f: - await obs.put_async(self.store, rpath, f) + await obs.put_async(store, rpath, f) async def _get_file(self, rpath, lpath, **kwargs): + lbucket, lpath = self._split_path(lpath) + rbucket, rpath = self._split_path(rpath) + + if lbucket != rbucket: + raise ValueError( + f"Bucket mismatch: Source bucket '{lbucket}' and destination bucket '{rbucket}' must be the same." + ) + + store = self._construct_store(lbucket) + with open(lpath, "wb") as f: - resp = await obs.get_async(self.store, rpath) + resp = await obs.get_async(store, rpath) async for buffer in resp.stream(): f.write(buffer) async def _info(self, path, **kwargs): - head = await obs.head_async(self.store, path) + bucket, path = self._split_path(path) + store = self._construct_store(bucket) + + head = await obs.head_async(store, path) return { # Required of `info`: (?) "name": head["path"], @@ -159,25 +295,50 @@ async def _info(self, path, **kwargs): "version": head["version"], } + def _fill_bucket_name(self, path, bucket): + return f"{bucket}/{path}" + async def _ls(self, path, detail=True, **kwargs): - result = await obs.list_with_delimiter_async(self.store, path) + bucket, path = self._split_path(path) + store = self._construct_store(bucket) + + result = await obs.list_with_delimiter_async(store, path) objects = result["objects"] prefs = result["common_prefixes"] if detail: return [ { - "name": object["path"], + "name": self._fill_bucket_name(object["path"], bucket), "size": object["size"], "type": "file", "e_tag": object["e_tag"], } for object in objects - ] + [{"name": object, "size": 0, "type": "directory"} for object in prefs] + ] + [ + { + "name": self._fill_bucket_name(pref, bucket), + "size": 0, + "type": "directory", + } + for pref in prefs + ] else: - return sorted([object["path"] for object in objects] + prefs) + return sorted( + [self._fill_bucket_name(object["path"], bucket) for object in objects] + + [self._fill_bucket_name(pref, bucket) for pref in prefs] + ) - def _open(self, path, mode="rb", **kwargs): + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): """Return raw bytes-mode file-like from the file-system""" + return BufferedFileSimple(self, path, mode, **kwargs) @@ -201,3 +362,65 @@ def read(self, length: int = -1): data = self.fs.cat_file(self.path, self.loc, self.loc + length) self.loc += length return data + + +def register(protocol: str | list[str], asynchronous: bool = False): + """ + Dynamically register a subclass of AsyncFsspecStore for the given protocol(s). + + This function creates a new subclass of AsyncFsspecStore with the specified + protocol and registers it with fsspec. If multiple protocols are provided, + the function registers each one individually. + + Args: + protocol (str | list[str]): A single protocol (e.g., "s3", "gcs", "abfs") or + a list of protocols to register AsyncFsspecStore for. + asynchronous (bool, optional): If True, the registered store will support + asynchronous operations. Defaults to False. + + Example: + >>> register("s3") + >>> register("s3", asynchronous=True) # Registers an async-store for "s3" + >>> register(["gcs", "abfs"]) # Registers both "gcs" and "abfs" + + Notes: + - Each protocol gets a dynamically generated subclass named `AsyncFsspecStore_`. + - This avoids modifying the original AsyncFsspecStore class. + """ + + # Ensure protocol is of type str or list + if not isinstance(protocol, (str, list)): + raise TypeError( + f"Protocol must be a string or a list of strings, got {type(protocol)}" + ) + + # Ensure protocol is not None or empty + if not protocol: + raise ValueError( + "Protocol must be a non-empty string or a list of non-empty strings." + ) + + if isinstance(protocol, list): + # Ensure all elements are strings + if not all(isinstance(p, str) for p in protocol): + raise TypeError("All protocols in the list must be strings.") + # Ensure no empty strings in the list + if not all(p for p in protocol): + raise ValueError("Protocol names in the list must be non-empty strings.") + + for p in protocol: + register(p) + return + + fsspec.register_implementation( + protocol, + type( + f"AsyncFsspecStore_{protocol}", # Unique class name + (AsyncFsspecStore,), # Base class + { + "protocol": protocol, + "asynchronous": asynchronous, + }, # Assign protocol dynamically + ), + clobber=True, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9739e93..897babe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,3 +51,13 @@ def s3_store(s3): "AWS_ALLOW_HTTP": "true", }, ) + + +@pytest.fixture() +def s3_store_config(s3): + return { + "AWS_ENDPOINT_URL": s3, + "AWS_REGION": "us-east-1", + "AWS_SKIP_SIGNATURE": "True", + "AWS_ALLOW_HTTP": "true", + } diff --git a/tests/test_fsspec.py b/tests/test_fsspec.py index ce9a1bf..cff6b2a 100644 --- a/tests/test_fsspec.py +++ b/tests/test_fsspec.py @@ -1,100 +1,174 @@ import os +import fsspec import pyarrow.parquet as pq import pytest -import obstore as obs -from obstore.fsspec import AsyncFsspecStore +from obstore.fsspec import AsyncFsspecStore, register +from tests.conftest import TEST_BUCKET_NAME + + +def test_register(): + """Test that register properly creates and registers a subclass for a given protocol.""" + register("s3") # Register the "s3" protocol dynamically + fs_class = fsspec.get_filesystem_class("s3") + + assert issubclass( + fs_class, AsyncFsspecStore + ), "Registered class should be a subclass of AsyncFsspecStore" + assert ( + fs_class.protocol == "s3" + ), "Registered class should have the correct protocol" + + # Ensure a new instance of the registered store can be created + fs_instance = fs_class() + assert isinstance( + fs_instance, AsyncFsspecStore + ), "Registered class should be instantiable" + + # test register asynchronous + register("gcs", asynchronous=True) # Register the "s3" protocol dynamically + fs_class = fsspec.get_filesystem_class("gcs") + assert fs_class.asynchronous == True, "Registered class should be asynchronous" + + # test multiple registrations + register(["file", "abfs"]) + assert issubclass(fsspec.get_filesystem_class("file"), AsyncFsspecStore) + assert issubclass(fsspec.get_filesystem_class("abfs"), AsyncFsspecStore) + + +def test_register_invalid_types(): + """Test that register rejects invalid input types.""" + with pytest.raises(TypeError): + register(123) # Not a string or list + + with pytest.raises(TypeError): + register(["s3", 42]) # List contains a non-string + + with pytest.raises(ValueError): + register(["s3", ""]) # List contains a non-string + + with pytest.raises(TypeError): + register(None) # None is invalid + + with pytest.raises(ValueError): + register([]) # Empty list is invalid @pytest.fixture() -def fs(s3_store): - return AsyncFsspecStore(s3_store) +def fs(s3_store_config): + register("s3") + return fsspec.filesystem("s3", config=s3_store_config) def test_list(fs): - out = fs.ls("", detail=False) - assert out == ["afile"] - fs.pipe_file("dir/bfile", b"data") - out = fs.ls("", detail=False) - assert out == ["afile", "dir"] - out = fs.ls("", detail=True) + out = fs.ls(f"{TEST_BUCKET_NAME}", detail=False) + assert out == [f"{TEST_BUCKET_NAME}/afile"] + fs.pipe_file(f"{TEST_BUCKET_NAME}/dir/bfile", b"data") + out = fs.ls(f"{TEST_BUCKET_NAME}", detail=False) + assert out == [f"{TEST_BUCKET_NAME}/afile", f"{TEST_BUCKET_NAME}/dir"] + out = fs.ls(f"{TEST_BUCKET_NAME}", detail=True) assert out[0]["type"] == "file" assert out[1]["type"] == "directory" @pytest.mark.asyncio -async def test_list_async(s3_store): - fs = AsyncFsspecStore(s3_store, asynchronous=True) - out = await fs._ls("", detail=False) - assert out == ["afile"] - await fs._pipe_file("dir/bfile", b"data") - out = await fs._ls("", detail=False) - assert out == ["afile", "dir"] - out = await fs._ls("", detail=True) +async def test_list_async(s3_store_config): + register("s3") + fs = fsspec.filesystem("s3", config=s3_store_config, asynchronous=True) + + out = await fs._ls(f"{TEST_BUCKET_NAME}", detail=False) + assert out == [f"{TEST_BUCKET_NAME}/afile"] + await fs._pipe_file(f"{TEST_BUCKET_NAME}/dir/bfile", b"data") + out = await fs._ls(f"{TEST_BUCKET_NAME}", detail=False) + assert out == [f"{TEST_BUCKET_NAME}/afile", f"{TEST_BUCKET_NAME}/dir"] + out = await fs._ls(f"{TEST_BUCKET_NAME}", detail=True) assert out[0]["type"] == "file" assert out[1]["type"] == "directory" @pytest.mark.network def test_remote_parquet(): - store = obs.store.HTTPStore.from_url("https://github.com") - fs = AsyncFsspecStore(store) - url = "opengeospatial/geoparquet/raw/refs/heads/main/examples/example.parquet" + register("https") + fs = fsspec.filesystem("https") + url = "github.com/opengeospatial/geoparquet/raw/refs/heads/main/examples/example.parquet" + pq.read_metadata(url, filesystem=fs) + + # also test with full url + url = "https://github.com/opengeospatial/geoparquet/raw/refs/heads/main/examples/example.parquet" pq.read_metadata(url, filesystem=fs) def test_multi_file_ops(fs): - data = {"dir/test1": b"test data1", "dir/test2": b"test data2"} + data = { + f"{TEST_BUCKET_NAME}/dir/test1": b"test data1", + f"{TEST_BUCKET_NAME}/dir/test2": b"test data2", + } fs.pipe(data) out = fs.cat(list(data)) assert out == data - out = fs.cat("dir", recursive=True) + out = fs.cat(f"{TEST_BUCKET_NAME}/dir", recursive=True) assert out == data - fs.cp("dir", "dir2", recursive=True) - out = fs.find("", detail=False) - assert out == ["afile", "dir/test1", "dir/test2", "dir2/test1", "dir2/test2"] - fs.rm(["dir", "dir2"], recursive=True) - out = fs.find("", detail=False) - assert out == ["afile"] + fs.cp(f"{TEST_BUCKET_NAME}/dir", f"{TEST_BUCKET_NAME}/dir2", recursive=True) + out = fs.find(f"{TEST_BUCKET_NAME}", detail=False) + assert out == [ + f"{TEST_BUCKET_NAME}/afile", + f"{TEST_BUCKET_NAME}/dir/test1", + f"{TEST_BUCKET_NAME}/dir/test2", + f"{TEST_BUCKET_NAME}/dir2/test1", + f"{TEST_BUCKET_NAME}/dir2/test2", + ] + fs.rm([f"{TEST_BUCKET_NAME}/dir", f"{TEST_BUCKET_NAME}/dir2"], recursive=True) + out = fs.find(f"{TEST_BUCKET_NAME}", detail=False) + assert out == [f"{TEST_BUCKET_NAME}/afile"] def test_cat_ranges_one(fs): data1 = os.urandom(10000) - fs.pipe_file("data1", data1) + fs.pipe_file(f"{TEST_BUCKET_NAME}/data1", data1) # single range - out = fs.cat_ranges(["data1"], [10], [20]) + out = fs.cat_ranges([f"{TEST_BUCKET_NAME}/data1"], [10], [20]) assert out == [data1[10:20]] # range oob - out = fs.cat_ranges(["data1"], [0], [11000]) + out = fs.cat_ranges([f"{TEST_BUCKET_NAME}/data1"], [0], [11000]) assert out == [data1] # two disjoint ranges, one file - out = fs.cat_ranges(["data1", "data1"], [10, 40], [20, 60]) + out = fs.cat_ranges( + [f"{TEST_BUCKET_NAME}/data1", f"{TEST_BUCKET_NAME}/data1"], [10, 40], [20, 60] + ) assert out == [data1[10:20], data1[40:60]] # two adjoining ranges, one file - out = fs.cat_ranges(["data1", "data1"], [10, 30], [20, 60]) + out = fs.cat_ranges( + [f"{TEST_BUCKET_NAME}/data1", f"{TEST_BUCKET_NAME}/data1"], [10, 30], [20, 60] + ) assert out == [data1[10:20], data1[30:60]] # two overlapping ranges, one file - out = fs.cat_ranges(["data1", "data1"], [10, 15], [20, 60]) + out = fs.cat_ranges( + [f"{TEST_BUCKET_NAME}/data1", f"{TEST_BUCKET_NAME}/data1"], [10, 15], [20, 60] + ) assert out == [data1[10:20], data1[15:60]] # completely overlapping ranges, one file - out = fs.cat_ranges(["data1", "data1"], [10, 0], [20, 60]) + out = fs.cat_ranges( + [f"{TEST_BUCKET_NAME}/data1", f"{TEST_BUCKET_NAME}/data1"], [10, 0], [20, 60] + ) assert out == [data1[10:20], data1[0:60]] def test_cat_ranges_two(fs): data1 = os.urandom(10000) data2 = os.urandom(10000) - fs.pipe({"data1": data1, "data2": data2}) + fs.pipe({f"{TEST_BUCKET_NAME}/data1": data1, f"{TEST_BUCKET_NAME}/data2": data2}) # single range in each file - out = fs.cat_ranges(["data1", "data2"], [10, 10], [20, 20]) + out = fs.cat_ranges( + [f"{TEST_BUCKET_NAME}/data1", f"{TEST_BUCKET_NAME}/data2"], [10, 10], [20, 20] + ) assert out == [data1[10:20], data2[10:20]] @@ -119,4 +193,4 @@ def test_atomic_write(fs): def test_cat_ranges_error(fs): with pytest.raises(ValueError): - fs.cat_ranges(["path"], [], []) + fs.cat_ranges([f"{TEST_BUCKET_NAME}/path"], [], [])