Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BI-4975: Construct file-connector s3 path dynamically #31

Merged
merged 2 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def get_sql_source(self, alias: Optional[str] = None) -> Any:
origin_src = self._get_origin_src()
status = origin_src.status
raw_schema = self.spec.raw_schema
s3_filename = origin_src.s3_filename
if origin_src.s3_filename_suffix is not None:
s3_filename = self.connection.get_full_s3_filename(origin_src.s3_filename_suffix)
else:
# TODO: Remove this fallback after old connections migration to s3_filename_suffix
s3_filename = origin_src.s3_filename

self._handle_component_errors()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,23 @@ def schedule_sources_delete(
source_to_del = (src.id for src in conn._saved_sources or [])
for src_id in source_to_del:
source = conn.get_saved_source_by_id(src_id)
if source.s3_filename is None:

if source.s3_filename_suffix is not None:
s3_filename = conn.get_full_s3_filename(source.s3_filename_suffix)
else:
# TODO: Remove this fallback after old connections migration to s3_filename_suffix
s3_filename = source.s3_filename

if s3_filename is None:
LOGGER.warning(f"Cannot schedule file deletion for source_id {source.id} - s3_filename not set")
continue
task = DeleteFileTask(
s3_filename=source.s3_filename,
s3_filename=s3_filename,
preview_id=source.preview_id,
)
task_instance = await_sync(self._task_processor.schedule(task))
LOGGER.info(
f"Scheduled task DeleteFileTask for source_id {source.id}, filename {source.s3_filename}. "
f"Scheduled task DeleteFileTask for source_id {source.id}, filename {s3_filename}. "
f"instance_id: {task_instance.instance_id}"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class BaseFileConnectionSourceStorageSchema(DefaultStorageSchema):
preview_id = fields.String(allow_none=True, load_default=None)
title = fields.String()
s3_filename = fields.String(allow_none=True, load_default=None)
s3_filename_suffix = fields.String(allow_none=True, load_default=None)
raw_schema = fields.Nested(SchemaColumnStorageSchema, many=True, allow_none=True, load_default=None)
status = fields.Enum(FileProcessingStatus)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class BaseFileS3Connection(ConnectionHardcodedDataMixin[FileS3ConnectorSettings]
"file_id",
"title",
"s3_filename",
"s3_filename_suffix",
"status",
"preview_id",
)
Expand All @@ -71,6 +72,7 @@ class FileDataSource:
preview_id: Optional[str] = attr.ib(default=None)
status: FileProcessingStatus = attr.ib(default=FileProcessingStatus.in_progress)
s3_filename: Optional[str] = attr.ib(default=None)
s3_filename_suffix: Optional[str] = attr.ib(default=None)
raw_schema: Optional[list[SchemaColumn]] = attr.ib(factory=list[SchemaColumn])

def str_for_hash(self) -> str:
Expand All @@ -80,6 +82,7 @@ def str_for_hash(self) -> str:
self.file_id,
self.title,
str(self.s3_filename),
str(self.s3_filename_suffix),
self.status.name,
]
)
Expand Down Expand Up @@ -119,6 +122,10 @@ def s3_access_key_id(self) -> str:
def s3_secret_access_key(self) -> str:
return self._connector_settings.SECRET_ACCESS_KEY

def get_full_s3_filename(self, s3_filename_suffix: str) -> str:
assert self.uuid and self.raw_tenant_id
return "_".join((self.raw_tenant_id, self.uuid, s3_filename_suffix))

