Skip to content

Commit

Permalink
Implement ls API
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Aug 16, 2024
1 parent c3ef905 commit e46fcdc
Show file tree
Hide file tree
Showing 5 changed files with 500 additions and 2 deletions.
283 changes: 283 additions & 0 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@
"""
import logging
import os
from typing import Optional, Tuple

import tos
from fsspec import AbstractFileSystem
from fsspec.utils import setup_logging as setup_logger
from tos.models import CommonPrefixInfo
from tos.models2 import ListedObjectVersion

from tosfs.utils import find_bucket_key

# environment variable names
ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL"
Expand All @@ -41,3 +48,279 @@ def setup_logging():
logger.warning(
"The tosfs's log level is set to be %s", logging.getLevelName(logger.level)
)


class TosFileSystem(AbstractFileSystem):
"""
Tos file system. An implementation of AbstractFileSystem which is an
abstract super-class for pythonic file-systems.
"""

def __init__(
self,
endpoint_url=None,
key="",
secret="",
region=None,
version_aware=False,
credentials_provider=None,
**kwargs,
):
self.tos_client = tos.TosClientV2(
key,
secret,
endpoint_url,
region,
credentials_provider=credentials_provider,
)
self.version_aware = version_aware
super().__init__(**kwargs)

def ls(self, path, detail=False, refresh=False, versions=False, **kwargs):
"""
List objects under the given path.
:param path: The path to list.
:param detail: Whether to return detailed information.
:param refresh: Whether to refresh the cache.
:param versions: Whether to list object versions.
:param kwargs: Additional arguments.
:return: A list of objects under the given path.
"""
path = self._strip_protocol(path).rstrip("/")
if path in ["", "/"]:
files = self._lsbuckets(refresh)
return files if detail else sorted([o["name"] for o in files])

files = self._lsdir(path, refresh, versions=versions)
if not files and "/" in path:
try:
files = self._lsdir(
self._parent(path), refresh=refresh, versions=versions
)
except IOError:
pass
files = [
o
for o in files
if o["name"].rstrip("/") == path and o["type"] != "directory"
]
if detail:
return files

return files if detail else sorted([o["name"] for o in files])

def _lsbuckets(self, refresh=False):
"""
List all buckets in the account.
:param refresh: Whether to refresh the cache.
:return: A list of buckets.
"""
if "" not in self.dircache or refresh:
try:
resp = self.tos_client.list_buckets()
except tos.exceptions.TosClientError as e:
logger.error("Tosfs failed with client error: %s", e)
return []
except tos.exceptions.TosServerError as e:
logger.error("Tosfs failed with server error: %s", e)
return []
except Exception as e:
logger.error("Tosfs failed with unknown error: %s", e)
return []

buckets = []
for bucket in resp.buckets:
buckets.append(
{
"Key": bucket.name,
"Size": 0,
"StorageClass": "BUCKET",
"size": 0,
"type": "directory",
"name": bucket.name,
}
)
self.dircache[""] = buckets

return self.dircache[""]

def _lsdir(
self,
path,
refresh=False,
max_items: int = 1000,
delimiter="/",
prefix="",
versions=False,
):
"""
List objects in a bucket, here we use cache to improve performance.
:param path: The path to list.
:param refresh: Whether to refresh the cache.
:param max_items: The maximum number of items to return, default is 1000. # noqa: E501
:param delimiter: The delimiter to use for grouping objects.
:param prefix: The prefix to use for filtering objects.
:param versions: Whether to list object versions.
:return: A list of objects in the bucket.
"""
bucket, key, _ = self.split_path(path)
if not prefix:
prefix = ""
if key:
prefix = key.lstrip("/") + "/" + prefix
if path not in self.dircache or refresh or not delimiter or versions:
logger.debug("Get directory listing for %s", path)
dirs = []
files = []
for obj in self._listdir(
bucket,
max_items=max_items,
delimiter=delimiter,
prefix=prefix,
versions=versions,
):
if isinstance(obj, CommonPrefixInfo):
dirs.append(self._fill_common_prefix_info(obj, bucket))
else:
files.append(self._fill_object_info(obj, bucket, versions))
files += dirs

if delimiter and files and not versions:
self.dircache[path] = files
return files
return self.dircache[path]

def _listdir(
self,
bucket,
max_items: int = 1000,
delimiter="/",
prefix="",
versions=False,
):
"""
List objects in a bucket.
:param bucket: The bucket name.
:param max_items: The maximum number of items to return, default is 1000. # noqa: E501
:param delimiter: The delimiter to use for grouping objects.
:param prefix: The prefix to use for filtering objects.
:param versions: Whether to list object versions.
:return: A list of objects in the bucket.
"""
if versions and not self.version_aware:
raise ValueError(
"versions cannot be specified if the filesystem is "
"not version aware."
)

all_results = []
is_truncated = True

try:
if self.version_aware:
key_marker, version_id_marker = None, None
while is_truncated:
resp = self.tos_client.list_object_versions(
bucket,
prefix,
delimiter=delimiter,
max_keys=max_items,
key_marker=key_marker,
version_id_marker=version_id_marker,
)
is_truncated = resp.is_truncated
all_results.extend(
resp.versions
+ resp.common_prefixes
+ resp.delete_markers
)
key_marker, version_id_marker = (
resp.next_key_marker,
resp.next_version_id_marker,
)
else:
continuation_token = ""
while is_truncated:
resp = self.tos_client.list_objects_type2(
bucket,
prefix,
start_after=prefix,
delimiter=delimiter,
max_keys=max_items,
continuation_token=continuation_token,
)
is_truncated = resp.is_truncated
continuation_token = resp.next_continuation_token

all_results.extend(resp.contents + resp.common_prefixes)

return all_results
except tos.exceptions.TosClientError as e:
logger.error(
"Tosfs failed with client error, message:%s, cause: %s",
e.message,
e.cause,
)
return []
except tos.exceptions.TosServerError as e:
logger.error("Tosfs failed with server error: %s", e)
return []
except Exception as e:
logger.error("Tosfs failed with unknown error: %s", e)
return []

def split_path(self, path) -> Tuple[str, str, Optional[str]]:
"""
Normalise tos path string into bucket and key.
Parameters
----------
path : string
Input path, like `tos://mybucket/path/to/file`
Examples
--------
>>> split_path("tos://mybucket/path/to/file")
['mybucket', 'path/to/file', None]
# pylint: disable=line-too-long
>>> split_path("tos://mybucket/path/to/versioned_file?versionId=some_version_id") # noqa: E501
['mybucket', 'path/to/versioned_file', 'some_version_id']
"""
path = self._strip_protocol(path)
path = path.lstrip("/")
if "/" not in path:
return path, "", None

bucket, keypart = find_bucket_key(path)
key, _, version_id = keypart.partition("?versionId=")
return (
bucket,
key,
version_id if self.version_aware and version_id else None,
)

@staticmethod
def _fill_common_prefix_info(common_prefix: CommonPrefixInfo, bucket):
return {
"name": common_prefix.prefix[:-1],
"Key": "/".join([bucket, common_prefix.prefix]),
"Size": 0,
"type": "directory",
}

@staticmethod
def _fill_object_info(obj, bucket, versions=False):
result = {
"Key": f"{bucket}/{obj.key}",
"size": obj.size,
"name": f"{bucket}/{obj.key}",
"type": "file",
}
if (
isinstance(obj, ListedObjectVersion)
and versions
and obj.version_id
and obj.version_id != "null"
):
result["name"] += f"?versionId={obj.version_id}"
return result
60 changes: 60 additions & 0 deletions tosfs/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# ByteDance Volcengine EMR, Copyright 2022.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
from tos import EnvCredentialsProvider

from tosfs.core import TosFileSystem
from tosfs.utils import random_path


@pytest.fixture(scope="module")
def tosfs_env_prepare():
if "TOS_ACCESS_KEY" not in os.environ:
raise EnvironmentError(
"Can not find TOS_ACCESS_KEY in environment variables."
)
if "TOS_SECRET_KEY" not in os.environ:
raise EnvironmentError(
"Can not find TOS_SECRET_KEY in environment variables."
)


@pytest.fixture(scope="module")
def tosfs(tosfs_env_prepare):
tosfs = TosFileSystem(
endpoint_url=os.environ.get("TOS_ENDPOINT"),
region=os.environ.get("TOS_REGION"),
credentials_provider=EnvCredentialsProvider(),
)
yield tosfs


@pytest.fixture(scope="module")
def bucket():
yield os.environ.get("TOS_BUCKET", "proton-ci")


@pytest.fixture(autouse=True)
def temporary_workspace(tosfs, bucket):
workspace = random_path()
# currently, make dir via purely tos python client,
# will replace with tosfs.mkdir in the future
tosfs.tos_client.put_object(bucket=bucket, key=f"{workspace}/")
yield workspace
# currently, remove dir via purely tos python client,
# will replace with tosfs.rmdir in the future
tosfs.tos_client.delete_object(bucket=bucket, key=f"{workspace}/")
Loading

0 comments on commit e46fcdc

Please sign in to comment.