Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Create obstore store in fsspec on demand #198

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 244 additions & 21 deletions obstore/python/obstore/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,35 @@

import asyncio
from collections import defaultdict
from typing import Any, Coroutine, Dict, List, Tuple
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Dict,
List,
Tuple,
)
from urllib.parse import urlparse

import fsspec.asyn
import fsspec.spec

import obstore as obs
from obstore import Bytes
from obstore.store import from_url

if TYPE_CHECKING:
from obstore.store import (
AzureConfig,
AzureConfigInput,
ClientConfig,
GCSConfig,
GCSConfigInput,
RetryConfig,
S3Config,
S3ConfigInput,
)


class AsyncFsspecStore(fsspec.asyn.AsyncFileSystem):
Expand All @@ -40,19 +63,44 @@ class AsyncFsspecStore(fsspec.asyn.AsyncFileSystem):
"""

cachable = False
config: (
S3Config
| S3ConfigInput
| GCSConfig
| GCSConfigInput
| AzureConfig
| AzureConfigInput
| None
)
client_options: ClientConfig | None
retry_config: RetryConfig | None

def __init__(
self,
store: obs.store.ObjectStore,
*args,
config: (
S3Config
| S3ConfigInput
| GCSConfig
| GCSConfigInput
| AzureConfig
| AzureConfigInput
| None
) = None,
client_options: ClientConfig | None = None,
retry_config: RetryConfig | None = None,
asynchronous: bool = False,
loop: Any = None,
batch_size: int | None = None,
):
"""Construct a new AsyncFsspecStore

