diff --git a/lib/dl_api_lib/dl_api_lib/app_connectors.py b/lib/dl_api_lib/dl_api_lib/app_connectors.py index 70a5badc3..7089a28df 100644 --- a/lib/dl_api_lib/dl_api_lib/app_connectors.py +++ b/lib/dl_api_lib/dl_api_lib/app_connectors.py @@ -25,8 +25,6 @@ class ApiConnectorEntrypointManager(EntrypointClassManager[ApiConnector]): def register_all_connectors(connector_ep_names: Optional[Collection[str]] = None) -> None: - connectors = ApiConnectorEntrypointManager().get_all_ep_classes() + connectors = ApiConnectorEntrypointManager().get_all_ep_classes(connector_ep_names) for ep_name, connector_cls in sorted(connectors.items()): - if connector_ep_names is not None and ep_name not in connector_ep_names: - continue CONN_REG_BI_API.register_connector(connector_cls) diff --git a/lib/dl_core/dl_core/connectors/registry.py b/lib/dl_core/dl_core/connectors/registry.py index 4be619f71..17a88d208 100644 --- a/lib/dl_core/dl_core/connectors/registry.py +++ b/lib/dl_core/dl_core/connectors/registry.py @@ -1,4 +1,8 @@ -from typing import Type +from typing import ( + Collection, + Optional, + Type, +) import attr @@ -15,9 +19,9 @@ class CoreConnectorEntrypointManager(EntrypointClassManager[CoreConnector]): entrypoint_group_name = attr.ib(init=False, default=_CONNECTOR_EP_GROUP) -def get_all_connectors() -> dict[str, Type[CoreConnector]]: +def get_all_connectors(ep_filter: Optional[Collection[str]] = None) -> dict[str, Type[CoreConnector]]: ep_mgr = CoreConnectorEntrypointManager() - return ep_mgr.get_all_ep_classes() + return ep_mgr.get_all_ep_classes(ep_filter) def get_connector_cls(name: str) -> Type[CoreConnector]: diff --git a/lib/dl_core/dl_core/core_connectors.py b/lib/dl_core/dl_core/core_connectors.py index 7e61bfe0a..cfea100ea 100644 --- a/lib/dl_core/dl_core/core_connectors.py +++ b/lib/dl_core/dl_core/core_connectors.py @@ -18,7 +18,5 @@ def load_all_connectors() -> None: def register_all_connectors(connector_ep_names: Optional[Collection[str]] = None) -> None: - for ep_name, connector_cls in sorted(get_all_connectors().items()): - if connector_ep_names is not None and ep_name not in connector_ep_names: - continue + for ep_name, connector_cls in sorted(get_all_connectors(connector_ep_names).items()): _register_connector(connector_cls) diff --git a/lib/dl_utils/dl_utils/entrypoints.py b/lib/dl_utils/dl_utils/entrypoints.py index ba3d436dc..36a50d70c 100644 --- a/lib/dl_utils/dl_utils/entrypoints.py +++ b/lib/dl_utils/dl_utils/entrypoints.py @@ -1,7 +1,9 @@ import abc from importlib import metadata from typing import ( + Collection, Generic, + Optional, Type, TypeVar, ) @@ -20,8 +22,10 @@ class EntrypointClassManager(abc.ABC, Generic[_EP_CLS_TV]): entrypoint_group_name: str = attr.ib(kw_only=True) - def get_all_ep_classes(self) -> dict[str, Type[_EP_CLS_TV]]: + def get_all_ep_classes(self, ep_filter: Optional[Collection[str]] = None) -> dict[str, Type[_EP_CLS_TV]]: entrypoints = list(metadata.entry_points().select(group=self.entrypoint_group_name)) # type: ignore + if ep_filter is not None: + entrypoints = [ep for ep in entrypoints if ep.name in ep_filter] return {ep.name: ep.load() for ep in entrypoints} def get_ep_class(self, name: str) -> Type[_EP_CLS_TV]: