Skip to content

Commit

Permalink
add ability to register to on ws connect and disconnect events of opa…
Browse files Browse the repository at this point in the history
…l client
  • Loading branch information
Asaf Cohen committed Nov 5, 2024
1 parent 6b72ec8 commit 1b2e9b6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
10 changes: 10 additions & 0 deletions packages/opal-client/opal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import websockets
from fastapi import FastAPI, status
from fastapi.responses import JSONResponse
from fastapi_websocket_pubsub.pub_sub_client import PubSubOnConnectCallback
from fastapi_websocket_rpc.rpc_channel import OnDisconnectCallback
from opal_client.callbacks.api import init_callbacks_api
from opal_client.callbacks.register import CallbacksRegister
from opal_client.config import PolicyStoreTypes, opal_client_config
Expand Down Expand Up @@ -54,6 +56,10 @@ def __init__(
store_backup_interval: Optional[int] = None,
offline_mode_enabled: bool = False,
shard_id: Optional[str] = None,
on_data_updater_connect: List[PubSubOnConnectCallback] = None,
on_data_updater_disconnect: List[OnDisconnectCallback] = None,
on_policy_updater_connect: List[PubSubOnConnectCallback] = None,
on_policy_updater_disconnect: List[OnDisconnectCallback] = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -119,6 +125,8 @@ def __init__(
policy_store=self.policy_store,
callbacks_register=self._callbacks_register,
opal_client_id=opal_client_identifier,
on_connect=on_policy_updater_connect,
on_disconnect=on_policy_updater_disconnect,
)
else:
self.policy_updater = None
Expand All @@ -140,6 +148,8 @@ def __init__(
callbacks_register=self._callbacks_register,
opal_client_id=opal_client_identifier,
shard_id=self._shard_id,
on_connect=on_data_updater_connect,
on_disconnect=on_data_updater_disconnect,
)
else:
self.data_updater = None
Expand Down
10 changes: 8 additions & 2 deletions packages/opal-client/opal_client/data/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import aiohttp
from aiohttp.client import ClientError, ClientSession
from fastapi_websocket_pubsub import PubSubClient
from fastapi_websocket_rpc.rpc_channel import RpcChannel
from fastapi_websocket_pubsub.pub_sub_client import PubSubOnConnectCallback
from fastapi_websocket_rpc.rpc_channel import OnDisconnectCallback, RpcChannel
from opal_client.callbacks.register import CallbacksRegister
from opal_client.callbacks.reporter import CallbacksReporter
from opal_client.config import opal_client_config
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(
callbacks_register: Optional[CallbacksRegister] = None,
opal_client_id: str = None,
shard_id: Optional[str] = None,
on_connect: List[PubSubOnConnectCallback] = None,
on_disconnect: List[OnDisconnectCallback] = None,
):
"""Keeps policy-stores (e.g. OPA) up to date with relevant data Obtains
data configuration on startup from OPAL-server Uses Pub/Sub to
Expand Down Expand Up @@ -132,6 +135,8 @@ def __init__(
self._updates_storing_queue = TakeANumberQueue(logger)
self._tasks = TasksPool()
self._polling_update_tasks = []
self._on_connect_callbacks = on_connect or []
self._on_disconnect_callbacks = on_disconnect or []

async def __aenter__(self):
await self.start()
Expand Down Expand Up @@ -278,7 +283,8 @@ async def _subscriber(self):
self._data_topics,
self._update_policy_data_callback,
methods_class=TenantAwareRpcEventClientMethods,
on_connect=[self.on_connect],
on_connect=[self.on_connect, *self._on_connect_callbacks],
on_disconnect=[self.on_disconnect, *self._on_disconnect_callbacks],
extra_headers=self._extra_headers,
keep_alive=opal_client_config.KEEP_ALIVE_INTERVAL,
server_uri=self._server_url,
Expand Down
11 changes: 8 additions & 3 deletions packages/opal-client/opal_client/policy/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import pydantic
from fastapi_websocket_pubsub import PubSubClient
from fastapi_websocket_rpc.rpc_channel import RpcChannel
from fastapi_websocket_pubsub.pub_sub_client import PubSubOnConnectCallback
from fastapi_websocket_rpc.rpc_channel import OnDisconnectCallback, RpcChannel
from opal_client.callbacks.register import CallbacksRegister
from opal_client.callbacks.reporter import CallbacksReporter
from opal_client.config import opal_client_config
Expand Down Expand Up @@ -43,6 +44,8 @@ def __init__(
data_fetcher: Optional[DataFetcher] = None,
callbacks_register: Optional[CallbacksRegister] = None,
opal_client_id: str = None,
on_connect: List[PubSubOnConnectCallback] = None,
on_disconnect: List[OnDisconnectCallback] = None,
):
"""inits the policy updater.
Expand Down Expand Up @@ -104,6 +107,8 @@ def __init__(
)
self._policy_update_queue = asyncio.Queue()
self._tasks = TasksPool()
self._on_connect_callbacks = on_connect or []
self._on_disconnect_callbacks = on_disconnect or []

async def __aenter__(self):
await self.start()
Expand Down Expand Up @@ -243,8 +248,8 @@ async def _subscriber(self):
self._client = PubSubClient(
topics=self._topics,
callback=self._update_policy_callback,
on_connect=[self._on_connect],
on_disconnect=[self._on_disconnect],
on_connect=[self._on_connect, *self._on_connect_callbacks],
on_disconnect=[self._on_disconnect, *self._on_disconnect_callbacks],
extra_headers=self._extra_headers,
keep_alive=opal_client_config.KEEP_ALIVE_INTERVAL,
server_uri=self._server_url,
Expand Down

0 comments on commit 1b2e9b6

Please sign in to comment.