Skip to content

Commit

Permalink
Use root ca explicitly in http-based adapters (#262)
Browse files Browse the repository at this point in the history
* Use root ca explicitly in http-based adapters

* Check cert's file encoding

* Add typing

* fix imports
  • Loading branch information
thenno authored Feb 2, 2024
1 parent b1e6bdb commit 4fc6f55
Show file tree
Hide file tree
Showing 26 changed files with 99 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __attrs_post_init__(self): # type: ignore # 2024-01-24 # TODO: Function is
self._session = self._make_session()

def _make_session(self) -> aiohttp.ClientSession:
ssl_context = ssl.create_default_context(cadata=self._ca_data.decode("utf-8"))
ssl_context = ssl.create_default_context(cadata=self._ca_data.decode("ascii"))
return aiohttp.ClientSession(
cookies=self.cookies,
headers=self.headers,
Expand Down
2 changes: 1 addition & 1 deletion lib/dl_api_lib_testing/dl_api_lib_testing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def control_api_app_settings(
rqe_config_subprocess=rqe_config_subprocess,
)

@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def ca_data(self) -> bytes:
return get_root_certificates()

Expand Down
19 changes: 18 additions & 1 deletion lib/dl_configs/dl_configs/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from typing import (
Callable,
Expand All @@ -12,6 +13,9 @@
TEMP_ROOT_CERTIFICATES_FOLDER_PATH = "/tmp/ssl/certs/"


LOGGER = logging.getLogger(__name__)


_T = TypeVar("_T")


Expand All @@ -25,8 +29,21 @@ def validator(value: _T) -> _T:


def get_root_certificates(path: str = DEFAULT_ROOT_CERTIFICATES_FILENAME) -> bytes:
"""
expects a path to a file with PEM certificates
aiohttp-based clients expect certificates as an ascii string to create ssl.sslContext
while grpc-clients expect them as a byte representation of an ascii string to create the special grpc ssl context
"""
with open(path, "rb") as fobj:
return fobj.read()
ca_data = fobj.read()
# fail fast
try:
ca_data.decode("ascii")
except UnicodeDecodeError:
LOGGER.exception("Looks like the certificates are not in PEM format")
raise
return ca_data


def get_root_certificates_path() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,6 @@ async def _make_target_conn_dto_pool(self) -> Sequence[BitrixGDSConnTargetDTO]:
connect_timeout=self._conn_options.connect_timeout, # type: ignore # TODO: fix
redis_conn_params=conn_params,
redis_caches_ttl=caches_ttl,
ca_data=self._ca_data.decode("ascii"),
)
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import attr

from dl_core.connection_executors.models.connection_target_dto_base import ConnTargetDTO
from dl_core.connection_executors.models.connection_target_dto_base import BaseAiohttpConnTargetDTO
from dl_core.utils import secrepr


Expand All @@ -15,7 +15,7 @@ def hide_pass(value: Optional[dict]) -> str:


@attr.s(frozen=True)
class BitrixGDSConnTargetDTO(ConnTargetDTO):
class BitrixGDSConnTargetDTO(BaseAiohttpConnTargetDTO):
portal: str = attr.ib(kw_only=True)
token: str = attr.ib(kw_only=True, repr=secrepr)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def _make_target_conn_dto_pool(self) -> list[BaseFileS3ConnTargetDTO]: #
access_key_id=self._conn_dto.access_key_id,
secret_access_key=self._conn_dto.secret_access_key,
replace_secret=self._conn_dto.replace_secret,
ca_data=self._ca_data.decode("ascii"),
)
)
return dto_pool
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import attr

from dl_core.connection_executors.models.connection_target_dto_base import BaseSQLConnTargetDTO
from dl_core.connection_executors.models.connection_target_dto_base import (
BaseAiohttpConnTargetDTO,
BaseSQLConnTargetDTO,
)
from dl_core.utils import secrepr


@attr.s
class BaseFileS3ConnTargetDTO(BaseSQLConnTargetDTO):
class BaseFileS3ConnTargetDTO(BaseAiohttpConnTargetDTO, BaseSQLConnTargetDTO):
protocol: str = attr.ib(kw_only=True)
disable_value_processing: bool = attr.ib(kw_only=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,25 +107,29 @@ def task_processor_factory(self) -> TaskProcessorFactory:
@pytest.fixture(scope="session")
def conn_sync_service_registry(
self,
root_certificates: bytes,
conn_bi_context: RequestContextInfo,
task_processor_factory: TaskProcessorFactory,
) -> ServicesRegistry:
return self.service_registry_factory(
conn_exec_factory_async_env=False,
conn_bi_context=conn_bi_context,
task_processor_factory=task_processor_factory,
root_certificates_data=root_certificates,
)

@pytest.fixture(scope="session")
def conn_async_service_registry(
self,
root_certificates: bytes,
conn_bi_context: RequestContextInfo,
task_processor_factory: TaskProcessorFactory,
) -> ServicesRegistry:
return self.service_registry_factory(
conn_exec_factory_async_env=True,
conn_bi_context=conn_bi_context,
task_processor_factory=task_processor_factory,
root_certificates_data=root_certificates,
)

@pytest.fixture(scope="function")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from dl_api_lib_testing.initialization import initialize_api_lib_test
from dl_testing.utils import get_root_certificates

from dl_connector_bundle_chs3.chs3_base.core.us_connection import BaseFileS3Connection
from dl_connector_bundle_chs3_tests.db.config import API_TEST_CONFIG
Expand All @@ -21,3 +22,8 @@ def _patched(self: Any, s3_filename_suffix: str) -> str:
return s3_filename_suffix

monkeypatch.setattr(BaseFileS3Connection, "get_full_s3_filename", _patched)


@pytest.fixture(scope="session")
def root_certificates() -> bytes:
return get_root_certificates()
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ async def _get_target_conn_dto(self) -> CHYTConnTargetDTO:
insert_quorum=self._conn_options.insert_quorum,
insert_quorum_timeout=self._conn_options.insert_quorum_timeout,
disable_value_processing=self._conn_options.disable_value_processing,
ca_data=self._ca_data.decode("ascii"),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dl_core.connection_executors.adapters.common_base import CommonBaseDirectAdapter
from dl_core.connection_executors.models.connection_target_dto_base import BaseSQLConnTargetDTO
from dl_core_testing.executors import ExecutorFactoryBase
from dl_testing.utils import get_root_certificates

from dl_connector_clickhouse.core.clickhouse_base.adapters import ClickHouseAdapter
from dl_connector_clickhouse.core.clickhouse_base.target_dto import ClickHouseConnTargetDTO
Expand All @@ -30,4 +31,5 @@ def get_dto_kwargs(self) -> dict[str, Any]:
insert_quorum=None,
insert_quorum_timeout=None,
disable_value_processing=False,
ca_data=get_root_certificates(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async def _make_target_conn_dto_pool(self) -> list[ClickHouseConnTargetDTO]: #
disable_value_processing=self._conn_options.disable_value_processing,
secure=self._conn_dto.secure,
ssl_ca=self._conn_dto.ssl_ca,
ca_data=self._ca_data.decode("ascii"),
)
)
return dto_pool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import attr

from dl_core.connection_executors.models.connection_target_dto_base import BaseSQLConnTargetDTO
from dl_core.connection_executors.models.connection_target_dto_base import (
BaseAiohttpConnTargetDTO,
BaseSQLConnTargetDTO,
)


@attr.s(frozen=True)
class BaseClickHouseConnTargetDTO(BaseSQLConnTargetDTO):
class BaseClickHouseConnTargetDTO(BaseSQLConnTargetDTO, BaseAiohttpConnTargetDTO):
protocol: str = attr.ib()
# TODO CONSIDER: Is really optional?
endpoint: Optional[str] = attr.ib()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def _make_target_conn_dto_pool(self) -> Sequence[PromQLConnTargetDTO]:
password=self._conn_dto.password,
protocol=self._conn_dto.protocol,
db_name=self._conn_dto.db_name,
ca_data=self._ca_data.decode("ascii"),
)
]

Expand All @@ -62,5 +63,6 @@ async def _make_target_conn_dto_pool(self) -> Sequence[PromQLConnTargetDTO]:
password=self._conn_dto.password,
protocol=self._conn_dto.protocol,
db_name=self._conn_dto.db_name,
ca_data=self._ca_data.decode("ascii"),
)
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import attr