def get_conn_dto(self) -> BaseFileS3ConnDTO: # type: ignore
cs = self._connector_settings
conn_dto = BaseFileS3ConnDTO(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __hash__(self) -> int:
self.file_id,
self.title,
self.s3_filename,
self.s3_filename_suffix,
raw_schema,
self.status,
self.sheet_id,
Expand Down Expand Up @@ -117,6 +118,7 @@ def restore_source_params_from_orig(self, src_id: str, original_version: BaseFil
raw_schema=orig_src.raw_schema,
file_id=orig_src.file_id,
s3_filename=orig_src.s3_filename,
s3_filename_suffix=orig_src.s3_filename_suffix,
status=orig_src.status,
preview_id=orig_src.preview_id,
first_line_is_header=orig_src.first_line_is_header,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def restore_source_params_from_orig(self, src_id: str, original_version: BaseFil
raw_schema=orig_src.raw_schema,
file_id=orig_src.file_id,
s3_filename=orig_src.s3_filename,
s3_filename_suffix=orig_src.s3_filename_suffix,
status=orig_src.status,
preview_id=orig_src.preview_id,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def saved_connection_id(
for src in conn.data.sources:
src.status = sample_file_data_source.status
src.raw_schema = sample_file_data_source.raw_schema
src.s3_filename = sample_file_data_source.s3_filename
src.s3_filename_suffix = sample_file_data_source.s3_filename_suffix
sync_us_manager.save(conn)
yield conn_id

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_build_from_clause(
expected = (
f"s3("
f"'{self.connection_settings.S3_ENDPOINT}/{self.connection_settings.BUCKET}/"
f"{sample_file_data_source.s3_filename}', "
f"{sample_file_data_source.s3_filename_suffix}', "
f"'key_id_{replace_secret}', 'secret_key_{replace_secret}', 'Native', "
f"'c1 Nullable(Int64)')"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from typing import (
Any,
Optional,
)

import pytest

from dl_api_lib_testing.initialization import initialize_api_lib_test

from dl_connector_bundle_chs3.chs3_base.core.us_connection import BaseFileS3Connection
from dl_connector_bundle_chs3_tests.db.config import API_TEST_CONFIG


def pytest_configure(config): # noqa
initialize_api_lib_test(pytest_config=config, api_test_config=API_TEST_CONFIG)


@pytest.fixture(autouse=True)
def patch_get_full_s3_filename(monkeypatch: pytest.MonkeyPatch) -> None:
def _patched(self: Any, s3_filename_suffix: str) -> str: # type: ignore
return s3_filename_suffix

monkeypatch.setattr(BaseFileS3Connection, "get_full_s3_filename", _patched)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def sample_file_data_source(
id=str(uuid.uuid4()),
file_id=str(uuid.uuid4()),
title=sample_s3_file,
s3_filename=sample_s3_file,
s3_filename_suffix=sample_s3_file,
raw_schema=raw_schema,
status=FileProcessingStatus.ready,
column_types=[{"name": col[0], "user_type": col[1].name} for col in sample_table_spec.table_schema],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_component_error(
conn.data.sources[0].id,
role=DataSourceRole.origin,
s3_filename=None,
s3_filename_suffix=None,
status=FileProcessingStatus.failed,
preview_id=None,
data_updated_at=datetime.datetime.now(datetime.timezone.utc),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def sample_file_data_source(
id=str(uuid.uuid4()),
file_id=str(uuid.uuid4()),
title=sample_s3_file,
s3_filename=sample_s3_file,
s3_filename_suffix=sample_s3_file,
raw_schema=raw_schema,
status=FileProcessingStatus.ready,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from dl_file_uploader_task_interface.context import FileUploaderTaskContext
import dl_file_uploader_task_interface.tasks as task_interface
from dl_file_uploader_worker_lib.tasks.save import _make_source_s3_filename
from dl_file_uploader_worker_lib.tasks.save import make_source_s3_filename_suffix
from dl_task_processor.task import (
BaseExecutorTask,
Retry,
Expand Down Expand Up @@ -267,11 +267,12 @@ async def run(self) -> TaskResult:
continue

old_fname_parts = old_s3_filename.split("_")
if len(old_fname_parts) == 2 and old_fname_parts[0] and old_fname_parts[1]:
if len(old_fname_parts) >= 2 and all(part for part in old_fname_parts):
# assume that first part is old tenant id
old_tenants.add(old_fname_parts[0])

new_s3_filename = _make_source_s3_filename(tenant_id)
s3_filename_suffix = make_source_s3_filename_suffix()
new_s3_filename = conn.get_full_s3_filename(s3_filename_suffix)
await s3_client.copy_object(
CopySource=dict(
Bucket=s3_service.persistent_bucket_name,
Expand All @@ -291,6 +292,7 @@ async def run(self) -> TaskResult:
source.id,
role=DataSourceRole.origin,
s3_filename=new_s3_filename,
s3_filename_suffix=s3_filename_suffix,
)

if conn_changed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
LOGGER = logging.getLogger(__name__)


def _make_source_s3_filename(tenant_id: str) -> str:
return f"{tenant_id}_{shortuuid.uuid()}"
def make_source_s3_filename_suffix() -> str:
return str(shortuuid.uuid())


def _make_tmp_source_s3_filename(source_id: str) -> str:
Expand Down Expand Up @@ -185,7 +185,9 @@ async def run(self) -> TaskResult:
raw_schema_override if raw_schema_override is not None else conn_raw_schema,
)

new_s3_filename = _make_source_s3_filename(tenant_id=self.meta.tenant_id)
s3_filename_suffix = make_source_s3_filename_suffix()
new_s3_filename = conn.get_full_s3_filename(s3_filename_suffix)
assert new_s3_filename

def _construct_insert_from_select_query(for_debug: bool = False) -> str:
src_sql = make_s3_table_func_sql_source(
Expand Down Expand Up @@ -262,6 +264,7 @@ def _construct_insert_from_select_query(for_debug: bool = False) -> str:
dst_source_id,
role=DataSourceRole.origin,
s3_filename=new_s3_filename,
s3_filename_suffix=s3_filename_suffix,
status=FileProcessingStatus.ready,
preview_id=preview.id,
**extra_dsrc_params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from dl_file_uploader_worker_lib.utils import parsing_utils
from dl_task_processor.state import wait_task

from dl_connector_bundle_chs3.chs3_base.core.us_connection import BaseFileS3Connection

from .utils import create_file_connection


Expand Down Expand Up @@ -617,8 +619,10 @@ async def test_rename_tenant_files(
usm = default_async_usm_per_test
conn = await usm.get_by_id(saved_file_connection_id)
source = conn.data.sources[0]
assert source.s3_filename
s3_obj = await s3_client.get_object(Bucket=s3_persistent_bucket, Key=source.s3_filename)
assert source.s3_filename_suffix
s3_obj = await s3_client.get_object(
Bucket=s3_persistent_bucket, Key=conn.get_full_s3_filename(source.s3_filename_suffix)
)
s3_data = await s3_obj["Body"].read()
preview_set = PreviewSet(redis=redis_model_manager._redis, id=conn.raw_tenant_id)
ps_vals = {_ async for _ in preview_set.sscan_iter()}
Expand All @@ -630,10 +634,13 @@ async def test_rename_tenant_files(
assert result[-1] == "success"

updated_conn = await usm.get_by_id(saved_file_connection_id)
assert isinstance(conn, BaseFileS3Connection)
updated_source = updated_conn.get_file_source_by_id(source.id)
assert updated_source.s3_filename
assert updated_source.s3_filename.startswith(new_tenant_id)
updated_s3_obj = await s3_client.get_object(Bucket=s3_persistent_bucket, Key=updated_source.s3_filename)
assert updated_source.s3_filename_suffix
assert updated_source.s3_filename_suffix != source.s3_filename_suffix
assert updated_source.s3_filename and updated_source.s3_filename.startswith(updated_conn.raw_tenant_id)
new_s3_filename = updated_conn.get_full_s3_filename(updated_source.s3_filename_suffix)
updated_s3_obj = await s3_client.get_object(Bucket=s3_persistent_bucket, Key=new_s3_filename)
updated_s3_obj_data = await updated_s3_obj["Body"].read()
assert s3_data == updated_s3_obj_data

Expand Down
Loading