Skip to content

Commit

Permalink
Refactored service registry creation in core tests
Browse files Browse the repository at this point in the history
  • Loading branch information
altvod committed Nov 15, 2023
1 parent d9da83c commit 8ef1a55
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def conn_sync_service_registry(
conn_bi_context: RequestContextInfo,
task_processor_factory: TaskProcessorFactory,
) -> ServicesRegistry:
return self.service_registry_factory(
return self.make_service_registry(
conn_exec_factory_async_env=False,
conn_bi_context=conn_bi_context,
task_processor_factory=task_processor_factory,
Expand All @@ -122,7 +122,7 @@ def conn_async_service_registry(
conn_bi_context: RequestContextInfo,
task_processor_factory: TaskProcessorFactory,
) -> ServicesRegistry:
return self.service_registry_factory(
return self.make_service_registry(
conn_exec_factory_async_env=True,
conn_bi_context=conn_bi_context,
task_processor_factory=task_processor_factory,
Expand Down
22 changes: 15 additions & 7 deletions lib/dl_core/dl_core/services_registry/sr_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
)
from dl_core.us_manager.mutation_cache.usentry_mutation_cache_factory import USEntryMutationCacheFactory
from dl_core.utils import FutureRef
from dl_task_processor.processor import ARQTaskProcessorFactory
from dl_task_processor.processor import (
ARQTaskProcessorFactory,
TaskProcessorFactory,
)


if TYPE_CHECKING:
Expand All @@ -54,6 +57,7 @@
from dl_core.services_registry.inst_specific_sr import InstallationSpecificServiceRegistryFactory
from dl_core.services_registry.typing import ConnectOptionsFactory
from dl_core.us_connection_base import ExecutorBasedMixin
from dl_utils.aio import ContextVarExecutor


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -100,9 +104,17 @@ class DefaultSRFactory(SRFactory[SERVICE_REGISTRY_TV]): # type: ignore # TODO:
rqe_caches_settings: Optional[RQECachesSetting] = attr.ib(default=None)
required_services: set[RequiredService] = attr.ib(factory=set)
inst_specific_sr_factory: Optional[InstallationSpecificServiceRegistryFactory] = attr.ib(default=None)
task_processor_factory: Optional[TaskProcessorFactory] = attr.ib()
tpe: Optional[ContextVarExecutor] = attr.ib(default=None)

service_registry_cls: ClassVar[Type[SERVICE_REGISTRY_TV]] = DefaultServicesRegistry # type: ignore # TODO: fix

@task_processor_factory.default
def _make_task_processor_factory(self) -> Optional[TaskProcessorFactory]:
if self.redis_pool_settings:
return ARQTaskProcessorFactory(redis_pool_settings=self.redis_pool_settings)
return None

def is_bleeding_edge_user(self, request_context_info: RequestContextInfo) -> bool:
return request_context_info.user_name in self.bleeding_edge_users

Expand All @@ -116,7 +128,7 @@ def make_conn_executor_factory(
LOGGER.info("ATTENTION! It's bleeding edge user")
return DefaultConnExecutorFactory(
async_env=self.async_env,
tpe=None,
tpe=self.tpe,
conn_sec_mgr=self.env_manager_factory.make_security_manager(),
rqe_config=self.rqe_config,
services_registry_ref=sr_ref, # type: ignore # TODO: fix
Expand Down Expand Up @@ -172,11 +184,7 @@ def make_service_registry(
)
if self.file_uploader_settings
else None,
task_processor_factory=ARQTaskProcessorFactory(
redis_pool_settings=self.redis_pool_settings,
)
if self.redis_pool_settings
else None,
task_processor_factory=self.task_processor_factory,
rqe_caches_settings=self.rqe_caches_settings,
required_services=self.required_services,
inst_specific_sr=(
Expand Down
51 changes: 48 additions & 3 deletions lib/dl_core_testing/dl_core_testing/testcases/service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from dl_constants.enums import ConnectionType
from dl_core.connections_security.base import InsecureConnectionSecurityManager
from dl_core.services_registry.conn_executor_factory import DefaultConnExecutorFactory
from dl_core.services_registry.env_manager_factory import InsecureEnvManagerFactory
from dl_core.services_registry.inst_specific_sr import InstallationSpecificServiceRegistryFactory
from dl_core.services_registry.sr_factories import (
DefaultSRFactory,
SRFactory,
)
from dl_core.services_registry.top_level import (
DefaultServicesRegistry,
ServicesRegistry,
Expand All @@ -37,6 +42,10 @@
from dl_core_testing.fixtures.dispenser import DbCsvTableDispenser
from dl_db_testing.database.engine_wrapper import DbEngineConfig
from dl_utils.aio import ContextVarExecutor
from dl_task_processor.processor import (
DummyTaskProcessorFactory,
TaskProcessorFactory,
)


class USConfig(NamedTuple):
Expand Down Expand Up @@ -69,7 +78,43 @@ def conn_bi_context(self) -> RequestContextInfo:
def conn_exec_factory_async_env(self) -> bool:
return False

def service_registry_factory(
def make_service_registry_factory(
self,
async_env: bool,
task_processor_factory: TaskProcessorFactory = DummyTaskProcessorFactory(),
) -> SRFactory:
return DefaultSRFactory(
async_env=async_env,
rqe_config=RQEConfig.get_default(), # Not used because RQE is disabled
connectors_settings={self.conn_type: self.connection_settings} if self.connection_settings else {},
inst_specific_sr_factory=self.inst_specific_sr_factory,
env_manager_factory=InsecureEnvManagerFactory(),
force_non_rqe_mode=True,
tpe=ContextVarExecutor(),
task_processor_factory=task_processor_factory,
)

def make_service_registry(
self,
conn_exec_factory_async_env: bool,
conn_bi_context: RequestContextInfo,
task_processor_factory: TaskProcessorFactory = DummyTaskProcessorFactory(),
**kwargs: Any,
) -> ServicesRegistry:
sr_factory = self.make_service_registry_factory(
async_env=conn_exec_factory_async_env,
task_processor_factory=task_processor_factory,
)
return sr_factory.make_service_registry(
request_context_info=conn_bi_context,
mutations_cache_factory=DefaultUSEntryMutationCacheFactory(),
reporting_registry=DefaultReportingRegistry(
rci=conn_bi_context,
),
**kwargs,
)

def make_service_registry_legacy(
self,
conn_exec_factory_async_env: bool,
conn_bi_context: RequestContextInfo,
Expand Down Expand Up @@ -106,7 +151,7 @@ def conn_sync_service_registry(
self,
conn_bi_context: RequestContextInfo,
) -> ServicesRegistry:
return self.service_registry_factory(
return self.make_service_registry(
conn_exec_factory_async_env=False,
conn_bi_context=conn_bi_context,
)
Expand All @@ -116,7 +161,7 @@ def conn_async_service_registry(
self,
conn_bi_context: RequestContextInfo,
) -> ServicesRegistry:
return self.service_registry_factory(
return self.make_service_registry(
conn_exec_factory_async_env=True,
conn_bi_context=conn_bi_context,
)
Expand Down

0 comments on commit 8ef1a55

Please sign in to comment.