Skip to content

Commit

Permalink
Avoid unnecessary connectors loading
Browse files Browse the repository at this point in the history
  • Loading branch information
asnytin committed Nov 8, 2023
1 parent 66807a2 commit 287595a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
4 changes: 1 addition & 3 deletions lib/dl_api_lib/dl_api_lib/app_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class ApiConnectorEntrypointManager(EntrypointClassManager[ApiConnector]):


def register_all_connectors(connector_ep_names: Optional[Collection[str]] = None) -> None:
connectors = ApiConnectorEntrypointManager().get_all_ep_classes()
connectors = ApiConnectorEntrypointManager().get_all_ep_classes(connector_ep_names)
for ep_name, connector_cls in sorted(connectors.items()):
if connector_ep_names is not None and ep_name not in connector_ep_names:
continue
CONN_REG_BI_API.register_connector(connector_cls)
10 changes: 7 additions & 3 deletions lib/dl_core/dl_core/connectors/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Type
from typing import (
Collection,
Optional,
Type,
)

import attr

Expand All @@ -15,9 +19,9 @@ class CoreConnectorEntrypointManager(EntrypointClassManager[CoreConnector]):
entrypoint_group_name = attr.ib(init=False, default=_CONNECTOR_EP_GROUP)


def get_all_connectors() -> dict[str, Type[CoreConnector]]:
def get_all_connectors(ep_filter: Optional[Collection[str]] = None) -> dict[str, Type[CoreConnector]]:
ep_mgr = CoreConnectorEntrypointManager()
return ep_mgr.get_all_ep_classes()
return ep_mgr.get_all_ep_classes(ep_filter)


def get_connector_cls(name: str) -> Type[CoreConnector]:
Expand Down
4 changes: 1 addition & 3 deletions lib/dl_core/dl_core/core_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,5 @@ def load_all_connectors() -> None:


def register_all_connectors(connector_ep_names: Optional[Collection[str]] = None) -> None:
for ep_name, connector_cls in sorted(get_all_connectors().items()):
if connector_ep_names is not None and ep_name not in connector_ep_names:
continue
for ep_name, connector_cls in sorted(get_all_connectors(connector_ep_names).items()):
_register_connector(connector_cls)
6 changes: 5 additions & 1 deletion lib/dl_utils/dl_utils/entrypoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
from importlib import metadata
from typing import (
Collection,
Generic,
Optional,
Type,
TypeVar,
)
Expand All @@ -20,8 +22,10 @@ class EntrypointClassManager(abc.ABC, Generic[_EP_CLS_TV]):

entrypoint_group_name: str = attr.ib(kw_only=True)

def get_all_ep_classes(self) -> dict[str, Type[_EP_CLS_TV]]:
def get_all_ep_classes(self, ep_filter: Optional[Collection[str]] = None) -> dict[str, Type[_EP_CLS_TV]]:
entrypoints = list(metadata.entry_points().select(group=self.entrypoint_group_name)) # type: ignore
if ep_filter is not None:
entrypoints = [ep for ep in entrypoints if ep.name in ep_filter]
return {ep.name: ep.load() for ep in entrypoints}

def get_ep_class(self, name: str) -> Type[_EP_CLS_TV]:
Expand Down

0 comments on commit 287595a

Please sign in to comment.