Args:
store: a configured instance of one of the store classes in `obstore.store`.
config: Configuration for the cloud storage provider, which can be one of
S3Config, S3ConfigInput, GCSConfig, GCSConfigInput, AzureConfig,
or AzureConfigInput. If None, no cloud storage configuration is applied.
client_options: Additional options for configuring the client.
retry_config: Configuration for handling request errors.
asynchronous: Set to `True` if this instance is meant to be be called using
the fsspec async API. This should only be set to true when running
within a coroutine.
Expand All @@ -75,26 +123,88 @@ def __init__(
```
"""

self.store = store
self.config = config
self.client_options = client_options
self.retry_config = retry_config

super().__init__(
*args, asynchronous=asynchronous, loop=loop, batch_size=batch_size
)

def _split_path(self, path: str) -> Tuple[str, str]:
"""
Split bucket and file path

Args:
path (str): Input path, like `s3://mybucket/path/to/file`

Examples:
>>> split_path("s3://mybucket/path/to/file")
['mybucket', 'path/to/file']
"""

protocol_with_bucket = ["s3", "s3a", "gcs", "gs", "abfs", "https", "http"]

if not self.protocol in protocol_with_bucket:
# no bucket name in path
return "", path

res = urlparse(path)
if res.scheme:
if res.scheme != self.protocol:
raise ValueError(
f"Expect protocol to be {self.protocol}. Got {res.scheme}"
)
path = res.netloc + res.path

if "/" not in path:
return path, ""
else:
path_li = path.split("/")
bucket = path_li[0]
file_path = "/".join(path_li[1:])
return (bucket, file_path)

@lru_cache(maxsize=10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if this cache size could be user specified but we can come back to it

def _construct_store(self, bucket: str):
return from_url(
url=f"{self.protocol}://{bucket}",
config=self.config,
client_options=self.client_options,
retry_config=self.retry_config if self.retry_config else None,
)

async def _rm_file(self, path, **kwargs):
return await obs.delete_async(self.store, path)
bucket, path = self._split_path(path)
store = self._construct_store(bucket)
return await obs.delete_async(store, path)

async def _cp_file(self, path1, path2, **kwargs):
return await obs.copy_async(self.store, path1, path2)
bucket1, path1 = self._split_path(path1)
bucket2, path2 = self._split_path(path2)

if bucket1 != bucket2:
raise ValueError(
f"Bucket mismatch: Source bucket '{bucket1}' and destination bucket '{bucket2}' must be the same."
)

async def _pipe_file(self, path, value, **kwargs):
return await obs.put_async(self.store, path, value)
store = self._construct_store(bucket1)
return await obs.copy_async(store, path1, path2)

async def _pipe_file(self, path, value, mode="overwrite", **kwargs):
bucket, path = self._split_path(path)
store = self._construct_store(bucket)
return await obs.put_async(store, path, value)

async def _cat_file(self, path, start=None, end=None, **kwargs):
bucket, path = self._split_path(path)
store = self._construct_store(bucket)

if start is None and end is None:
resp = await obs.get_async(self.store, path)
return await resp.bytes_async()
resp = await obs.get_async(store, path)
return (await resp.bytes_async()).to_bytes()

range_bytes = await obs.get_range_async(self.store, path, start=start, end=end)
range_bytes = await obs.get_range_async(store, path, start=start, end=end)
return range_bytes.to_bytes()

async def _cat_ranges(
Expand All @@ -118,11 +228,14 @@ async def _cat_ranges(
for idx, (path, start, end) in enumerate(zip(paths, starts, ends)):
per_file_requests[path].append((start, end, idx))

futs: List[Coroutine[Any, Any, List[bytes]]] = []
futs: List[Coroutine[Any, Any, List[Bytes]]] = []
for path, ranges in per_file_requests.items():
bucket, path = self._split_path(path)
store = self._construct_store(bucket)

offsets = [r[0] for r in ranges]
ends = [r[1] for r in ranges]
fut = obs.get_ranges_async(self.store, path, starts=offsets, ends=ends)
fut = obs.get_ranges_async(store, path, starts=offsets, ends=ends)
futs.append(fut)

result = await asyncio.gather(*futs)
Expand All @@ -137,17 +250,40 @@ async def _cat_ranges(
return output_buffers

async def _put_file(self, lpath, rpath, **kwargs):
lbucket, lpath = self._split_path(lpath)
rbucket, rpath = self._split_path(rpath)

if lbucket != rbucket:
raise ValueError(
f"Bucket mismatch: Source bucket '{lbucket}' and destination bucket '{rbucket}' must be the same."
)

store = self._construct_store(lbucket)

with open(lpath, "rb") as f:
await obs.put_async(self.store, rpath, f)
await obs.put_async(store, rpath, f)

async def _get_file(self, rpath, lpath, **kwargs):
lbucket, lpath = self._split_path(lpath)
rbucket, rpath = self._split_path(rpath)

if lbucket != rbucket:
raise ValueError(
f"Bucket mismatch: Source bucket '{lbucket}' and destination bucket '{rbucket}' must be the same."
)

store = self._construct_store(lbucket)

with open(lpath, "wb") as f:
resp = await obs.get_async(self.store, rpath)
resp = await obs.get_async(store, rpath)
async for buffer in resp.stream():
f.write(buffer)

async def _info(self, path, **kwargs):
head = await obs.head_async(self.store, path)
bucket, path = self._split_path(path)
store = self._construct_store(bucket)

head = await obs.head_async(store, path)
return {
# Required of `info`: (?)
"name": head["path"],
Expand All @@ -159,25 +295,50 @@ async def _info(self, path, **kwargs):
"version": head["version"],
}

def _fill_bucket_name(self, path, bucket):
return f"{bucket}/{path}"

async def _ls(self, path, detail=True, **kwargs):
result = await obs.list_with_delimiter_async(self.store, path)
bucket, path = self._split_path(path)
store = self._construct_store(bucket)

result = await obs.list_with_delimiter_async(store, path)
objects = result["objects"]
prefs = result["common_prefixes"]
if detail:
return [
{
"name": object["path"],
"name": self._fill_bucket_name(object["path"], bucket),
"size": object["size"],
"type": "file",
"e_tag": object["e_tag"],
}
for object in objects
] + [{"name": object, "size": 0, "type": "directory"} for object in prefs]
] + [
{
"name": self._fill_bucket_name(pref, bucket),
"size": 0,
"type": "directory",
}
for pref in prefs
]
else:
return sorted([object["path"] for object in objects] + prefs)
return sorted(
[self._fill_bucket_name(object["path"], bucket) for object in objects]
+ [self._fill_bucket_name(pref, bucket) for pref in prefs]
)

def _open(self, path, mode="rb", **kwargs):
def _open(
self,
path,
mode="rb",
block_size=None,
autocommit=True,
cache_options=None,
**kwargs,
):
"""Return raw bytes-mode file-like from the file-system"""

return BufferedFileSimple(self, path, mode, **kwargs)


Expand All @@ -201,3 +362,65 @@ def read(self, length: int = -1):
data = self.fs.cat_file(self.path, self.loc, self.loc + length)
self.loc += length
return data


def register(protocol: str | list[str], asynchronous: bool = False):
"""
Dynamically register a subclass of AsyncFsspecStore for the given protocol(s).

This function creates a new subclass of AsyncFsspecStore with the specified
protocol and registers it with fsspec. If multiple protocols are provided,
the function registers each one individually.

Args:
protocol (str | list[str]): A single protocol (e.g., "s3", "gcs", "abfs") or
a list of protocols to register AsyncFsspecStore for.
asynchronous (bool, optional): If True, the registered store will support
asynchronous operations. Defaults to False.

Example:
>>> register("s3")
>>> register("s3", asynchronous=True) # Registers an async-store for "s3"
>>> register(["gcs", "abfs"]) # Registers both "gcs" and "abfs"

Notes:
- Each protocol gets a dynamically generated subclass named `AsyncFsspecStore_<protocol>`.
- This avoids modifying the original AsyncFsspecStore class.
"""

# Ensure protocol is of type str or list
if not isinstance(protocol, (str, list)):
raise TypeError(
f"Protocol must be a string or a list of strings, got {type(protocol)}"
)

# Ensure protocol is not None or empty
if not protocol:
raise ValueError(
"Protocol must be a non-empty string or a list of non-empty strings."
)

if isinstance(protocol, list):
# Ensure all elements are strings
if not all(isinstance(p, str) for p in protocol):
raise TypeError("All protocols in the list must be strings.")
# Ensure no empty strings in the list
if not all(p for p in protocol):
raise ValueError("Protocol names in the list must be non-empty strings.")

for p in protocol:
register(p)
return

fsspec.register_implementation(
protocol,
type(
f"AsyncFsspecStore_{protocol}", # Unique class name
(AsyncFsspecStore,), # Base class
{
"protocol": protocol,
"asynchronous": asynchronous,
}, # Assign protocol dynamically
),
clobber=True,
)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,13 @@ def s3_store(s3):
"AWS_ALLOW_HTTP": "true",
},
)


@pytest.fixture()
def s3_store_config(s3):
return {
"AWS_ENDPOINT_URL": s3,
"AWS_REGION": "us-east-1",
"AWS_SKIP_SIGNATURE": "True",
"AWS_ALLOW_HTTP": "true",
}
Loading