Skip to content

Commit

Permalink
Support tag the access source
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Sep 26, 2024
1 parent 698eb23 commit a3be857
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 37 deletions.
42 changes: 5 additions & 37 deletions tosfs/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

PUT_TAG_ACTION_NAME = "PutBucketDoubleMeterTagging"
GET_TAG_ACTION_NAME = "GetBucketTagging"
DEL_TAG_ACTION_NAME = "DeleteBucketTagging"
EMR_OPEN_API_VERSION = "2022-12-29"
OPEN_API_HOST = "open.volcengineapi.com"
ACCEPT_HEADER_KEY = "accept"
Expand Down Expand Up @@ -110,13 +109,6 @@
{},
{},
),
DEL_TAG_ACTION_NAME: ApiInfo(
"POST",
"/",
{"Action": DEL_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION},
{},
{},
),
}


Expand All @@ -133,7 +125,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
BucketTagAction._instance = object.__new__(cls)
return BucketTagAction._instance

def __init__(self, key: str, secret: str, region: str = "cn-beijing") -> None:
def __init__(self, key: str, secret: str, region: str) -> None:
"""Init BucketTagAction."""
super().__init__(self.get_service_info(region), self.get_api_info())
self.set_ak(key)
Expand All @@ -151,19 +143,7 @@ def get_service_info(region: str) -> ServiceInfo:
if service_info:
return service_info

if "VOLC_REGION" in os.environ:
return ServiceInfo(
OPEN_API_HOST,
{
ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE,
},
Credentials("", "", "emr", region),
CONNECTION_TIMEOUT_DEFAULT_SECONDS,
SOCKET_TIMEOUT_DEFAULT_SECONDS,
"http",
)

raise Exception("do not support region %s" % region)
raise Exception(f"Do not support region: {region}")

def put_bucket_tag(self, bucket: str) -> tuple[str, bool]:
"""Put tag for bucket."""
Expand All @@ -188,24 +168,12 @@ def get_bucket_tag(self, bucket: str) -> bool:
try:
res = self.get(GET_TAG_ACTION_NAME, params)
res_json = json.loads(res)
logging.debug("The result of get_Bucket_tag is %s", res_json)
logging.debug("The get bucket tag's response is %s", res_json)
return True
except Exception as e:
logging.debug("Get tag for %s is failed: %s", bucket, e)
return False

def del_bucket_tag(self, bucket: str) -> None:
"""Delete tag for bucket."""
params = {
"Bucket": bucket,
}
try:
res = self.json(DEL_TAG_ACTION_NAME, params, json.dumps(""))
res_json = json.loads(res)
logging.debug("The result of del_Bucket_tag is %s", res_json)
except Exception as e:
logging.debug("Delete tag for %s is failed: %s", bucket, e)


def singleton(cls: Any) -> Any:
"""Singleton decorator."""
Expand All @@ -231,6 +199,7 @@ def __init__(self, key: str, secret: str, region: str):
self.key = key
self.secret = secret
self.region = region
self.bucket_tag_service = BucketTagAction(self.key, self.secret, self.region)

def add_bucket_tag(self, bucket: str) -> None:
"""Add tag for bucket."""
Expand All @@ -245,10 +214,9 @@ def add_bucket_tag(self, bucket: str) -> None:
self.cached_bucket_set |= tagged_bucket_from_file_set

need_tag_buckets = collect_bucket_set - self.cached_bucket_set
bucket_tag_service = BucketTagAction(self.key, self.secret, self.region)

for res in self.executor.map(
bucket_tag_service.put_bucket_tag, need_tag_buckets
self.bucket_tag_service.put_bucket_tag, need_tag_buckets
):
if res[1]:
self.cached_bucket_set.add(res[0])
Expand Down
46 changes: 46 additions & 0 deletions tosfs/tests/test_tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# ByteDance Volcengine EMR, Copyright 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from time import sleep

import pytest

from tosfs.tag import TAGGED_BUCKETS_FILE


@pytest.fixture
def _prepare_tag_env():
if os.path.exists(TAGGED_BUCKETS_FILE):
os.remove(TAGGED_BUCKETS_FILE)
yield
if os.path.exists(TAGGED_BUCKETS_FILE):
os.remove(TAGGED_BUCKETS_FILE)


@pytest.mark.usefixtures("_prepare_tag_env")
def test_bucket_tag_action(tosfs, bucket, temporary_workspace):
tag_mgr = tosfs.bucket_tag_mgr
if tag_mgr is None:
return

tag_mgr.cached_bucket_set = set()
tag_mgr.add_bucket_tag(bucket)
sleep(10)
assert os.path.exists(TAGGED_BUCKETS_FILE)
with open(TAGGED_BUCKETS_FILE, "r") as f:
tagged_buckets = f.read()
assert bucket in tagged_buckets

assert tag_mgr.bucket_tag_service.get_bucket_tag(bucket)

0 comments on commit a3be857

Please sign in to comment.