Skip to content

Commit

Permalink
feat(connectors): BI-5975 Connection export and import (#736)
Browse files Browse the repository at this point in the history
* feat(connectors): BI-5975 Connection export

* Fix

* Fix

* Add connection import

* Lint fix

* Add import test

* Fix ch readonly default

* Fix readonly

* Fix

* Delete id from export response

* Review fixes

* Review fixes

* Review fixes

* Attempt to fix file-connectors import

* Fix file-connectors import
  • Loading branch information
vallbull authored Feb 19, 2025
1 parent 4876f11 commit c89ac6a
Show file tree
Hide file tree
Showing 25 changed files with 296 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def secret_string_field(
required: bool = True,
allow_none: bool = False,
default: Optional[str] = None,
bi_extra: FieldExtra = FieldExtra(editable=True), # noqa: B008
bi_extra: FieldExtra = FieldExtra(editable=True, export_fake=True), # noqa: B008
) -> ma_fields.String:
return ma_fields.String(
attribute=attribute,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ClassicSQLConnectionSchema(ConnectionSchema):
host = DBHostField(attribute="data.host", required=True, bi_extra=FieldExtra(editable=True))
port = ma_fields.Integer(attribute="data.port", required=True, bi_extra=FieldExtra(editable=True))
username = ma_fields.String(attribute="data.username", required=True, bi_extra=FieldExtra(editable=True))
password = secret_string_field(attribute="data.password", bi_extra=FieldExtra(editable=True))
password = secret_string_field(attribute="data.password")
db_name = ma_fields.String(
attribute="data.db_name", allow_none=True, bi_extra=FieldExtra(editable=True), validate=db_name_no_query_params
)
Expand Down
9 changes: 9 additions & 0 deletions lib/dl_api_connector/dl_api_connector/api_schema/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,19 @@ class CreateMode(OperationsMode):
test = enum.auto()


class ImportMode(OperationsMode):
create_from_import = enum.auto()


class EditMode(OperationsMode):
edit = enum.auto()
test = enum.auto()


class ExportMode(OperationsMode):
export = enum.auto()


class SchemaKWArgs(TypedDict):
only: Optional[Sequence[str]]
partial: Union[Sequence[str], bool]
Expand All @@ -38,3 +46,4 @@ class FieldExtra:
partial_in: Sequence[OperationsMode] = ()
exclude_in: Sequence[OperationsMode] = ()
editable: Union[bool, Sequence[OperationsMode]] = ()
export_fake: Optional[bool] = False
42 changes: 40 additions & 2 deletions lib/dl_api_connector/dl_api_connector/api_schema/top_level.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
from copy import deepcopy
import itertools
import logging
import os
Expand All @@ -19,6 +20,7 @@
import marshmallow
from marshmallow import (
missing,
post_dump,
post_load,
pre_load,
)
Expand All @@ -28,7 +30,9 @@
from dl_api_connector.api_schema.extras import (
CreateMode,
EditMode,
ExportMode,
FieldExtra,
ImportMode,
OperationsMode,
SchemaKWArgs,
)
Expand Down Expand Up @@ -84,7 +88,7 @@ def get_field_extra(f: ma_fields.Field) -> Optional[FieldExtra]:
return f.metadata.get("bi_extra", None)

@property
def operations_mode(self) -> Optional[CreateMode]:
def operations_mode(self) -> Optional[CreateMode | ImportMode]:
return self.context.get(self.CTX_KEY_OPERATIONS_MODE)

@classmethod
Expand All @@ -98,6 +102,13 @@ def all_fields_with_extra_info(cls) -> Iterable[tuple[str, ma_fields.Field, Fiel
if extra is not None:
yield field_name, field, extra

@classmethod
def fieldnames_with_extra_export_fake_info(cls) -> Iterable[str]:
for field_name, field in cls.all_fields_dict().items():
extra = cls.get_field_extra(field)
if extra is not None and extra.export_fake is True:
yield field_name

def _refine_init_kwargs(self, kw_args: SchemaKWArgs, operations_mode: Optional[OperationsMode]) -> SchemaKWArgs:
if operations_mode is None:
return kw_args
Expand Down Expand Up @@ -159,7 +170,7 @@ def post_load(self, data: dict[str, Any], **_: Any) -> Union[_TARGET_OBJECT_TV,
return data
assert isinstance(editable_object, self.TARGET_CLS)
return self.update_object(editable_object, data)
if isinstance(self.operations_mode, CreateMode):
if isinstance(self.operations_mode, CreateMode | ImportMode):
return self.create_object(data)
raise ValueError(f"Can not perform load. Unknown operations mode: {self.operations_mode!r}")

Expand All @@ -171,6 +182,21 @@ def get_allowed_unknown_fields(self) -> set[str]:
"""
return set()

@final
def delete_unknown_fields(self, data: dict[str, Any]) -> dict[str, Any]:
LOGGER.info(
"Got unknown fields for schema %s/%s. Unknown fields will be removed.",
type(self).__qualname__,
self.operations_mode,
)

cleaned_data = {}
for field_name, field_value in data.items():
if field_name in self.fields and not self.fields[field_name].dump_only:
cleaned_data[field_name] = field_value

return cleaned_data

@final
def handle_unknown_fields(self, data: dict[str, Any]) -> dict[str, Any]:
refined_data = {}
Expand Down Expand Up @@ -230,8 +256,20 @@ def pre_load(self, data: dict[str, Any], **_: Any) -> dict[str, Any]:
schema_input_keys=all_data_keys,
),
)

if isinstance(self.operations_mode, ImportMode):
return self.delete_unknown_fields(data)

return self.handle_unknown_fields(data)

@post_dump(pass_many=False)
def post_dump(self, data: dict[str, Any], **_: Any) -> dict[str, Any]:
if isinstance(self.operations_mode, ExportMode):
data = deepcopy(data)
for secret_field in self.fieldnames_with_extra_export_fake_info():
data[secret_field] = "******"
return data


_US_ENTRY_TV = TypeVar("_US_ENTRY_TV", bound=USEntry)

Expand Down
16 changes: 16 additions & 0 deletions lib/dl_api_lib/dl_api_lib/app/control_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Generic,
Optional,
TypeVar,
final,
)

import attr
Expand Down Expand Up @@ -49,6 +50,12 @@
from dl_core.connection_models import ConnectOptions
from dl_core.us_connection_base import ConnectionBase

from dl_api_lib.app.control_api.resources.connections import (
BIResource,
ConnectionExportItem,
)
from dl_api_lib.app.control_api.resources.connections import ns as connections_namespace


@attr.s(frozen=True)
class EnvSetupResult:
Expand All @@ -62,6 +69,14 @@ class EnvSetupResult:
class ControlApiAppFactory(SRFactoryBuilder, Generic[TControlApiAppSettings], abc.ABC):
_settings: TControlApiAppSettings = attr.ib()

def get_connection_export_resource(self) -> type[BIResource]:
return ConnectionExportItem

@final
def register_additional_handlers(self) -> None:
connection_export_resource = self.get_connection_export_resource()
connections_namespace.add_resource(connection_export_resource, "/export/<connection_id>")

@abc.abstractmethod
def set_up_environment(
self,
Expand Down Expand Up @@ -159,6 +174,7 @@ def create_app(
ma = Marshmallow()
ma.init_app(app)

app.before_first_request(self.register_additional_handlers)
init_apis(app)

return app
71 changes: 71 additions & 0 deletions lib/dl_api_lib/dl_api_lib/app/control_api/resources/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from dl_api_connector.api_schema.extras import (
CreateMode,
EditMode,
ExportMode,
ImportMode,
)
from dl_api_lib import exc
from dl_api_lib.api_decorators import schematic_request
Expand Down Expand Up @@ -125,6 +127,48 @@ def post(self, connection_id): # type: ignore # TODO: fix
_handle_conn_test_exc(e)


@ns.route("/import")
class ConnectionsImportList(BIResource):
@put_to_request_context(endpoint_code="ConnectionImport")
@schematic_request(ns=ns)
def post(self): # type: ignore # TODO: fix
us_manager = self.get_us_manager()
notifications = []

conn_data = request.json and request.json["data"]["connection"]
assert conn_data

conn_type = conn_data["db_type"]
if not conn_type or conn_type not in ConnectionType:
raise exc.BadConnectionType(f"Invalid connection type value: {conn_type}")

conn_availability = self.get_service_registry().get_connector_availability()
conn_type_is_available = conn_availability.check_connector_is_available(ConnectionType[conn_type])
if not conn_type_is_available:
raise exc.UnsupportedForEntityType("Connector %s is not available in current env", conn_type)

conn_data["workbook_id"] = request.json and request.json["data"].get("workbook_id", None)
conn_data["type"] = conn_type

schema = GenericConnectionSchema(
context=self.get_schema_ctx(schema_operations_mode=ImportMode.create_from_import)
)
try:
conn: ConnectionBase = schema.load(conn_data)
except MValidationError as e:
return e.messages, 400

conn.validate_new_data_sync(services_registry=self.get_service_registry())

conn_warnings = conn.get_import_warnings_list()
if conn_warnings:
notifications.extend(conn_warnings)

us_manager.save(conn)

return dict(id=conn.uuid, notifications=notifications)


@ns.route("/")
class ConnectionsList(BIResource):
@put_to_request_context(endpoint_code="ConnectionCreate")
Expand Down Expand Up @@ -211,6 +255,33 @@ def put(self, connection_id): # type: ignore # TODO: fix
us_manager.save(conn)


class ConnectionExportItem(BIResource):
@put_to_request_context(endpoint_code="ConnectionExport")
@schematic_request(
ns=ns,
responses={},
)
def get(self, connection_id: str) -> dict:
notifications: list[dict] = []

conn = self.get_us_manager().get_by_id(connection_id, expected_type=ConnectionBase)
need_permission_on_entry(conn, USPermissionKind.read)
assert isinstance(conn, ConnectionBase)

if not conn.allow_export:
raise exc.UnsupportedForEntityType(f"Connector {conn.conn_type.name} does not support export")

result = GenericConnectionSchema(context=self.get_schema_ctx(ExportMode.export)).dump(conn)
result.update(options=ConnectionOptionsSchema().dump(conn.get_options()))
result.pop("id")

conn_warnings = conn.get_export_warnings_list()
if conn_warnings:
notifications.extend(conn_warnings)

return dict(connection=result, notifications=notifications)


def _dump_source_templates(tpls: Optional[list[DataSourceTemplate]]) -> Optional[list[dict[str, Any]]]:
if tpls is None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dl_api_client.dsmaker.api.http_sync_base import SyncHttpClientBase
from dl_api_lib_testing.connection_base import ConnectionTestBase
from dl_core.us_connection_base import ConnectionBase
from dl_core.us_manager.us_manager_sync import SyncUSManager
from dl_testing.regulated_test import RegulatedTestCase


Expand All @@ -23,6 +25,76 @@ def test_create_connection(
)
assert resp.status_code == 200, resp.json

def test_export_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: Optional[dict[str, str]],
sync_us_manager: SyncUSManager,
) -> None:
conn = sync_us_manager.get_by_id(saved_connection_id, expected_type=ConnectionBase)
assert isinstance(conn, ConnectionBase)

resp = control_api_sync_client.get(
url=f"/api/v1/connections/export/{saved_connection_id}",
headers=bi_headers,
)

if not conn.allow_export:
assert resp.status_code == 400
return

assert resp.status_code == 200, resp.json
if hasattr(conn.data, "password"):
password = resp.json["connection"]["password"]
assert password == "******"

def test_import_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: Optional[dict[str, str]],
sync_us_manager: SyncUSManager,
) -> None:
conn = sync_us_manager.get_by_id(saved_connection_id, expected_type=ConnectionBase)
assert isinstance(conn, ConnectionBase)
if not conn.allow_export:
return

export_resp = control_api_sync_client.get(
url=f"/api/v1/connections/export/{saved_connection_id}",
headers=bi_headers,
)

export_resp.json["connection"][
"name"
] = f"{self.conn_type.name} conn {uuid.uuid4()}" # in case of response with workbook, 'name'-field is in export response by default

import_request = json.dumps(
{
"data": {
# "workbook_id" : "1234567890000", # can't test with workbook in case of ERR.DS_API.US.OBJ_NOT_FOUND
"connection": export_resp.json["connection"]
}
}
)

import_response = control_api_sync_client.post(
url="/api/v1/connections/import",
headers=bi_headers,
data=import_request,
content_type="application/json",
)
assert import_response.status_code == 200, import_response.json
assert import_response.json["id"]
assert import_response.json["id"] != saved_connection_id
assert import_response.json["notifications"]

export_resp = control_api_sync_client.delete(
url=f"/api/v1/connections/{import_response.json['id']}",
headers=bi_headers,
)

def test_test_connection(
self,
control_api_sync_client: SyncHttpClientBase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class BitrixGDSConnectOptions(ConnectOptions):
class BitrixGDSConnection(ConnectionBase):
allow_cache: ClassVar[bool] = True
source_type = SOURCE_TYPE_BITRIX_GDS
allow_export: ClassVar[bool] = True

@attr.s(kw_only=True)
class DataModel(ConnCacheableDataModelMixin, ConnectionBase.DataModel):
Expand Down
Loading

0 comments on commit c89ac6a

Please sign in to comment.