from dl_core.connection_executors.models.connection_target_dto_base import BaseSQLConnTargetDTO
from dl_core.connection_executors.models.connection_target_dto_base import (
BaseAiohttpConnTargetDTO,
BaseSQLConnTargetDTO,
)


@attr.s
class PromQLConnTargetDTO(BaseSQLConnTargetDTO):
class PromQLConnTargetDTO(BaseSQLConnTargetDTO, BaseAiohttpConnTargetDTO):
path: str = attr.ib()
protocol: str = attr.ib()
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@
import aiohttp.client_exceptions
import attr

from dl_configs.utils import get_root_certificates_path
from dl_core import exc
from dl_core.connection_executors.adapters.async_adapters_base import AsyncDirectDBAdapter


if TYPE_CHECKING:
from aiohttp import BasicAuth

from dl_core.connection_executors.models.connection_target_dto_base import ConnTargetDTO
from dl_core.connection_executors.models.connection_target_dto_base import BaseAiohttpConnTargetDTO
from dl_core.connection_executors.models.db_adapter_data import DBAdapterQuery
from dl_core.connection_executors.models.scoped_rci import DBAdapterScopedRCI

Expand All @@ -43,7 +42,7 @@ class AiohttpDBAdapter(AsyncDirectDBAdapter, metaclass=abc.ABCMeta):
"""Common base for adapters that primarily use an aiohttp client"""

# TODO?: commonize some of the CTDTO attributes such as connect_timeout?
_target_dto: ConnTargetDTO = attr.ib()
_target_dto: BaseAiohttpConnTargetDTO = attr.ib()
_req_ctx_info: DBAdapterScopedRCI = attr.ib()

_http_read_chunk_size: int = attr.ib(init=False, default=(1024 * 64))
Expand All @@ -57,9 +56,9 @@ def __attrs_post_init__(self) -> None:
auth=self.get_session_auth(),
headers=self.get_session_headers(),
connector=self.create_aiohttp_connector(
# TODO: pass ca_data through *DTO
# https://github.com/datalens-tech/datalens-backend/issues/233
ssl_context=ssl.create_default_context(cafile=get_root_certificates_path())
ssl_context=ssl.create_default_context(
cadata=self._target_dto.ca_data,
)
),
)

Expand All @@ -71,7 +70,7 @@ def create_aiohttp_connector(self, ssl_context: Optional[ssl.SSLContext]) -> aio
@classmethod
def create(
cls: Type[_DBA_TV],
target_dto: ConnTargetDTO,
target_dto: BaseAiohttpConnTargetDTO,
req_ctx_info: DBAdapterScopedRCI,
default_chunk_size: int,
) -> _DBA_TV:
Expand Down
1 change: 1 addition & 0 deletions lib/dl_core/dl_core/connection_executors/common_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class ConnExecutorBase(metaclass=abc.ABCMeta):
_exec_mode: ExecutionMode = attr.ib()
_sec_mgr: "ConnectionSecurityManager" = attr.ib()
_remote_qe_data: Optional[RemoteQueryExecutorData] = attr.ib()
_ca_data: bytes = attr.ib()
_services_registry: Optional[ServicesRegistry] = attr.ib(
kw_only=True, default=None
) # Do not use. To be deprecated. Somehow.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,8 @@ class BaseSQLConnTargetDTO(ConnTargetDTO):

def get_effective_host(self) -> Optional[str]:
return self.host


@attr.s(frozen=True)
class BaseAiohttpConnTargetDTO(ConnTargetDTO, metaclass=abc.ABCMeta):
ca_data: str = attr.ib(repr=secrepr)
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DefaultConnExecutorFactory(BaseClosableExecutorFactory):
rqe_config: Optional[RQEConfig] = attr.ib()
tpe: Optional[ContextVarExecutor] = attr.ib()
conn_sec_mgr: ConnectionSecurityManager = attr.ib()
ca_data: bytes = attr.ib()

is_bleeding_edge_user: bool = attr.ib(default=False)
conn_cls_whitelist: Optional[FrozenSet[Type[ExecutorBasedMixin]]] = attr.ib(default=None)
Expand Down Expand Up @@ -149,6 +150,7 @@ def _get_async_conn_executor_recipe(
rqe_data=rqe_data,
exec_mode=exec_mode,
conn_hosts_pool=conn_hosts_pool, # type: ignore # TODO: fix
ca_data=self.ca_data,
)

