From 698eb232b1cb1967b24e2fc9eae5d24fc77193d1 Mon Sep 17 00:00:00 2001 From: yanghua Date: Wed, 25 Sep 2024 10:34:02 +0800 Subject: [PATCH] Support tag the access source --- poetry.lock | 2 +- pyproject.toml | 5 +- tosfs/consts.py | 4 +- tosfs/core.py | 20 ++++ tosfs/tag.py | 260 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 288 insertions(+), 3 deletions(-) create mode 100644 tosfs/tag.py diff --git a/poetry.lock b/poetry.lock index b779fe6..ceab263 100644 --- a/poetry.lock +++ b/poetry.lock @@ -733,4 +733,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "1bb1712f54089469cbb3c278bad0114a8104a28f988d82ff8d87bf88aa5d0fa5" +content-hash = "1bb1712f54089469cbb3c278bad0114a8104a28f988d82ff8d87bf88aa5d0fa5" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3d59f1f..e3991c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,9 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.9" -fsspec = "==2023.5.0" +fsspec = ">=2023.5.0" tos = ">=2.7.0" +volcengine= "==1.0.154" [tool.poetry.group.dev.dependencies] fsspec = ">=2023.5.0" @@ -22,6 +23,7 @@ pytest-cov = "==5.0.0" coverage = "==7.5.0" ruff = "==0.6.0" types-requests = "==2.32.0.20240907" +volcengine= "==1.0.154" [tool.pydocstyle] convention = "numpy" @@ -64,6 +66,7 @@ select = [ ignore = [ "S101", # Use of `assert` detected "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "S108", # Probable insecure usage of temporary file or directory "D203", # no-blank-line-before-class "D213", # multi-line-summary-second-line "PLR0913", # Too many arguments in function definition diff --git a/tosfs/consts.py b/tosfs/consts.py index 98d5e62..556aac0 100644 --- a/tosfs/consts.py +++ b/tosfs/consts.py @@ -31,6 +31,8 @@ LS_OPERATION_DEFAULT_MAX_ITEMS = 1000 +TOSFS_LOG_FORMAT = "%(asctime)s %(name)s [%(levelname)s] %(filename)s:%(lineno)d %(funcName)s : %(message)s" # noqa: E501 + # environment variable names ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL" -TOSFS_LOG_FORMAT = "%(asctime)s %(name)s [%(levelname)s] %(filename)s:%(lineno)d %(funcName)s : %(message)s" # noqa: E501 +ENV_NAME_TOS_BUCKET_TAG_ENABLE = "TOS_BUCKET_TAG_ENABLE" diff --git a/tosfs/core.py b/tosfs/core.py index 807c1bb..947626d 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -25,6 +25,7 @@ from fsspec import AbstractFileSystem from fsspec.spec import AbstractBufferedFile from fsspec.utils import setup_logging as setup_logger +from tos.auth import CredentialProviderAuth from tos.exceptions import TosClientError, TosServerError from tos.models import CommonPrefixInfo from tos.models2 import ( @@ -54,6 +55,7 @@ from tosfs.fsspec_utils import glob_translate from tosfs.mpu import MultipartUploader from tosfs.retry import retryable_func_executor +from tosfs.tag import BucketTagMgr from tosfs.utils import find_bucket_key, get_brange logger = logging.getLogger("tosfs") @@ -203,6 +205,10 @@ def __init__( if version_aware: raise ValueError("Currently, version_aware is not supported.") + self.tag_enabled = os.environ.get("TOS_TAG_ENABLED", True) + if self.tag_enabled: + self._init_tag_manager() + self.version_aware = version_aware self.default_block_size = ( default_block_size or FILE_OPERATION_READ_WRITE_BUFFER_SIZE @@ -2093,12 +2099,26 @@ def _split_path(self, path: str) -> Tuple[str, str, Optional[str]]: bucket, keypart = find_bucket_key(path) key, _, version_id = keypart.partition("?versionId=") + + if self.tag_enabled: + self.bucket_tag_mgr.add_bucket_tag(bucket) + return ( bucket, key, version_id if self.version_aware and version_id else None, ) + def _init_tag_manager(self) -> None: + auth = self.tos_client.auth + if isinstance(auth, CredentialProviderAuth): + credentials = auth.credentials_provider.get_credentials() + self.bucket_tag_mgr = BucketTagMgr( + credentials.get_ak(), credentials.get_sk(), auth.region + ) + else: + raise TosfsError("Currently only support CredentialProviderAuth type") + @staticmethod def _fill_dir_info( bucket: str, common_prefix: Optional[CommonPrefixInfo], key: str = "" diff --git a/tosfs/tag.py b/tosfs/tag.py new file mode 100644 index 0000000..0720af9 --- /dev/null +++ b/tosfs/tag.py @@ -0,0 +1,260 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The module contains all the business logic for tagging tos buckets .""" + +import fcntl +import functools +import json +import logging +import os +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from volcengine.ApiInfo import ApiInfo +from volcengine.base.Service import Service +from volcengine.Credentials import Credentials +from volcengine.ServiceInfo import ServiceInfo + +PUT_TAG_ACTION_NAME = "PutBucketDoubleMeterTagging" +GET_TAG_ACTION_NAME = "GetBucketTagging" +DEL_TAG_ACTION_NAME = "DeleteBucketTagging" +EMR_OPEN_API_VERSION = "2022-12-29" +OPEN_API_HOST = "open.volcengineapi.com" +ACCEPT_HEADER_KEY = "accept" +ACCEPT_HEADER_JSON_VALUE = "application/json" + +THREAD_POOL_SIZE = 2 +TAGGED_BUCKETS_FILE = "/tmp/.emr_tagged_buckets" + +CONNECTION_TIMEOUT_DEFAULT_SECONDS = 60 * 5 +SOCKET_TIMEOUT_DEFAULT_SECONDS = 60 * 5 + +service_info_map = { + "cn-beijing": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "cn-beijing"), + CONNECTION_TIMEOUT_DEFAULT_SECONDS, + SOCKET_TIMEOUT_DEFAULT_SECONDS, + "http", + ), + "cn-guangzhou": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "cn-guangzhou"), + CONNECTION_TIMEOUT_DEFAULT_SECONDS, + SOCKET_TIMEOUT_DEFAULT_SECONDS, + "http", + ), + "cn-shanghai": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "cn-shanghai"), + CONNECTION_TIMEOUT_DEFAULT_SECONDS, + SOCKET_TIMEOUT_DEFAULT_SECONDS, + "http", + ), + "ap-southeast-1": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "ap-southeast-1"), + CONNECTION_TIMEOUT_DEFAULT_SECONDS, + SOCKET_TIMEOUT_DEFAULT_SECONDS, + "http", + ), + "cn-beijing-qa": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr_qa", "cn-beijing"), + CONNECTION_TIMEOUT_DEFAULT_SECONDS, + SOCKET_TIMEOUT_DEFAULT_SECONDS, + "http", + ), +} + +api_info = { + PUT_TAG_ACTION_NAME: ApiInfo( + "POST", + "/", + {"Action": PUT_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, + {}, + {}, + ), + GET_TAG_ACTION_NAME: ApiInfo( + "GET", + "/", + {"Action": GET_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, + {}, + {}, + ), + DEL_TAG_ACTION_NAME: ApiInfo( + "POST", + "/", + {"Action": DEL_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, + {}, + {}, + ), +} + + +class BucketTagAction(Service): + """BucketTagAction is a class to manage the tag of bucket.""" + + _instance_lock = threading.Lock() + + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + """Singleton.""" + if not hasattr(BucketTagAction, "_instance"): + with BucketTagAction._instance_lock: + if not hasattr(BucketTagAction, "_instance"): + BucketTagAction._instance = object.__new__(cls) + return BucketTagAction._instance + + def __init__(self, key: str, secret: str, region: str = "cn-beijing") -> None: + """Init BucketTagAction.""" + super().__init__(self.get_service_info(region), self.get_api_info()) + self.set_ak(key) + self.set_sk(secret) + + @staticmethod + def get_api_info() -> dict: + """Get api info.""" + return api_info + + @staticmethod + def get_service_info(region: str) -> ServiceInfo: + """Get service info.""" + service_info = service_info_map.get(region) + if service_info: + return service_info + + if "VOLC_REGION" in os.environ: + return ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", region), + CONNECTION_TIMEOUT_DEFAULT_SECONDS, + SOCKET_TIMEOUT_DEFAULT_SECONDS, + "http", + ) + + raise Exception("do not support region %s" % region) + + def put_bucket_tag(self, bucket: str) -> tuple[str, bool]: + """Put tag for bucket.""" + params = { + "Bucket": bucket, + } + + try: + res = self.json(PUT_TAG_ACTION_NAME, params, json.dumps("")) + res_json = json.loads(res) + logging.debug("Put tag for bucket %s successfully: %s .", bucket, res_json) + return bucket, True + except Exception as e: + logging.debug("Put tag for bucket %s failed: %s .", bucket, e) + return bucket, False + + def get_bucket_tag(self, bucket: str) -> bool: + """Get tag for bucket.""" + params = { + "Bucket": bucket, + } + try: + res = self.get(GET_TAG_ACTION_NAME, params) + res_json = json.loads(res) + logging.debug("The result of get_Bucket_tag is %s", res_json) + return True + except Exception as e: + logging.debug("Get tag for %s is failed: %s", bucket, e) + return False + + def del_bucket_tag(self, bucket: str) -> None: + """Delete tag for bucket.""" + params = { + "Bucket": bucket, + } + try: + res = self.json(DEL_TAG_ACTION_NAME, params, json.dumps("")) + res_json = json.loads(res) + logging.debug("The result of del_Bucket_tag is %s", res_json) + except Exception as e: + logging.debug("Delete tag for %s is failed: %s", bucket, e) + + +def singleton(cls: Any) -> Any: + """Singleton decorator.""" + _instances = {} + + @functools.wraps(cls) + def get_instance(*args: Any, **kwargs: Any) -> Any: + if cls not in _instances: + _instances[cls] = cls(*args, **kwargs) + return _instances[cls] + + return get_instance + + +@singleton +class BucketTagMgr: + """BucketTagMgr is a class to manage the tag of bucket.""" + + def __init__(self, key: str, secret: str, region: str): + """Init BucketTagMgr.""" + self.executor = ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE) + self.cached_bucket_set: set = set() + self.key = key + self.secret = secret + self.region = region + + def add_bucket_tag(self, bucket: str) -> None: + """Add tag for bucket.""" + collect_bucket_set = {bucket} + + if not collect_bucket_set - self.cached_bucket_set: + return + + if os.path.exists(TAGGED_BUCKETS_FILE): + with open(TAGGED_BUCKETS_FILE, "r") as file: + tagged_bucket_from_file_set = set(file.read().split(" ")) + self.cached_bucket_set |= tagged_bucket_from_file_set + + need_tag_buckets = collect_bucket_set - self.cached_bucket_set + bucket_tag_service = BucketTagAction(self.key, self.secret, self.region) + + for res in self.executor.map( + bucket_tag_service.put_bucket_tag, need_tag_buckets + ): + if res[1]: + self.cached_bucket_set.add(res[0]) + + with open(TAGGED_BUCKETS_FILE, "w") as fw: + fcntl.flock(fw, fcntl.LOCK_EX) + fw.write(" ".join(self.cached_bucket_set)) + fcntl.flock(fw, fcntl.LOCK_UN) + fw.close()