diff --git a/firebase_messaging/__init__.py b/firebase_messaging/__init__.py index 3797f54..df63dbb 100644 --- a/firebase_messaging/__init__.py +++ b/firebase_messaging/__init__.py @@ -1,3 +1,9 @@ from .fcmpushclient import FcmPushClient, FcmPushClientConfig, FcmPushClientRunState +from .fcmregister import FcmRegisterConfig -__all__ = ["FcmPushClientConfig", "FcmPushClient", "FcmPushClientRunState"] +__all__ = [ + "FcmPushClientConfig", + "FcmPushClient", + "FcmPushClientRunState", + "FcmRegisterConfig", +] diff --git a/firebase_messaging/const.py b/firebase_messaging/const.py index e7c3b46..939ba49 100644 --- a/firebase_messaging/const.py +++ b/firebase_messaging/const.py @@ -25,8 +25,16 @@ + "j8TM4W88jITfq7ZmPvIM1Iv-4_l2LxQcYwhqby2xGpWwzjfAnG4" ) -FCM_SUBSCRIBE_URL = "https://fcm.googleapis.com/fcm/connect/subscribe" -FCM_SEND_URL = "https://fcm.googleapis.com/fcm/send" +FCM_SUBSCRIBE_URL = "https://fcm.googleapis.com/fcm/connect/subscribe/" +FCM_SEND_URL = "https://fcm.googleapis.com/fcm/send/" + +FCM_API = "https://fcm.googleapis.com/v1/" +FCM_REGISTRATION = "https://fcmregistrations.googleapis.com/v1/" +FCM_INSTALLATION = "https://firebaseinstallations.googleapis.com/v1/" +AUTH_VERSION = "FIS_v2" +SDK_VERSION = "w:0.6.6" + +DOORBELLS_ENDPOINT = "/clients_api/doorbots/{0}" MCS_VERSION = 41 MCS_HOST = "mtalk.google.com" diff --git a/firebase_messaging/fcmpushclient.py b/firebase_messaging/fcmpushclient.py index 2ea1f68..abe226e 100644 --- a/firebase_messaging/fcmpushclient.py +++ b/firebase_messaging/fcmpushclient.py @@ -1,7 +1,8 @@ import asyncio -import functools +import contextlib import json import logging +import ssl import struct import time import traceback @@ -9,9 +10,7 @@ from contextlib import suppress as contextlib_suppress from dataclasses import dataclass from enum import Enum -from ssl import SSLError -from threading import Thread -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple from aiohttp import ClientSession from cryptography.hazmat.backends import default_backend @@ -27,7 +26,7 @@ MCS_SELECTIVE_ACK_ID, MCS_VERSION, ) -from .fcmregister import FcmRegister +from .fcmregister import FcmRegister, FcmRegisterConfig from .proto.mcs_pb2 import ( # pylint: disable=no-name-in-module Close, DataMessageStanza, @@ -41,8 +40,8 @@ _logger = logging.getLogger(__name__) -OnNotificationCallable = Callable[[Dict[str, Dict[str, str]], str, Any], None] -CredentialsUpdatedCallable = Callable[[Dict[str, Any]], None] +OnNotificationCallable = Callable[[dict[str, Any], str, Any], None] +CredentialsUpdatedCallable = Callable[[dict[str, Any]], None] class ErrorType(Enum): @@ -85,7 +84,7 @@ class FcmPushClientConfig: # pylint:disable=too-many-instance-attributes """Time in seconds to wait before attempting to retry the connection after failure.""" - reset_interval: float = 1 + reset_interval: float = 3 """Time in seconds to wait between resets after errors or disconnection.""" heartbeat_ack_timeout: float = 5 @@ -120,15 +119,21 @@ class FcmPushClient: # pylint:disable=too-many-instance-attributes def __init__( self, + callback: Callable[[dict, str, Any | None], None], + fcm_config: FcmRegisterConfig, + credentials: dict | None = None, + credentials_updated_callback: CredentialsUpdatedCallable | None = None, *, - credentials: Optional[dict] = None, - credentials_updated_callback: Optional[CredentialsUpdatedCallable] = None, - received_persistent_ids: Optional[List[str]] = None, - config: Optional[FcmPushClientConfig] = None, - http_client_session: Optional[ClientSession] = None, + callback_context: object | None = None, + received_persistent_ids: list[str] | None = None, + config: FcmPushClientConfig | None = None, + http_client_session: ClientSession | None = None, ): """Initializes the receiver.""" - self.credentials: Optional[Dict[str, Dict[str, str]]] = credentials + self.callback = callback + self.callback_context = callback_context + self.fcm_config = fcm_config + self.credentials = credentials self.credentials_updated_callback = credentials_updated_callback self.persistent_ids = received_persistent_ids if received_persistent_ids else [] self.config = config if config else FcmPushClientConfig() @@ -136,28 +141,21 @@ def __init__( _logger.setLevel(logging.DEBUG) self._http_client_session = http_client_session - self.reader: Optional[asyncio.StreamReader] = None - self.writer: Optional[asyncio.StreamWriter] = None + self.reader: asyncio.StreamReader | None = None + self.writer: asyncio.StreamWriter | None = None self.do_listen = False self.sequential_error_counters: Dict[ErrorType, int] = {} - self.log_warn_counters: Dict[str, int] = {} + self.log_warn_counters: dict[str, int] = {} # reset variables self.input_stream_id = 0 self.last_input_stream_id_reported = -1 self.first_message = True - self.last_login_time: Optional[float] = None - self.last_message_time: Optional[float] = None + self.last_login_time: float | None = None + self.last_message_time: float | None = None self.run_state: FcmPushClientRunState = FcmPushClientRunState.CREATED - self.tasks: List[asyncio.Task] = [] - - self.listen_event_loop: Optional[asyncio.AbstractEventLoop] = None - self.callback_event_loop: Optional[asyncio.AbstractEventLoop] = None - self.fcm_thread: Optional[Thread] = None - - self.app_id: Optional[str] = None - self.sender_id: Optional[int] = None + self.tasks: list[asyncio.Task] = [] self.reset_lock: Optional[asyncio.Lock] = None self.stopping_lock: Optional[asyncio.Lock] = None @@ -182,16 +180,12 @@ def _log_warn_with_limit(self, msg: str, *args: object) -> None: _logger.warning(msg, *args) async def _do_writer_close(self) -> None: - try: - if ( - self.listen_event_loop - and self.writer - and self.listen_event_loop.is_running() - ): - self.writer.close() - await self.writer.wait_closed() - except OSError as e: - _logger.debug("%s Error while trying to close writer", type(e).__name__) + writer = self.writer + self.writer = None + if writer: + writer.close() + with contextlib.suppress(Exception): + await writer.wait_closed() async def _reset(self) -> None: if ( @@ -205,19 +199,21 @@ async def _reset(self) -> None: _logger.debug("Resetting connection") self.run_state = FcmPushClientRunState.RESETTING + + await self._do_writer_close() + now = time.time() time_since_last_login = now - self.last_login_time # type: ignore[operator] if time_since_last_login < self.config.reset_interval: _logger.debug("%ss since last reset attempt.", time_since_last_login) await asyncio.sleep(self.config.reset_interval - time_since_last_login) - await self._do_writer_close() - _logger.debug("Reestablishing connection") if not await self._connect_with_retry(): _logger.error( "Unable to connect to MCS endpoint " - + "after %s tries, shutting down" + + "after %s tries, shutting down", + self.config.connection_retry_count, ) self._terminate() return @@ -319,14 +315,14 @@ async def _login(self) -> None: self.last_login_time = now try: - android_id = self.credentials["gcm"]["androidId"] # type: ignore[index] + android_id = self.credentials["gcm"]["android_id"] # type: ignore[index] req = LoginRequest() req.adaptive_heartbeat = False req.auth_service = LoginRequest.ANDROID_ID # 2 - req.auth_token = self.credentials["gcm"]["securityToken"] # type: ignore[index] - req.id = "chrome-63.0.3234.0" + req.auth_token = self.credentials["gcm"]["security_token"] # type: ignore[index] + req.id = self.fcm_config.chrome_version req.domain = "mcs.android.com" - req.device_id = "android-%x" % int(android_id) + req.device_id = f"android-{int(android_id):x}" req.network_type = 1 req.resource = android_id req.user = android_id @@ -376,18 +372,20 @@ def _decrypt_raw_data( ) return decrypted - def _app_data_by_key(self, p: DataMessageStanza, key: str) -> str: + def _app_data_by_key( + self, p: DataMessageStanza, key: str, do_not_raise: bool = False + ) -> str: for x in p.app_data: if x.key == key: return x.value + if do_not_raise: + return "" raise RuntimeError(f"couldn't find in app_data {key}") def _handle_data_message( self, - callback: Optional[OnNotificationCallable], msg: DataMessageStanza, - obj: Any, ) -> None: _logger.debug( "Received data message Stream ID: %s, Last: %s, Status: %s", @@ -396,15 +394,23 @@ def _handle_data_message( msg.status, ) + if ( + self._app_data_by_key(msg, "message_type", do_not_raise=True) + == "deleted_messages" + ): + # The deleted_messages message does not contain data. + return crypto_key = self._app_data_by_key(msg, "crypto-key")[3:] # strip dh= salt = self._app_data_by_key(msg, "encryption")[5:] # strip salt= subtype = self._app_data_by_key(msg, "subtype") - if subtype != self.app_id: + if TYPE_CHECKING: + assert self.credentials + if subtype != self.credentials["gcm"]["app_id"]: self._log_warn_with_limit( "Subtype %s in data message does not match" + "app id client was registered with %s", subtype, - self.app_id, + self.credentials["gcm"]["app_id"], ) if not self.credentials: return @@ -418,39 +424,12 @@ def _handle_data_message( self._log_verbose( "Decrypted data for message %s is: %s", msg.persistent_id, ret_val ) - if callback and self.listen_event_loop != self.callback_event_loop: - if ( - callback is not None - and self.callback_event_loop - and self.callback_event_loop.is_running() - ): - on_error = functools.partial( - self._try_increment_error_count, ErrorType.NOTIFY - ) - on_success = functools.partial( - self._reset_error_count, ErrorType.NOTIFY - ) - self.callback_event_loop.call_soon_threadsafe( - functools.partial( - FcmPushClient._wrapped_callback, - self.listen_event_loop, - on_error, - on_success, - callback, - ret_val, - msg.persistent_id, - obj, - ) - ) - elif callback: - try: - callback(ret_val, msg.persistent_id, obj) - self._reset_error_count(ErrorType.NOTIFY) - except Exception: - _logger.exception( - "Unexpected exception calling notification callback\n" - ) - self._try_increment_error_count(ErrorType.NOTIFY) + try: + self.callback(ret_val, msg.persistent_id, self.callback_context) + self._reset_error_count(ErrorType.NOTIFY) + except Exception: + _logger.exception("Unexpected exception calling notification callback\n") + self._try_increment_error_count(ErrorType.NOTIFY) def _new_input_stream_id_available(self) -> bool: return self.last_input_stream_id_reported != self.input_stream_id @@ -489,7 +468,6 @@ async def _send_selective_ack(self, persistent_id: str) -> None: iqs = IqStanza() iqs.type = IqStanza.IqType.SET iqs.id = "" - # iqs.extension = Extension() iqs.extension.id = MCS_SELECTIVE_ACK_ID sa = SelectiveAck() sa.id.extend([persistent_id]) @@ -517,18 +495,10 @@ def _terminate(self) -> None: ): # cancel return if task is done so no need to check task.cancel() - async def _do_monitor(self, callback: Optional[OnNotificationCallable]) -> None: + async def _do_monitor(self) -> None: while self.do_listen: await asyncio.sleep(self.config.monitor_interval) - if callback and ( - not self.callback_event_loop - or not self.callback_event_loop.is_running() - ): - _logger.debug("Callback loop no longer running, terminating FcmClient") - self._terminate() - return - if self.run_state == FcmPushClientRunState.STARTED: # if server_heartbeat_interval is set and less than # client_heartbeat_interval then the last_message_time @@ -581,9 +551,7 @@ def _try_increment_error_count(self, error_type: ErrorType) -> bool: return False return True - async def _handle_message( - self, msg: Message, callback: Optional[OnNotificationCallable], obj: Any - ) -> None: + async def _handle_message(self, msg: Message) -> None: self.last_message_time = time.time() self.input_stream_id += 1 @@ -606,7 +574,7 @@ async def _handle_message( return if isinstance(msg, DataMessageStanza): - self._handle_data_message(callback, msg, obj) + self._handle_data_message(msg) self.persistent_ids.append(msg.persistent_id) if self.config.send_selective_acknowledgements: await self._send_selective_ack(msg.persistent_id) @@ -624,14 +592,17 @@ async def _handle_message( @staticmethod async def _open_connection( - host: str, port: int, ssl: bool + host: str, port: int, ssl_context: ssl.SSLContext ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: - return await asyncio.open_connection(host=host, port=port, ssl=ssl) + return await asyncio.open_connection(host=host, port=port, ssl=ssl_context) async def _connect(self) -> bool: try: + loop = asyncio.get_running_loop() + # create_default_context() blocks the event loop + ssl_context = await loop.run_in_executor(None, ssl.create_default_context) self.reader, self.writer = await self._open_connection( - host=MCS_HOST, port=MCS_PORT, ssl=True + host=MCS_HOST, port=MCS_PORT, ssl_context=ssl_context ) _logger.debug("Connected to MCS endpoint (%s,%s)", MCS_HOST, MCS_PORT) return True @@ -673,41 +644,20 @@ async def _connect_with_retry(self) -> bool: ) return connected - async def _listen( - self, callback: Optional[OnNotificationCallable], obj: Any = None - ) -> None: - """ - listens for push notifications - - callback(obj, notification, data_message): called on notifications - obj: optional arbitrary value passed to callback - """ - + async def _listen(self) -> None: + """listens for push notifications.""" if not await self._connect_with_retry(): return try: await self._login() - while ( - self.do_listen - and self.listen_event_loop - and self.listen_event_loop.is_running() - ): - if callback and ( - not self.callback_event_loop - or not self.callback_event_loop.is_running() - ): - _logger.debug( - "Callback loop no longer running, terminating FcmClient" - ) - self._terminate() - return + while self.do_listen: try: if self.run_state == FcmPushClientRunState.RESETTING: await asyncio.sleep(1) elif msg := await self._receive_msg(): - await self._handle_message(msg, callback, obj) + await self._handle_message(msg) except (OSError, EOFError) as osex: if ( @@ -717,13 +667,13 @@ async def _listen( ConnectionResetError, TimeoutError, asyncio.IncompleteReadError, - SSLError, + ssl.SSLError, ), ) and self.run_state == FcmPushClientRunState.RESETTING ): if ( - isinstance(osex, SSLError) # pylint: disable=no-member + isinstance(osex, ssl.SSLError) # pylint: disable=no-member and osex.reason != "APPLICATION_DATA_AFTER_CLOSE_NOTIFY" ): self._log_warn_with_limit( @@ -739,9 +689,6 @@ async def _listen( _logger.exception("Unexpected exception during read\n") if self._try_increment_error_count(ErrorType.CONNECTION): await self._reset() - - except asyncio.CancelledError as cex: - raise cex except Exception as ex: _logger.error( "Unknown error: %s, shutting down FcmPushClient.\n%s", @@ -752,58 +699,7 @@ async def _listen( finally: await self._do_writer_close() - async def _run_tasks( - self, callback: Optional[OnNotificationCallable], obj: Any - ) -> None: - self.reset_lock = asyncio.Lock() - self.stopping_lock = asyncio.Lock() - self.do_listen = True - self.run_state = FcmPushClientRunState.STARTING_TASKS - try: - self.tasks = [ - asyncio.create_task(self._listen(callback, obj)), - asyncio.create_task(self._do_monitor(callback)), - ] - await asyncio.gather(*self.tasks, return_exceptions=True) - _logger.info("FCMClient has shutdown") - except Exception as ex: - _logger.error("Unexpected error running FcmPushClient: %s", ex) - - def _start_on_new_loop( - self, callback: Optional[OnNotificationCallable], obj: Any - ) -> None: - self.listen_event_loop = asyncio.new_event_loop() - if not self.callback_event_loop: - self.callback_event_loop = self.listen_event_loop - - asyncio.set_event_loop(self.listen_event_loop) - self.listen_event_loop.run_until_complete(self._run_tasks(callback, obj)) - - def _start_on_existing_loop( - self, callback: Optional[OnNotificationCallable], obj: Any - ) -> None: - self.listen_event_loop.create_task(self._run_tasks(callback, obj)) # type: ignore[union-attr] - - @staticmethod - def _wrapped_callback( - fcm_client_loop: asyncio.AbstractEventLoop, - on_error: functools.partial, - on_success: functools.partial, - callback: OnNotificationCallable, - notification: Dict[str, Dict[str, str]], - persistent_id: str, - obj: Any, - ) -> None: # pylint: disable=too-many-arguments - # Should be running on callback loop - - try: - callback(notification, persistent_id, obj) - fcm_client_loop.call_soon_threadsafe(on_success) - except Exception: - _logger.exception("Unexpected exception calling notification callback\n") - fcm_client_loop.call_soon_threadsafe(on_error) - - async def checkin(self, sender_id: int, app_id: str) -> str: + async def checkin_or_register(self) -> str: """Check in if you have credentials otherwise register as a new client. :param sender_id: sender id identifying push service you are connecting to. @@ -811,69 +707,32 @@ async def checkin(self, sender_id: int, app_id: str) -> str: :return: The FCM token which is used to identify you with the push end point application. """ - self.app_id = app_id - self.sender_id = sender_id - register = FcmRegister( + self.register = FcmRegister( + self.fcm_config, self.credentials, self.credentials_updated_callback, http_client_session=self._http_client_session, ) - self.credentials = await register.checkin(sender_id, app_id) - await register.close() - return self.credentials["fcm"]["token"] - - def start( - self, - callback: Optional[Callable[[dict, str, Optional[Any]], None]], - obj: Any = None, - *, - listen_event_loop: Optional[asyncio.AbstractEventLoop] = None, - callback_event_loop: Optional[asyncio.AbstractEventLoop] = None, - ) -> None: - """Connect to FCM and start listening for push - messages on a seperate service thread. - - :param callback: Optional callback to call when a message is received. - Will callback on the loop used to start the connection. - Callback expects parameters of: - dict: which will be a decrypted dictionary of the war payload.\n - persistent_id: unique message identifier from the FCM server.\n - obj: returns the arbitrary object if supplied to this function. - :param obj: Arbitrary object to be returned in the callback. - :param listen_event_loop: If supplied the client will use this event loop - for asyncio communication with the fcm server, otherwise it will create - it's own thread and start an event loop on it. - :param callback_event_loop: If supplied the client will run the callback - on the supplied loop, otherwise it will run the callback on it's own - thread loop or the listen_event_loop if set. - """ - self.listen_event_loop = listen_event_loop - self.callback_event_loop = callback_event_loop + self.credentials = await self.register.checkin_or_register() + # await self.register.fcm_refresh_install() + await self.register.close() + return self.credentials["fcm"]["registration"]["token"] - if self.listen_event_loop: - if not self.callback_event_loop: - self.callback_event_loop = self.listen_event_loop - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop == self.listen_event_loop: - self._start_on_existing_loop(callback, obj) - else: - self.listen_event_loop.call_soon_threadsafe( - self._start_on_existing_loop, callback, obj - ) - else: - self.fcm_thread = Thread( - target=self._start_on_new_loop, - args=[callback, obj], - daemon=True, - name="FcmClientThread", - ) - self.fcm_thread.start() + async def start(self) -> None: + """Connect to FCM and start listening for push notifications.""" + self.reset_lock = asyncio.Lock() + self.stopping_lock = asyncio.Lock() + self.do_listen = True + self.run_state = FcmPushClientRunState.STARTING_TASKS + try: + self.tasks = [ + asyncio.create_task(self._listen()), + asyncio.create_task(self._do_monitor()), + ] + except Exception as ex: + _logger.error("Unexpected error running FcmPushClient: %s", ex) - async def _stop_connection(self) -> None: + async def stop(self) -> None: if ( self.stopping_lock and self.stopping_lock.locked() @@ -903,40 +762,9 @@ async def _stop_connection(self) -> None: def is_started(self) -> bool: return self.run_state == FcmPushClientRunState.STARTED - def stop(self) -> None: - """Disconnects from FCM and shuts down the service thread.""" - if self.fcm_thread: - if ( - self.listen_event_loop - and self.listen_event_loop.is_running() - and self.fcm_thread.is_alive() - ): - _logger.debug("Shutting down FCMClient") - asyncio.run_coroutine_threadsafe( - self._stop_connection(), self.listen_event_loop - ) - - elif self.listen_event_loop and self.listen_event_loop.is_running(): - self.listen_event_loop.create_task(self._stop_connection()) - - async def _send_data_message(self, raw_data_: bytes, persistent_id: str) -> None: + async def send_message(self, raw_data: bytes, persistent_id: str) -> None: + """Not implemented, does nothing atm.""" dms = DataMessageStanza() dms.persistent_id = persistent_id # Not supported yet - - def send_message(self, raw_data: bytes, persistent_id: str) -> None: - """Not implemented, does nothing atm.""" - if self.fcm_thread: - asyncio.run_coroutine_threadsafe( - self._send_data_message(raw_data, persistent_id), - self.listen_event_loop, # type: ignore[arg-type] - ) - else: - self.listen_event_loop.create_task( # type: ignore[union-attr] - self._send_data_message(raw_data, persistent_id) - ) - - def __del__(self) -> None: - if self.listen_event_loop and self.listen_event_loop.is_running(): - self.stop() diff --git a/firebase_messaging/fcmregister.py b/firebase_messaging/fcmregister.py index 7495a5d..4f41a60 100644 --- a/firebase_messaging/fcmregister.py +++ b/firebase_messaging/fcmregister.py @@ -1,8 +1,15 @@ +from __future__ import annotations + import asyncio +import json import logging import os -from base64 import urlsafe_b64encode -from typing import Any, Callable, Dict, Optional, Union +import secrets +import time +import uuid +from base64 import b64encode, urlsafe_b64encode +from dataclasses import dataclass +from typing import Any, Callable from aiohttp import ClientSession from cryptography.hazmat.primitives import serialization @@ -10,11 +17,14 @@ from google.protobuf.json_format import MessageToDict, MessageToJson from .const import ( + AUTH_VERSION, + FCM_INSTALLATION, + FCM_REGISTRATION, FCM_SEND_URL, - FCM_SUBSCRIBE_URL, GCM_CHECKIN_URL, GCM_REGISTER_URL, GCM_SERVER_KEY_B64, + SDK_VERSION, ) from .proto.android_checkin_pb2 import ( DEVICE_CHROME_BROWSER, @@ -29,32 +39,49 @@ _logger = logging.getLogger(__name__) +@dataclass +class FcmRegisterConfig: + project_id: str + app_id: str + api_key: str + messaging_sender_id: str + bundle_id: str = "receiver.push.com" + chrome_id: str = "org.chromium.linux" + chrome_version: str = "94.0.4606.51" + vapid_key: str | None = GCM_SERVER_KEY_B64 + persistend_ids: list[str] | None = None + heartbeat_interval_ms: int = 5 * 60 * 1000 # 5 mins + + def __postinit__(self) -> None: + if self.persistend_ids is None: + self.persistend_ids = [] + + class FcmRegister: def __init__( self, - credentials: Optional[dict] = None, - credentials_updated_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + config: FcmRegisterConfig, + credentials: dict | None = None, + credentials_updated_callback: Callable[[dict[str, Any]], None] | None = None, *, - http_client_session: Optional[ClientSession] = None, + http_client_session: ClientSession | None = None, log_debug_verbose: bool = False, ): + self.config = config self.credentials = credentials self.credentials_updated_callback = credentials_updated_callback self._log_debug_verbose = log_debug_verbose self._http_client_session = http_client_session - self._local_session: Optional[ClientSession] = None - - self.app_id: Optional[str] = None - self.sender_id: Optional[int] = None + self._local_session: ClientSession | None = None def _get_checkin_payload( - self, android_id: Optional[int] = None, security_token: Optional[int] = None + self, android_id: int | None = None, security_token: int | None = None ) -> AndroidCheckinRequest: chrome = ChromeBuildProto() chrome.platform = ChromeBuildProto.Platform.PLATFORM_LINUX # 3 - chrome.chrome_version = "63.0.3234.0" + chrome.chrome_version = self.config.chrome_version chrome.channel = ChromeBuildProto.Channel.CHANNEL_STABLE # 1 checkin = AndroidCheckinProto() @@ -71,12 +98,20 @@ def _get_checkin_payload( return payload + async def gcm_check_in_and_register( + self, + ) -> dict[str, Any] | None: + options = await self.gcm_check_in() + if not options: + raise RuntimeError("Unable to register and check in to gcm") + gcm_credentials = await self.gcm_register(options) + return gcm_credentials + async def gcm_check_in( self, - android_id: Optional[int] = None, - security_token: Optional[int] = None, - log_debug_verbose: bool = False, - ) -> Optional[Dict[str, Any]]: + android_id: int | None = None, + security_token: int | None = None, + ) -> dict[str, Any] | None: """ perform check-in request @@ -88,7 +123,7 @@ async def gcm_check_in( payload = self._get_checkin_payload(android_id, security_token) - if log_debug_verbose: + if self._log_debug_verbose: _logger.debug("GCM check in payload:\n%s", payload) retries = 3 @@ -136,47 +171,50 @@ async def gcm_check_in( return None acir.ParseFromString(content) - if log_debug_verbose: + if self._log_debug_verbose: msg = MessageToJson(acir, indent=4) _logger.debug("GCM check in response (raw):\n%s", msg) return MessageToDict(acir) async def gcm_register( - self, app_id: str, retries: int = 5, log_debug_verbose: bool = False - ) -> Optional[Dict[str, str]]: + self, + options: dict[str, Any], + retries: int = 2, + ) -> dict[str, str] | None: """ obtains a gcm token app_id: app id as an integer retries: number of failed requests before giving up - returns {"token": "...", "appId": 123123, "androidId":123123, + returns {"token": "...", "gcm_app_id": 123123, "androidId":123123, "securityToken": 123123} """ # contains android_id, security_token and more - chk = await self.gcm_check_in(log_debug_verbose=log_debug_verbose) - if not chk: - raise RuntimeError("Unable to register and check in to gcm") - if log_debug_verbose: - _logger.debug("GCM check in response %s", chk) + gcm_app_id = f"wp:{self.config.bundle_id}#{uuid.uuid4()}" + android_id = options["androidId"] + security_token = options["securityToken"] + + headers = { + "Authorization": f"AidLogin {android_id}:{security_token}", + "Content-Type": "application/x-www-form-urlencoded", + } body = { "app": "org.chromium.linux", - "X-subtype": app_id, - "device": chk["androidId"], - # "sender": urlsafe_base64(GCM_SERVER_KEY), + "X-subtype": gcm_app_id, + "device": android_id, "sender": GCM_SERVER_KEY_B64, } - if log_debug_verbose: + if self._log_debug_verbose: _logger.debug("GCM Registration request: %s", body) - auth = "AidLogin {}:{}".format(chk["androidId"], chk["securityToken"]) - last_error: Optional[Union[str, Exception]] = None + last_error: str | Exception | None = None for try_num in range(retries): try: async with self._session.post( url=GCM_REGISTER_URL, - headers={"Authorization": auth}, + headers=headers, data=body, timeout=2, ) as resp: @@ -192,11 +230,14 @@ async def gcm_register( await asyncio.sleep(1) continue token = response_text.split("=")[1] - # get only the fields we need from the check in response - chkfields = {k: chk[k] for k in ["androidId", "securityToken"]} - res = {"token": token, "appId": app_id} - res.update(chkfields) - return res + + return { + "token": token, + "app_id": gcm_app_id, + "android_id": android_id, + "security_token": security_token, + } + except Exception as e: last_error = e _logger.warning( @@ -215,26 +256,106 @@ async def gcm_register( _logger.error(errorstr) return None - async def fcm_register( - self, - sender_id: int, - token: str, - retries: int = 5, - log_debug_verbose: bool = False, - ) -> Optional[Dict[str, Any]]: - """ - generates key pair and obtains a fcm token - - sender_id: sender id as an integer - token: the subscription token in the dict returned by gcm_register + async def fcm_install_and_register( + self, gcm_data: dict[str, Any], keys: dict[str, Any] + ) -> dict[str, Any] | None: + if installation := await self.fcm_install(): + registration = await self.fcm_register(gcm_data, installation, keys) + return { + "registration": registration, + "installation": installation, + } + return None - returns {"keys": keys, "fcm": {...}} - """ - # I used this analyzer to figure out how to slice the asn1 structs - # https://lapo.it/asn1js - # first byte of public key is skipped for some reason - # maybe it's always zero + async def fcm_install(self) -> dict | None: + fid = bytearray(secrets.token_bytes(17)) + # Replace the first 4 bits with the FID header 0b0111. + fid[0] = 0b01110000 + (fid[0] % 0b00010000) + fid64 = b64encode(fid).decode() + + hb_header = b64encode( + json.dumps({"heartbeats": [], "version": 2}).encode() + ).decode() + headers = { + "x-firebase-client": hb_header, + "x-goog-api-key": self.config.api_key, + } + payload = { + "appId": self.config.app_id, + "authVersion": AUTH_VERSION, + "fid": fid64, + "sdkVersion": SDK_VERSION, + } + url = FCM_INSTALLATION + f"projects/{self.config.project_id}/installations" + async with self._session.post( + url=url, + headers=headers, + data=json.dumps(payload), + timeout=2, + ) as resp: + if resp.status == 200: + fcm_install = await resp.json() + + return { + "token": fcm_install["authToken"]["token"], + "expires_in": int(fcm_install["authToken"]["expiresIn"][:-1:]), + "refresh_token": fcm_install["refreshToken"], + "fid": fcm_install["fid"], + "created_at": time.monotonic(), + } + else: + text = await resp.text() + _logger.error( + "Error during fcm_install: %s ", + text, + ) + return None + + async def fcm_refresh_install_token(self) -> dict | None: + hb_header = b64encode( + json.dumps({"heartbeats": [], "version": 2}).encode() + ).decode() + if not self.credentials: + raise RuntimeError("Credentials must be set to refresh install token") + fcm_refresh_token = self.credentials["fcm"]["installation"]["refresh_token"] + + headers = { + "Authorization": f"{AUTH_VERSION} {fcm_refresh_token}", + "x-firebase-client": hb_header, + "x-goog-api-key": self.config.api_key, + } + payload = { + "installation": { + "sdkVersion": SDK_VERSION, + "appId": self.config.app_id, + } + } + url = ( + FCM_INSTALLATION + f"projects/{self.config.project_id}/" + "installations/{fid}/authTokens:generate" + ) + async with self._session.post( + url=url, + headers=headers, + data=json.dumps(payload), + timeout=5, + ) as resp: + if resp.status == 200: + fcm_refresh = await resp.json() + return { + "token": fcm_refresh["token"], + "expires_in": int(fcm_refresh["expiresIn"][:-1:]), + "created_at": time.monotonic(), + } + else: + text = await resp.text() + _logger.error( + "Error during fcm_refresh_install_token: %s ", + text, + ) + return None + def generate_keys(self) -> dict: private_key = ec.generate_private_key(ec.SECP256R1()) public_key = private_key.public_key() @@ -248,31 +369,64 @@ async def fcm_register( format=serialization.PublicFormat.SubjectPublicKeyInfo, ) - keys = { + return { "public": urlsafe_b64encode(serialized_public[26:]).decode( "ascii" ), # urlsafe_base64(serialized_public[26:]), "private": urlsafe_b64encode(serialized_private).decode("ascii"), "secret": urlsafe_b64encode(os.urandom(16)).decode("ascii"), } - data = { - "authorized_entity": sender_id, - "endpoint": f"{FCM_SEND_URL}/{token}", - "encryption_key": keys["public"], - "encryption_auth": keys["secret"], + + async def fcm_register( + self, + gcm_data: dict, + installation: dict, + keys: dict, + retries: int = 2, + ) -> dict[str, Any] | None: + headers = { + "x-goog-api-key": self.config.api_key, + "x-goog-firebase-installations-auth": installation["token"], + } + # If vapid_key is the default do not send it here or it will error + vapid_key = ( + self.config.vapid_key + if self.config.vapid_key != GCM_SERVER_KEY_B64 + else None + ) + payload = { + "web": { + "applicationPubKey": vapid_key, + "auth": keys["secret"], + "endpoint": FCM_SEND_URL + gcm_data["token"], + "p256dh": keys["public"], + } } - if log_debug_verbose: - _logger.debug("FCM registration data: %s", data) + url = FCM_REGISTRATION + f"projects/{self.config.project_id}/registrations" + if self._log_debug_verbose: + _logger.debug("FCM registration data: %s", payload) for try_num in range(retries): try: async with self._session.post( - url=FCM_SUBSCRIBE_URL, - data=data, - timeout=2, + url=url, + headers=headers, + data=json.dumps(payload), + # timeout=2, ) as resp: - fcm = await resp.json() - return {"keys": keys, "fcm": fcm} + if resp.status == 200: + fcm = await resp.json() + return fcm + else: + text = await resp.text() + _logger.error( # pylint: disable=duplicate-code + "Error during fmc register request " + "attempt %s out of %s: %s", + try_num + 1, + retries, + text, + ) + except Exception as e: _logger.error( # pylint: disable=duplicate-code "Error during fmc register request attempt %s out of %s", @@ -283,7 +437,7 @@ async def fcm_register( await asyncio.sleep(1) return None - async def checkin(self, sender_id: int, app_id: str) -> Dict[str, Any]: + async def checkin_or_register(self) -> dict[str, Any]: """Check in if you have credentials otherwise register as a new client. :param sender_id: sender id identifying push service you are connecting to. @@ -291,24 +445,21 @@ async def checkin(self, sender_id: int, app_id: str) -> Dict[str, Any]: :return: The FCM token which is used to identify you with the push end point application. """ - self.sender_id = sender_id - self.app_id = app_id if self.credentials: gcm_response = await self.gcm_check_in( - self.credentials["gcm"]["androidId"], - self.credentials["gcm"]["securityToken"], - log_debug_verbose=self._log_debug_verbose, + self.credentials["gcm"]["android_id"], + self.credentials["gcm"]["security_token"], ) if gcm_response: return self.credentials - self.credentials = await self.register(sender_id, app_id) + self.credentials = await self.register() if self.credentials_updated_callback: self.credentials_updated_callback(self.credentials) return self.credentials - async def register(self, sender_id: int, app_id: str) -> Dict: + async def register(self) -> dict: """Register gcm and fcm tokens for sender_id. Typically you would call checkin instead of register which does not do a full registration @@ -318,26 +469,30 @@ async def register(self, sender_id: int, app_id: str) -> Dict: :param app_id: identifier for your application. :return: The dict containing all credentials. """ - self.sender_id = sender_id - self.app_id = app_id - subscription = await self.gcm_register( - app_id=app_id, log_debug_verbose=self._log_debug_verbose - ) - if subscription is None: + + keys = self.generate_keys() + + gcm_data = await self.gcm_check_in_and_register() + if gcm_data is None: raise RuntimeError( "Unable to establish subscription with Google Cloud Messaging." ) - self._log_verbose("GCM subscription: %s", subscription) - fcm = await self.fcm_register( - sender_id=sender_id, - token=subscription["token"], - log_debug_verbose=self._log_debug_verbose, - ) - if not fcm: + self._log_verbose("GCM subscription: %s", gcm_data) + + fcm_data = await self.fcm_install_and_register(gcm_data, keys) + if not fcm_data: raise RuntimeError("Unable to register with fcm") - self._log_verbose("FCM registration: %s", fcm) - res: Dict[str, Any] = {"gcm": subscription} - res.update(fcm) + self._log_verbose("FCM registration: %s", fcm_data) + res: dict[str, Any] = { + "keys": keys, + "gcm": gcm_data, + "fcm": fcm_data, + "config": { + "bundle_id": self.config.bundle_id, + "project_id": self.config.project_id, + "vapid_key": self.config.vapid_key, + }, + } self._log_verbose("Credential: %s", res) _logger.info("Registered with FCM") return res diff --git a/tests/conftest.py b/tests/conftest.py index 0311df4..467cf3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,12 +4,17 @@ import logging import os import threading +from unittest.mock import patch import pytest from aioresponses import CallbackResult, aioresponses from google.protobuf.json_format import Parse as JsonParse -from firebase_messaging.fcmpushclient import FcmPushClient, FcmPushClientConfig +from firebase_messaging.fcmpushclient import ( + FcmPushClient, + FcmPushClientConfig, + FcmRegisterConfig, +) from firebase_messaging.proto.checkin_pb2 import AndroidCheckinResponse from firebase_messaging.proto.mcs_pb2 import LoginResponse from tests.fakes import FakeMcsEndpoint @@ -36,59 +41,56 @@ def load_fixture_as_msg(filename, msg_class): @pytest.fixture() async def fake_mcs_endpoint(): - # async with McsEndpoint() as ep: - ep = FakeMcsEndpoint() - yield ep + fmce = FakeMcsEndpoint() - ep.close() + async def _mock_open_conn(*_, **__): + return fmce.client_reader, fmce.client_writer + with patch("asyncio.open_connection", side_effect=_mock_open_conn, autospec=True): + yield fmce -@pytest.fixture(params=[None, "loop"], ids=["loop_created", "loop_provided"]) + fmce.close() + + +@pytest.fixture() async def logged_in_push_client(request, fake_mcs_endpoint, mocker, caplog): clients = {} caplog.set_level(logging.DEBUG) - listen_loop = asyncio.get_running_loop() if request.param else None - async def _logged_in_push_client( - credentials, msg_callback, - callback_obj=None, - callback_loop=None, + credentials, *, + callback_obj=None, supress_disconnect=False, **config_kwargs, ): config = FcmPushClientConfig(**config_kwargs) - pr = FcmPushClient(credentials=credentials, config=config) - await pr.checkin(1234, 4321) - - cb_loop = asyncio.get_running_loop() if callback_loop else None - pr.start( + fcm_config = FcmRegisterConfig("project-1234", "bar", "foobar", "foobar") + pr = FcmPushClient( msg_callback, - callback_obj, - listen_event_loop=listen_loop, - callback_event_loop=cb_loop, + fcm_config, + credentials, + None, + callback_context=callback_obj, + config=config, ) + await pr.checkin_or_register() + + await pr.start() await fake_mcs_endpoint.get_message() lr = load_fixture_as_msg("login_response.json", LoginResponse) await fake_mcs_endpoint.put_message(lr) clients[pr] = supress_disconnect - tc = 1 if listen_loop else 2 - assert len(threading.enumerate()) == tc - if listen_loop: - assert pr.listen_event_loop == asyncio.get_running_loop() - else: - assert pr.listen_event_loop != asyncio.get_running_loop() return pr yield _logged_in_push_client for k, v in clients.items(): if not v: - k.stop() + await k.stop() @pytest.fixture(autouse=True, name="aioresponses_mock") @@ -105,7 +107,11 @@ def aioresponses_mock_fixture(): body=load_fixture("gcm_register_response.txt"), ) mock.post( - "https://fcm.googleapis.com/fcm/connect/subscribe", + "https://firebaseinstallations.googleapis.com/v1/projects/project-1234/installations", + payload=load_fixture_as_dict("fcm_install_response.json"), + ) + mock.post( + "https://fcmregistrations.googleapis.com/v1/projects/project-1234/registrations", payload=load_fixture_as_dict("fcm_register_response.json"), ) yield mock diff --git a/tests/fakes.py b/tests/fakes.py index eb9cbd4..686d208 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -13,63 +13,36 @@ class FakeMcsEndpoint: def __init__(self): - self.connection_mock = patch( - "asyncio.open_connection", side_effect=self.open_connection, autospec=True - ) - self.connection_mock.start() - - self.client_loop = None - self.init_loop = None - self.client_writer = None - self.client_reader = None - self.init_loop = asyncio.get_running_loop() + # self.connection_mock = patch( + # "asyncio.open_connection", side_effect=self.open_connection, autospec=True + # ) + # self.connection_mock.start() + self.client_writer = self.FakeWriter() + self.client_reader = self.FakeReader() def close(self): - self.connection_mock.stop() + # self.connection_mock.stop() + pass async def open_connection(self, *_, **__): # Queues should be created on the loop that will be accessing them self.client_writer = self.FakeWriter() self.client_reader = self.FakeReader() - self.client_loop = asyncio.get_running_loop() return self.client_reader, self.client_writer - async def wait_for_connection(self, timeout=10): - async with asyncio_timeout(timeout): - while not self.client_loop: - await asyncio.sleep(0.1) - async def put_message(self, message): - await self.wait_for_connection() - if self.init_loop != self.client_loop: - asyncio.run_coroutine_threadsafe( - self.client_reader.put_message(message), self.client_loop - ) - else: - await self.client_reader.put_message(message) + await self.client_reader.put_message(message) async def put_error(self, error): - await self.wait_for_connection() - if self.init_loop != self.client_loop: - asyncio.run_coroutine_threadsafe( - self.client_reader.put_error(error), self.client_loop - ) - else: - await self.client_reader.put_error(error) + await self.client_reader.put_error(error) async def get_message(self): - await self.wait_for_connection() - if self.init_loop != self.client_loop: - fut = asyncio.run_coroutine_threadsafe( - self.client_writer.get_message(), self.client_loop - ) - return fut.result() - else: - return await self.client_writer.get_message() + return await self.client_writer.get_message() class FakeReader: def __init__(self): self.queue = asyncio.Queue() + self.lock = asyncio.Lock() async def readexactly(self, size): if size == 0: @@ -85,17 +58,20 @@ async def readexactly(self, size): async def put_message(self, message): include_version = isinstance(message, LoginResponse) packet = FcmPushClient._make_packet(message, include_version) - for p in packet: - b = bytes([p]) - await self.queue.put(b) + async with self.lock: + for p in packet: + b = bytes([p]) + await self.queue.put(b) async def put_error(self, error): - await self.queue.put(error) + async with self.lock: + await self.queue.put(error) class FakeWriter: def __init__(self): self.queue = asyncio.Queue() self.buf = "" + self.lock = asyncio.Lock() def write(self, buffer): for i in buffer: @@ -112,10 +88,11 @@ async def wait_closed(self): pass async def get_bytes(self, size): - val = b"" - for _ in range(size): - val += await self.queue.get() - return val + async with self.lock: + val = b"" + for _ in range(size): + val += await self.queue.get() + return val async def get_message(self, timeout=2): async with asyncio_timeout(timeout): diff --git a/tests/fixtures/credentials.json b/tests/fixtures/credentials.json index f4999b1..edd6514 100644 --- a/tests/fixtures/credentials.json +++ b/tests/fixtures/credentials.json @@ -1,9 +1,9 @@ { "gcm": { "token": "XYZ01234zyx:APA91b0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz-0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ-01234abcdefghijklmnopqrstuvwxyz", - "appId": "abcdef01-ef12-1234-12ef-abcdef012345", - "androidId": "5678901234567890123", - "securityToken": "0123456789012345678" + "app_id": "abcdef01-ef12-1234-12ef-abcdef012345", + "android_id": "5678901234567890123", + "security_token": "0123456789012345678" }, "keys": { "public": "BPEHm32RWI4db8FCk0IM6G9f9vz_uJeRiuU64Y5dkyZjkBXmyGgzwzZMylPaLNvg50EuoQmNlU7sSMUf0mYctn0=", @@ -11,7 +11,26 @@ "secret": "_zGHqwp9rRP5cgzilMLvCA" }, "fcm": { - "token": "zyxXYZ01234:APA91b01234abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789_abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789", - "pushSet": "01234zyxXYZ:APA91b43210zyxwvutsrqpomlkjihgfedcba_ZYXWVUTSRQPONMLKJIHGFEDCBA-9876543210_zyxwvutsrqpomlkjihgfedcba_ZYXWVUTSRQPONMLKJIHGFEDCBA-9876543210" + "registration": { + "name": "projects/project-1234/registrations/abcdefghijklmnopqrstuvwxyz-0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ-01234abcdefghijklmnopqrstuvwxyz_XYZ01234zyx:APA91b0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ", + "token": "abcdefghijklmnopqrstuvwxyz-0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ-01234abcdefghijklmnopqrstuvwxyz_XYZ01234zyx:APA91b0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ", + "web": { + "endpoint": "https://fcm.googleapis.com/fcm/send/XYZ01234zyx:APA91b0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz-0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ-01234abcdefghijklmnopqrstuvwxyz", + "p256dh": "123456967219824KDJDOWFNAFW=", + "auth": "DSFNDSAKGFAGFA==" + } + }, + "installation": { + "token": "OPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijkl", + "expires_in": 604800, + "refresh_token": "1_OPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijkl", + "fid": "1234AMXDRTTYODLsd-nb14", + "created_at": 36245.436300371 + } + }, + "config": { + "bundle_id": "project.push.com", + "project_id": "project-1234", + "vapid_key": "BDOU99-h67HcA6JeFXHbSNMu7e2yNNu3RzoMj8TM4W88jITfq7ZmPvIM1Iv-4_l2LxQcYwhqby2xGpWwzjfAnG4" } } \ No newline at end of file diff --git a/tests/fixtures/fcm_install_response.json b/tests/fixtures/fcm_install_response.json new file mode 100644 index 0000000..673b0fe --- /dev/null +++ b/tests/fixtures/fcm_install_response.json @@ -0,0 +1,9 @@ +{ + "name": "projects/123456789101/installations/1234AMXDRTTYODLsd-nb14", + "fid": "1234AMXDRTTYODLsd-nb14", + "refreshToken": "1_OPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijkl", + "authToken": { + "token": "OPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijklOPQRSTUVWXYZ-01234abcdefghijkl", + "expiresIn": "604800s" + } +} \ No newline at end of file diff --git a/tests/fixtures/fcm_register_response.json b/tests/fixtures/fcm_register_response.json index 26aea9e..92f1600 100644 --- a/tests/fixtures/fcm_register_response.json +++ b/tests/fixtures/fcm_register_response.json @@ -1,4 +1,9 @@ { - "token": "zyxXYZ01234:APA91b01234abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789_abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789", - "pushSet": "01234zyxXYZ:APA91b43210zyxwvutsrqponmlkjihgfedcba_ZYXWVUTSRQPONMLKJIHGFEDCBA-9876543210_zyxwvutsrqponmlkjihgfedcba_ZYXWVUTSRQPONMLKJIHGFEDCBA-9876543210" + "name": "projects/project-1234/registrations/abcdefghijklmnopqrstuvwxyz-0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ-01234abcdefghijklmnopqrstuvwxyz_XYZ01234zyx:APA91b0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ", + "token": "abcdefghijklmnopqrstuvwxyz-0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ-01234abcdefghijklmnopqrstuvwxyz_XYZ01234zyx:APA91b0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ", + "web": { + "endpoint": "https://fcm.googleapis.com/fcm/send/XYZ01234zyx:APA91b0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz-0123456789_ABCDEFGHIJKLMNOPQRSTUVWXYZ-01234abcdefghijklmnopqrstuvwxyz", + "p256dh": "123456967219824KDJDOWFNAFW=", + "auth": "DSFNDSAKGFAGFA==" + } } \ No newline at end of file diff --git a/tests/test_fcmpushclient.py b/tests/test_fcmpushclient.py index 514fdff..43a4876 100644 --- a/tests/test_fcmpushclient.py +++ b/tests/test_fcmpushclient.py @@ -7,7 +7,7 @@ from cryptography.hazmat.primitives.serialization import load_der_private_key from http_ece import encrypt -from firebase_messaging import FcmPushClient +from firebase_messaging import FcmPushClient, FcmRegisterConfig from firebase_messaging.proto.mcs_pb2 import ( Close, DataMessageStanza, @@ -20,22 +20,10 @@ async def test_register(): - pr = FcmPushClient(credentials=None) - await pr.checkin(1234, 4321) - - -async def test_no_disconnect(logged_in_push_client, fake_mcs_endpoint, mocker, caplog): - pr = await logged_in_push_client(None, None, supress_disconnect=True) - - pr.__del__() - await asyncio.sleep(0.1) - assert ( - len([record for record in caplog.records if record.levelname == "ERROR"]) == 0 + pr = FcmPushClient( + None, FcmRegisterConfig("project-1234", "bar", "foobar", "foobar"), None ) - - assert "FCMClient has shutdown" in [ - record.message for record in caplog.records if record.levelname == "INFO" - ] + await pr.checkin_or_register() async def test_login(logged_in_push_client, fake_mcs_endpoint, mocker, caplog): @@ -50,11 +38,8 @@ async def test_login(logged_in_push_client, fake_mcs_endpoint, mocker, caplog): ] -@pytest.mark.parametrize( - "callback_loop", [None, "loop"], ids=["no_cb_loop_param", "cb_loop_param"] -) async def test_data_message_receive( - logged_in_push_client, fake_mcs_endpoint, mocker, caplog, callback_loop + logged_in_push_client, fake_mcs_endpoint, mocker, caplog ): notification = None persistent_id = None @@ -73,8 +58,8 @@ def on_msg(ntf, psid, obj=None): credentials = load_fixture_as_dict("credentials.json") obj = "Foobar" - cb_loop_param = asyncio.get_running_loop() if callback_loop else None - pr = await logged_in_push_client(credentials, on_msg, obj, cb_loop_param) + + await logged_in_push_client(on_msg, credentials, callback_obj=obj) dms = load_fixture_as_msg("data_message_stanza.json", DataMessageStanza) await fake_mcs_endpoint.put_message(dms) @@ -89,11 +74,6 @@ def on_msg(ntf, psid, obj=None): assert persistent_id == dms.persistent_id assert obj == callback_obj - if callback_loop: - assert cb_loop == asyncio.get_running_loop() - else: - assert cb_loop == pr.listen_event_loop - async def test_connection_reset(logged_in_push_client, fake_mcs_endpoint, mocker): # ConnectionResetError, TimeoutError, SSLError @@ -101,14 +81,12 @@ async def test_connection_reset(logged_in_push_client, fake_mcs_endpoint, mocker None, None, abort_on_sequential_error_count=3, reset_interval=0.1 ) - mocker.patch.object(FcmPushClient, "_reset", wraps=pr._reset) - - assert pr._reset.call_count == 0 + reset_spy = mocker.spy(pr, "_reset") await fake_mcs_endpoint.put_error(ConnectionResetError()) await asyncio.sleep(0.1) - assert pr._reset.call_count == 1 + assert reset_spy.call_count == 1 msg = await fake_mcs_endpoint.get_message() assert isinstance(msg, LoginRequest) @@ -118,15 +96,12 @@ async def test_connection_reset(logged_in_push_client, fake_mcs_endpoint, mocker async def test_terminate( logged_in_push_client, fake_mcs_endpoint, mocker, error_count, caplog ): - # ConnectionResetError, TimeoutError, SSLError pr = await logged_in_push_client( None, None, abort_on_sequential_error_count=error_count, reset_interval=0 ) - mocker.patch.object(FcmPushClient, "_reset", wraps=pr._reset) - mocker.patch.object(FcmPushClient, "_terminate", wraps=pr._terminate) - - assert pr._reset.call_count == 0 + reset_spy = mocker.spy(pr, "_reset") + term_spy = mocker.spy(pr, "_terminate") for i in range(1, error_count + 1): await fake_mcs_endpoint.put_error(ConnectionResetError()) @@ -134,13 +109,13 @@ async def test_terminate( await asyncio.sleep(0.1) # client should reset while it gets errors < abort_on_sequential_error_count then it should terminate if i < error_count: - assert pr._reset.call_count == i - assert pr._terminate.call_count == 0 + assert reset_spy.call_count == i + assert term_spy.call_count == 0 msg = await fake_mcs_endpoint.get_message() assert isinstance(msg, LoginRequest) else: - assert pr._reset.call_count == i - 1 - assert pr._terminate.call_count == 1 + assert reset_spy.call_count == i - 1 + assert term_spy.call_count == 1 async def test_heartbeat_receive(logged_in_push_client, fake_mcs_endpoint, caplog):