def _cook_conn_executor(self, recipe: ConnExecutorRecipe, with_tpe: bool) -> AsyncConnExecutorBase:
Expand All @@ -170,6 +172,7 @@ def _conn_host_fail_callback_func(host: str): # type: ignore # TODO: fix
conn_hosts_pool=recipe.conn_hosts_pool,
host_fail_callback=_conn_host_fail_callback_func,
services_registry=self._services_registry_ref.ref, # Do not use. To be deprecated. Somehow.
ca_data=recipe.ca_data,
)
else:
raise CEFactoryError(f"Can not instantiate {executor_cls}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ConnExecutorRecipe:
exec_mode: ExecutionMode
rqe_data: Optional[RemoteQueryExecutorData]
conn_hosts_pool: Sequence[str] = attr.ib(kw_only=True, converter=tuple)
ca_data: bytes


@attr.s(frozen=True)
Expand Down
1 change: 1 addition & 0 deletions lib/dl_core/dl_core/services_registry/sr_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def make_conn_executor_factory(
connect_options_mutator=self.env_manager_factory.mutate_conn_opts,
entity_usage_checker=self.entity_usage_checker,
force_non_rqe_mode=self.force_non_rqe_mode,
ca_data=self.ca_data,
)

def additional_sr_constructor_kwargs(
Expand Down
2 changes: 1 addition & 1 deletion lib/dl_core/dl_core_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
pytest_plugins = ("aiohttp.pytest_plugin",) # and it, in turn, includes 'pytest_asyncio.plugin'


@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def root_certificates() -> bytes:
return get_root_certificates()
14 changes: 11 additions & 3 deletions lib/dl_core/dl_core_tests/db/compeng/test_compeng_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def test_compeng_cache(
sync_us_manager,
caches_redis_client_factory,
data_processor_service_factory,
root_certificates,
):
dataset = saved_dataset
us_manager = sync_us_manager
Expand Down Expand Up @@ -122,6 +123,7 @@ async def get_data_from_processor(input_data, data_key, operations):
conn_bi_context=conn_bi_context,
caches_redis_client_factory=caches_redis_client_factory,
data_processor_service_factory=data_processor_service_factory,
root_certificates_data=root_certificates,
)
dto = ClickHouseConnDTO(
conn_id="123",
Expand Down Expand Up @@ -190,7 +192,9 @@ async def get_data_from_processor(input_data, data_key, operations):
operations_c10 = get_operations(coeff=10)
data_key_l10 = LocalKeyRepresentation(key_parts=(DataKeyPart(part_type="part", part_content="value_l10"),))
output_data = await get_data_from_processor(
input_data=input_data_l10, data_key=data_key_l10, operations=operations_c10
input_data=input_data_l10,
data_key=data_key_l10,
operations=operations_c10,
)
expected_data_l10_c10 = get_expected_data(length=10, coeff=10)
assert output_data == expected_data_l10_c10
Expand All @@ -204,7 +208,9 @@ async def get_data_from_processor(input_data, data_key, operations):
caplog.clear()
input_data_l5 = [[i, f"str_{i}"] for i in range(5)] # up to 5 instead of 10
output_data = await get_data_from_processor(
input_data=input_data_l5, data_key=data_key_l10, operations=operations_c10
input_data=input_data_l5,
data_key=data_key_l10,
operations=operations_c10,
)
assert output_data == expected_data_l10_c10
# Check cache flags in reporting
Expand All @@ -218,7 +224,9 @@ async def get_data_from_processor(input_data, data_key, operations):
data_key_l5 = LocalKeyRepresentation(key_parts=(DataKeyPart(part_type="part", part_content="value_l5"),))
operations_c5 = get_operations(coeff=5)
output_data = await get_data_from_processor(
input_data=input_data_l5, data_key=data_key_l5, operations=operations_c5
input_data=input_data_l5,
data_key=data_key_l5,
operations=operations_c5,
)
expected_data_l5_c5 = get_expected_data(length=5, coeff=5)
assert output_data == expected_data_l5_c5
Expand Down
Loading

0 comments on commit 4fc6f55

Please sign in to comment.