Skip to content

Commit

Permalink
feat(general): BI-5779 Checking connection can be used in public hand…
Browse files Browse the repository at this point in the history
…ler (#619)

* feat(general): BI-5779 Checking connection can be used in public handler

* feat(general): BI-5779 Checking connection can be used in public handler

* Fix TODO
  • Loading branch information
vallbull authored Sep 26, 2024
1 parent 278ed31 commit e8bae23
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 5 deletions.
66 changes: 61 additions & 5 deletions lib/dl_api_lib/dl_api_lib/app/control_api/resources/info.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import typing

from marshmallow import Schema
from marshmallow import fields as ma_fields

Expand All @@ -20,6 +22,7 @@
UserDataType,
)
from dl_core.exc import EntityUsageNotAllowed
from dl_core.us_connection_base import ConnectionBase
from dl_core.us_dataset import Dataset


Expand All @@ -32,19 +35,33 @@ class GetFieldTypeCollectionResponseSchema(Schema):
types = ma_fields.Nested(FieldTypeInfoSchema, many=True)


class BasePublicityCheckerResponseSchema(Schema):
allowed = ma_fields.Boolean()
reason = ma_fields.String()


class DatasetsPublicityCheckerRequestSchema(Schema):
datasets = ma_fields.List(ma_fields.String())


class DatasetsPublicityCheckerResponseSchema(Schema):
class DatasetResponseSchema(Schema):
class DatasetResponseSchema(BasePublicityCheckerResponseSchema):
dataset_id = ma_fields.String()
allowed = ma_fields.Boolean()
reason = ma_fields.String()

result = ma_fields.Nested(DatasetResponseSchema, many=True)


class ConnectionsPublicityCheckerRequestSchema(Schema):
connections = ma_fields.List(ma_fields.String())


class ConnectionsPublicityCheckerResponseSchema(Schema):
class ConnectionResponseSchema(BasePublicityCheckerResponseSchema):
connection_id = ma_fields.String()

result = ma_fields.Nested(ConnectionResponseSchema, many=True)


ns = API.namespace("Info", path="/info")


Expand Down Expand Up @@ -75,10 +92,11 @@ class DatasetsPublicityChecker(BIResource):
body=DatasetsPublicityCheckerRequestSchema(),
responses={200: ("Success", DatasetsPublicityCheckerResponseSchema())},
)
def post(self, body): # type: ignore # TODO: fix
def post(self, body: typing.Mapping[str, typing.Any]) -> typing.Mapping[str, typing.Any]:
ds_ids = body["datasets"]
responses = []
us_manager = self.get_us_manager()
reason = None

public_usage_checker = PublicEnvEntityUsageChecker()

Expand All @@ -97,7 +115,6 @@ def post(self, body): # type: ignore # TODO: fix
reason = exc.message
else:
allowed = True
reason = None # type: ignore # TODO: fix

responses.append(
{
Expand All @@ -110,6 +127,45 @@ def post(self, body): # type: ignore # TODO: fix
return {"result": responses}


@ns.route("/connections_publicity_checker")
class ConnectionsPublicityChecker(BIResource):
@schematic_request(
ns=ns,
body=ConnectionsPublicityCheckerRequestSchema(),
responses={200: ("Success", ConnectionsPublicityCheckerResponseSchema())},
)
def post(self, body: typing.Mapping[str, typing.Any]) -> typing.Mapping[str, typing.Any]:
conn_ids = body["connections"]
responses = []
reason = None

public_usage_checker = PublicEnvEntityUsageChecker()

for conn in self.get_us_manager().get_collection(
ConnectionBase, raise_on_broken_entry=True, include_data=True, ids=conn_ids
):
try:
public_usage_checker.ensure_data_connection_can_be_used(
rci=self.get_current_rci(),
conn=conn,
)
except EntityUsageNotAllowed as exc:
allowed = False
reason = exc.message
else:
allowed = True

responses.append(
{
"connection_id": conn.uuid,
"allowed": allowed,
"reason": reason,
}
)

return {"result": responses}


@ns.route("/connectors")
class AvailableConnectorsCollection(BIResource):
def get(self) -> dict:
Expand Down
35 changes: 35 additions & 0 deletions lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json

import pytest

Expand Down Expand Up @@ -118,3 +119,37 @@ def test_get_connector_icon(self, client, conn_type_name):
def test_get_connector_icon_not_found(self, client):
icons_resp = client.get("/api/v1/info/connectors/icons/unknown_conn_type")
assert icons_resp.status_code == 404, icons_resp.json

def test_public_usage_checker(self, client, saved_dataset, saved_connection_id):
data = dict(datasets=[saved_dataset.id])
response = client.post(
"/api/v1/info/datasets_publicity_checker",
content_type="application/json",
data=json.dumps(data),
)
expected_resp = [
dict(
reason=None,
dataset_id=saved_dataset.id,
allowed=True,
)
]

assert response.status_code == 200
assert response.json["result"] == expected_resp

data = dict(connections=[saved_connection_id])
response = client.post(
"/api/v1/info/connections_publicity_checker",
content_type="application/json",
data=json.dumps(data),
)
expected_resp = [
dict(
reason=None,
connection_id=saved_connection_id,
allowed=True,
)
]
assert response.status_code == 200
assert response.json["result"] == expected_resp
3 changes: 3 additions & 0 deletions lib/dl_utils/dl_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class AddressableData:
data: dict[str, Any] = attr.ib()

def contains(self, key: DataKey) -> bool:
if not self.data:
return False

try:
self.get(key)
return True
Expand Down

0 comments on commit e8bae23

Please sign in to comment.