diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index f46a2b76..50a9da96 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -14,9 +14,9 @@ jobs: uses: ./.github/workflows/test.yml publish: name: Publish to Docker Hub + if: github.repository == 'ParadoxAlarmInterface/pai' && github.ref == 'refs/heads/dev' uses: ./.github/workflows/publish_docker.yml needs: test - if: github.repository_owner == 'ParadoxAlarmInterface' secrets: DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7db3af8e..e470ec93 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,7 @@ jobs: name: Publish to Docker Hub uses: ./.github/workflows/publish_docker.yml needs: test - if: github.repository_owner == 'ParadoxAlarmInterface' + if: github.repository == 'ParadoxAlarmInterface/pai' secrets: DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} @@ -20,6 +20,6 @@ jobs: name: Publish to PyPI uses: ./.github/workflows/publish_pypi.yml needs: test - if: github.repository_owner == 'ParadoxAlarmInterface' + if: github.repository == 'ParadoxAlarmInterface/pai' secrets: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4099e599..b7b4822b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,13 +16,13 @@ repos: - id: trailing-whitespace - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 + rev: v3.17.0 hooks: - id: pyupgrade args: ["--py37-plus"] - repo: https://github.com/psf/black - rev: 24.4.0 + rev: 24.8.0 hooks: - id: black args: @@ -35,7 +35,7 @@ repos: - id: isort - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + rev: 7.1.1 hooks: - id: flake8 additional_dependencies: [flake8-bugbear] diff --git a/README.md b/README.md index 060c98d1..100901eb 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,8 @@ Can be looked up in Babyware (_Right click on a panel ⇾ Properties ⇾ PC Comm We do not recommend using SWAN because of https://github.com/CriticalSecurity/paradox ## Firmware Upgrade WARNING: -**Do not upgrade EVO firmware versions to 7.50.000+ if you use Serial connection. Process is irreversible! Paradox introduces serial communication encryption which most probably will break our PAI ability to talk to the panel.** +**Do not upgrade EVO firmware versions to 7.50.000+ if you use Serial connection. Process is irreversible! Paradox introduces serial communication encryption which most probably will break our PAI ability to talk to the panel.** +Note: Paradox sells unlock code to re-enable the unencrypted serial port. ## How to use See [wiki](https://github.com/ParadoxAlarmInterface/pai/wiki/Installation) diff --git a/config/pai.conf.example b/config/pai.conf.example index 65836a5b..56d54f45 100644 --- a/config/pai.conf.example +++ b/config/pai.conf.example @@ -114,6 +114,14 @@ import logging ### MQTT Home Assistant Auto Discovery # MQTT_HOMEASSISTANT_AUTODISCOVERY_ENABLE = False # MQTT_HOMEASSISTANT_CODE = None +# HOMEASSISTANT_PUBLISH_PARTITION_PROPERTIES = [ +# 'target_state', +# 'current_state' +# ] +# HOMEASSISTANT_PUBLISH_ZONE_PROPERTIES = [ +# 'open', +# 'tamper' +# ] # ### Dash App # MQTT_DASH_PUBLISH = False @@ -137,16 +145,11 @@ import logging # ### Home Assistant Notifications (HASS.io required) # HOMEASSISTANT_NOTIFICATIONS_ENABLE = False +# HOMEASSISTANT_NOTIFICATIONS_API_URL = "http://supervisor/core/api/services/:domain/:service" +# HOMEASSISTANT_NOTIFICATIONS_API_TOKEN = "" # Long-Lived Access Token. Required if you do not use HA Supervisor +# HOMEASSISTANT_NOTIFICATIONS_LOVELACE_URI = "" # URI to open when notification is clicked # HOMEASSISTANT_NOTIFICATIONS_NOTIFIER_NAME = 'notify' # HOMEASSISTANT_NOTIFICATIONS_MIN_EVENT_LEVEL = 'INFO' -# HOMEASSISTANT_PUBLISH_PARTITION_PROPERTIES = [ -# 'target_state', -# 'current_state' -# ] -# HOMEASSISTANT_PUBLISH_ZONE_PROPERTIES = [ -# 'open', -# 'tamper' -# ] ## Event filtering by tags: # HOMEASSISTANT_NOTIFICATIONS_EVENT_FILTERS = [ # list of tags to include or exclude see hardware event.py for tag list # 'live,alarm,-restore', # or diff --git a/paradox/config.py b/paradox/config.py index aac33cb7..7ea12199 100644 --- a/paradox/config.py +++ b/paradox/config.py @@ -164,14 +164,6 @@ class Config: "armed_away": "arm", "disarmed": "disarm", }, - # Home Assistant Notifications (HASS.io required) - "HOMEASSISTANT_NOTIFICATIONS_ENABLE": False, - "HOMEASSISTANT_NOTIFICATIONS_NOTIFIER_NAME": "notify", - "HOMEASSISTANT_NOTIFICATIONS_MIN_EVENT_LEVEL": ( - "INFO", - str, - ["DEBUG", "INFO", "WARN", "ERROR", "CRITICAL"], - ), "HOMEASSISTANT_PUBLISH_PARTITION_PROPERTIES": [ # List of partition properties to publish "target_state", "current_state", @@ -180,6 +172,17 @@ class Config: "open", "tamper", ], + # Home Assistant Notifications (HASS.io required) + "HOMEASSISTANT_NOTIFICATIONS_ENABLE": False, + "HOMEASSISTANT_NOTIFICATIONS_API_URL": "http://supervisor/core/api/services/:domain/:service", + "HOMEASSISTANT_NOTIFICATIONS_API_TOKEN": "", # Authentication token used for Home Assistant if not using Supervisor + "HOMEASSISTANT_NOTIFICATIONS_LOVELACE_URI": "", # URI to open when notification is clicked + "HOMEASSISTANT_NOTIFICATIONS_NOTIFIER_NAME": "notify", + "HOMEASSISTANT_NOTIFICATIONS_MIN_EVENT_LEVEL": ( + "INFO", + str, + ["DEBUG", "INFO", "WARN", "ERROR", "CRITICAL"], + ), "HOMEASSISTANT_NOTIFICATIONS_IGNORE_EVENTS": [], # List of tuples or regexp matching "type,label,property=value,property2=value" eg. [(major, minor), "zone:HOME:entry_delay=True", ...] "HOMEASSISTANT_NOTIFICATIONS_ALLOW_EVENTS": [], # Same as before but as a white list. Default is use EVENT_FILTERS "HOMEASSISTANT_NOTIFICATIONS_EVENT_FILTERS": [ # list of tags, property changes to include or exclude. See event.py for tag list diff --git a/paradox/connections/ip/connection.py b/paradox/connections/ip/connection.py index 46f81265..f0fd7519 100644 --- a/paradox/connections/ip/connection.py +++ b/paradox/connections/ip/connection.py @@ -60,7 +60,7 @@ def __init__(self, host="127.0.0.1", port=10000): self.port = port async def _try_connect(self): - _, self._protocol = await self.loop.create_connection( + _, self._protocol = await asyncio.get_event_loop().create_connection( self._make_protocol, host=self.host, port=self.port ) @@ -90,7 +90,9 @@ def set_key(self, value): self._protocol.key = value def on_ip_message(self, container: Container): - return self.loop.create_task(self.ip_handler_registry.handle(container)) + return asyncio.get_event_loop().create_task( + self.ip_handler_registry.handle(container) + ) async def wait_for_ip_message(self, timeout=cfg.IO_TIMEOUT) -> Container: future = FutureHandler() @@ -115,7 +117,7 @@ def __init__( self.port = port async def _try_connect(self) -> None: - _, self._protocol = await self.loop.create_connection( + _, self._protocol = await asyncio.get_event_loop().create_connection( self._make_protocol, host=self.host, port=self.port ) @@ -146,7 +148,7 @@ def write(self, data: bytes): async def _try_connect(self) -> None: await self.stun_session.connect() - _, self._protocol = await self.loop.create_connection( + _, self._protocol = await asyncio.get_event_loop().create_connection( self._make_protocol, sock=self.stun_session.get_socket() ) diff --git a/paradox/connections/serial_connection.py b/paradox/connections/serial_connection.py index 537307a8..3234b084 100644 --- a/paradox/connections/serial_connection.py +++ b/paradox/connections/serial_connection.py @@ -1,13 +1,10 @@ -# -*- coding: utf-8 -*- - - +import asyncio import logging import os import stat -import typing -import serial_asyncio from serial import SerialException +import serial_asyncio from ..exceptions import SerialConnectionOpenFailed from .connection import Connection @@ -67,12 +64,12 @@ async def connect(self) -> bool: logger.error(f"Failed to update file {self.port_path} permissions") return False - self.connected_future = self.loop.create_future() - open_timeout_handler = self.loop.call_later(5, self.open_timeout) + self.connected_future = asyncio.get_event_loop().create_future() + open_timeout_handler = asyncio.get_event_loop().call_later(5, self.open_timeout) try: _, self._protocol = await serial_asyncio.create_serial_connection( - self.loop, self.make_protocol, self.port_path, self.baud + asyncio.get_event_loop(), self.make_protocol, self.port_path, self.baud ) return await self.connected_future @@ -81,7 +78,7 @@ async def connect(self) -> bool: raise SerialConnectionOpenFailed( "Connection to serial port failed" ) from e # PAICriticalException - except: + except Exception: logger.exception("Unable to connect to Serial") finally: open_timeout_handler.cancel() diff --git a/paradox/exceptions.py b/paradox/exceptions.py index fdd36ee7..90c816dd 100644 --- a/paradox/exceptions.py +++ b/paradox/exceptions.py @@ -35,9 +35,11 @@ class PAICriticalException(PAIException): class AuthenticationFailed(PAICriticalException): pass + class CodeLockout(PAICriticalException): pass + class PanelNotDetected(PAICriticalException): pass @@ -46,6 +48,10 @@ class SerialConnectionOpenFailed(PAICriticalException): pass +class InvalidCommand(PAIException): + pass + + def async_loop_unhandled_exception_handler(loop, context): exception = context.get("exception") diff --git a/paradox/interfaces/text/core.py b/paradox/interfaces/text/core.py index dc83df50..0cee55da 100644 --- a/paradox/interfaces/text/core.py +++ b/paradox/interfaces/text/core.py @@ -2,15 +2,15 @@ from paradox.config import config as cfg from paradox.event import Event, EventLevel, Notification -from paradox.interfaces import ThreadQueueInterface +from paradox.exceptions import InvalidCommand +from paradox.interfaces import AsyncInterface from paradox.lib import ps -from paradox.lib.event_filter import (EventFilter, EventTagFilter, - LiveEventRegexpFilter) +from paradox.lib.event_filter import EventFilter, EventTagFilter, LiveEventRegexpFilter logger = logging.getLogger("PAI").getChild(__name__) -class AbstractTextInterface(ThreadQueueInterface): +class AbstractTextInterface(AsyncInterface): """Interface Class using any Text interface""" def __init__(self, alarm, event_filter: EventFilter, min_level=EventLevel.INFO): @@ -21,12 +21,7 @@ def __init__(self, alarm, event_filter: EventFilter, min_level=EventLevel.INFO): self.min_level = min_level self.alarm = alarm - def stop(self): - super().stop() - - def _run(self): - super(AbstractTextInterface, self)._run() - + async def run(self): ps.subscribe(self.handle_panel_event, "events") ps.subscribe(self.handle_notify, "notifications") @@ -38,11 +33,17 @@ def notification_filter(self, notification: Notification): def handle_notify(self, notification: Notification): if self.notification_filter(notification): - self.send_message(notification.message, notification.level) + try: + self.send_message(notification.message, notification.level) + except Exception as e: + logger.exception(f"Error handling notification: {e}") def handle_panel_event(self, event: Event): if self.event_filter.match(event): - self.send_message(event.message, event.level) + try: + self.send_message(event.message, event.level) + except Exception as e: + logger.exception(f"Error handling event: {e}") async def handle_command(self, message_raw): message = cfg.COMMAND_ALIAS.get(message_raw, message_raw) @@ -50,7 +51,7 @@ async def handle_command(self, message_raw): tokens = message.split(" ") if len(tokens) != 3: - m = "Invalid: {}".format(message_raw) + m = f"Invalid: {message_raw}" logger.warning(m) return m @@ -61,46 +62,45 @@ async def handle_command(self, message_raw): element_type = tokens[0].lower() element = tokens[1] - command = self.normalize_payload(tokens[2].lower()) + command = self.normalize_command(tokens[2].lower()) # Process a Zone Command if element_type == "zone": if not await self.alarm.control_zone(element, command): - m = "Zone command error: {}={}".format(element, command) + m = f"Zone command error: {element}={command}" logger.warning(m) return m # Process a Partition Command elif element_type == "partition": if not await self.alarm.control_partition(element, command): - m = "Partition command error: {}={}".format(element, command) + m = f"Partition command error: {element}={command}" logger.warning(m) return m # Process an Output Command elif element_type == "output": if not await self.alarm.control_output(element, command): - m = "Output command error: {}={}".format(element, command) + m = f"Output command error: {element}={command}" logger.warning(m) return m else: - m = "Invalid control element: {}".format(element) + m = f"Invalid control element: {element}" logger.error(m) return m - logger.info("OK: {}".format(message_raw)) + logger.info(f"OK: {message_raw}") return "OK" - # TODO: Remove this (to panels?) @staticmethod - def normalize_payload(message): - message = message.strip().lower() + def normalize_command(command): + command = command.strip().lower() - if message in ["true", "on", "1", "enable"]: + if command in ["true", "on", "1", "enable"]: return "on" - elif message in ["false", "off", "0", "disable"]: + elif command in ["false", "off", "0", "disable"]: return "off" - elif message in [ + elif command in [ "pulse", "arm", "disarm", @@ -109,9 +109,9 @@ def normalize_payload(message): "bypass", "clear_bypass", ]: - return message + return command - return None + raise InvalidCommand(f'Invalid command: "{command}"') class ConfiguredAbstractTextInterface(AbstractTextInterface): diff --git a/paradox/interfaces/text/gsm.py b/paradox/interfaces/text/gsm.py index 3201817e..229e8313 100644 --- a/paradox/interfaces/text/gsm.py +++ b/paradox/interfaces/text/gsm.py @@ -1,11 +1,8 @@ -# -*- coding: utf-8 -*- - import asyncio -import datetime import json import logging import os -from concurrent import futures +from typing import Callable, Optional import serial_asyncio @@ -24,22 +21,16 @@ class SerialConnectionProtocol(ConnectionProtocol): def __init__(self, handler: ConnectionHandler): - super(SerialConnectionProtocol, self).__init__(handler) - self.buffer = b"" - self.loop = asyncio.get_event_loop() + super().__init__(handler) self.last_message = b"" - def connection_made(self, transport): - super(SerialConnectionProtocol, self).connection_made(transport) - self.handler.on_connection() - async def send_message(self, message): self.last_message = message self.transport.write(message + b"\r\n") def data_received(self, recv_data): self.buffer += recv_data - logger.debug("BUFFER: {}".format(self.buffer)) + logger.debug(f"BUFFER: {self.buffer}") while len(self.buffer) >= 0: r = self.buffer.find(b"\r\n") # not found @@ -61,25 +52,22 @@ def data_received(self, recv_data): if self.last_message == frame: self.last_message = b"" elif len(frame) > 0: - self.loop.create_task(self.handler.on_message(frame)) # Callback + self.handler.on_message(frame) # Callback def connection_lost(self, exc): logger.error("The serial port was closed") - self.buffer = b"" self.last_message = b"" - super(SerialConnectionProtocol, self).connection_lost(exc) + super().connection_lost(exc) class SerialCommunication(ConnectionHandler): - def __init__(self, loop, port, baud=9600, timeout=5, recv_callback=None): + def __init__(self, port, baud=9600, timeout=5): self.port_path = port self.baud = baud self.connected_future = None - self.recv_callback = recv_callback - self.loop = loop + self.recv_callback = None self.connected = False self.connection = None - asyncio.set_event_loop(loop) self.queue = asyncio.Queue() def clear(self): @@ -96,16 +84,14 @@ def on_connection(self): self.connected = True def on_message(self, message: bytes): - logger.debug("M->I: {}".format(message)) + logger.debug(f"M->I: {message}") if self.recv_callback is not None: - return asyncio.get_event_loop().call_soon( - self.recv_callback(message) - ) # Callback + self.recv_callback(message) # Callback else: - return self.queue.put_nowait(message) + self.queue.put_nowait(message) - def set_recv_callback(self, callback): + def set_recv_callback(self, callback: Optional[Callable[[str], bool]]): self.recv_callback = callback def open_timeout(self): @@ -120,23 +106,23 @@ def make_protocol(self): return SerialConnectionProtocol(self) async def write(self, message, timeout=15): - logger.debug("I->M: {}".format(message)) + logger.debug(f"I->M: {message}") if self.connection is not None: await self.connection.send_message(message) - return await asyncio.wait_for(self.queue.get(), timeout=5, loop=self.loop) + return await asyncio.wait_for(self.queue.get(), timeout=5) async def read(self, timeout=5): if self.connection is not None: return await asyncio.wait_for(self.queue.get(), timeout=timeout) async def connect(self): - logger.info("Connecting to serial port {}".format(self.port_path)) + logger.info(f"Connecting to serial port {self.port_path}") - self.connected_future = self.loop.create_future() - self.loop.call_later(5, self.open_timeout) + self.connected_future = asyncio.get_event_loop().create_future() + asyncio.get_event_loop().call_later(5, self.open_timeout) _, self.connection = await serial_asyncio.create_serial_connection( - self.loop, self.make_protocol, self.port_path, self.baud + asyncio.get_event_loop(), self.make_protocol, self.port_path, self.baud ) return await self.connected_future @@ -156,73 +142,64 @@ def __init__(self, alarm): self.port = None self.modem_connected = False - self.loop = asyncio.new_event_loop() self.message_cmt = None def stop(self): - """ Stops the GSM Interface Thread""" - self.stop_running.set() - - self.loop.stop() + """Stops the GSM Interface""" super().stop() + logger.debug("GSM Stopped. TODO: Implement a proper stop") - logger.debug("GSM Stopped") - - def write(self, message: str, expected: str = None) -> None: + async def write(self, message: str, expected: str = None) -> None: r = b"" while r != expected: - r = self.loop.run_until_complete(self.port.write(message)) + r = await self.port.write(message) data = b"" if r == b"ERROR": - raise Exception("Got error from modem: {}".format(r)) + raise Exception(f"Got error from modem: {r}") while r != expected: - r = self.loop.run_until_complete(self.port.read()) + r = await self.port.read() data += r + b"\n" - def connect(self): - logger.info( - "Using {} at {} baud".format(cfg.GSM_MODEM_PORT, cfg.GSM_MODEM_BAUDRATE) - ) + async def connect(self): + logger.info(f"Using {cfg.GSM_MODEM_PORT} at {cfg.GSM_MODEM_BAUDRATE} baud") try: if not os.path.exists(cfg.GSM_MODEM_PORT): - logger.error("Modem port ({}) not found".format(cfg.GSM_MODEM_PORT)) + logger.error(f"Modem port ({cfg.GSM_MODEM_PORT}) not found") return False self.port = SerialCommunication( - self.loop, cfg.GSM_MODEM_PORT, cfg.GSM_MODEM_BAUDRATE, 5 + cfg.GSM_MODEM_PORT, cfg.GSM_MODEM_BAUDRATE, 5 ) - except: - logger.exception( - "Could not open port {} for GSM modem".format(cfg.GSM_MODEM_PORT) - ) + except Exception: + logger.exception(f"Could not open port {cfg.GSM_MODEM_PORT} for GSM modem") return False self.port.set_recv_callback(None) - result = self.loop.run_until_complete(self.port.connect()) + result = await self.port.connect() if not result: logger.exception("Could not connect to GSM modem") return False try: - self.write(b"AT", b"OK") # Init - self.write(b"ATE0", b"OK") # Disable Echo - self.write(b"AT+CMEE=2", b"OK") # Increase verbosity - self.write(b"AT+CMGF=1", b"OK") # SMS Text mode - self.write(b"AT+CFUN=1", b"OK") # Enable modem - self.write( + await self.write(b"AT", b"OK") # Init + await self.write(b"ATE0", b"OK") # Disable Echo + await self.write(b"AT+CMEE=2", b"OK") # Increase verbosity + await self.write(b"AT+CMGF=1", b"OK") # SMS Text mode + await self.write(b"AT+CFUN=1", b"OK") # Enable modem + await self.write( b"AT+CNMI=1,2,0,0,0", b"OK" ) # SMS received only when modem enabled, Use +CMT with SMS, No Status Report, - self.write(b"AT+CUSD=1", b"OK") # Enable result code presentation + await self.write(b"AT+CUSD=1", b"OK") # Enable result code presentation - except futures.TimeoutError as e: + except asyncio.TimeoutError: logger.error("No reply from modem") return False - except: + except Exception: logger.exception("Modem connect error") return False @@ -234,20 +211,16 @@ def connect(self): self.modem_connected = True return True - def _run(self): - super(GSMTextInterface, self)._run() + async def run(self): + await super().run() - while not self.modem_connected and not self.stop_running.isSet(): - if not self.connect(): + while not self.modem_connected: + if not await self.connect(): logger.warning("Could not connect to modem") - self.stop_running.wait(5) - - self.loop.run_forever() - - self.stop_running.wait() + await asyncio.sleep(5) - async def data_received(self, data: str) -> bool: + def data_received(self, data: str) -> bool: logger.debug(f"Data Received: {data}") data = data.decode() @@ -262,21 +235,18 @@ async def data_received(self, data: str) -> bool: return True - def handle_message(self, timestamp: str, source: str, message: str) -> None: - """ Handle GSM message. It should be a command """ + async def handle_message(self, timestamp: str, source: str, message: str) -> None: + """Handle GSM message. It should be a command""" - logger.debug("Received: {} {} {}".format(timestamp, source, message)) + logger.debug(f"Received: {timestamp} {source} {message}") if source in cfg.GSM_CONTACTS: - future = asyncio.run_coroutine_threadsafe( - self.handle_command(message), self.alarm.work_loop - ) - ret = future.result(10) + ret = await self.handle_command(message) - m = "GSM {}: {}".format(source, ret) + m = f"GSM {source}: {ret}" logger.info(m) else: - m = "GSM {} (UNK): {}".format(source, message) + m = f"GSM {source} (UNK): {message}" logger.warning(m) self.send_message(m, EventLevel.INFO) @@ -284,7 +254,7 @@ def handle_message(self, timestamp: str, source: str, message: str) -> None: Notification(sender=self.name, message=m, level=EventLevel.INFO) ) - def send_message(self, message: str, level: EventLevel) -> None: + async def send_message(self, message: str, level: EventLevel) -> None: if self.port is None: logger.warning("GSM not available when sending message") return @@ -293,12 +263,9 @@ def send_message(self, message: str, level: EventLevel) -> None: data = b'AT+CMGS="%b"\x0d%b\x1a' % (dst.encode(), message.encode()) try: - future = asyncio.run_coroutine_threadsafe( - self.port.write(data), self.loop - ) - result = future.result() - logger.debug("SMS result: {}".format(result)) - except: + result = await self.port.write(data) + logger.debug(f"SMS result: {result}") + except Exception: logger.exception("ERROR sending SMS") def process_cmt(self, header: str, text: str) -> None: @@ -308,8 +275,8 @@ def process_cmt(self, header: str, text: str) -> None: tokens = json.loads(f"[{header[idx:]}]", strict=False) - logger.debug("On {}, {} sent {}".format(tokens[2], tokens[0], text)) - self.handle_message(tokens[2], tokens[0], text) + logger.debug(f"On {tokens[2]}, {tokens[0]} sent {text}") + asyncio.create_task(self.handle_message(tokens[2], tokens[0], text)) def process_cusd(self, message: str) -> None: idx = message.find(" ") diff --git a/paradox/interfaces/text/homeassistant_notifications.py b/paradox/interfaces/text/homeassistant_notifications.py index 89e2357a..ed69d0bd 100644 --- a/paradox/interfaces/text/homeassistant_notifications.py +++ b/paradox/interfaces/text/homeassistant_notifications.py @@ -22,7 +22,13 @@ def __init__(self, alarm): cfg.HOMEASSISTANT_NOTIFICATIONS_MIN_EVENT_LEVEL, ) + self.api_url = "http://supervisor/core/api/services/:domain/:service" + if cfg.HOMEASSISTANT_NOTIFICATIONS_API_URL: + self.api_url = cfg.HOMEASSISTANT_NOTIFICATIONS_API_URL + self.token = os.environ.get("SUPERVISOR_TOKEN") + if cfg.HOMEASSISTANT_NOTIFICATIONS_API_TOKEN: + self.token = cfg.HOMEASSISTANT_NOTIFICATIONS_API_TOKEN if not self.token: logger.error( f'"SUPERVISOR_TOKEN" environment variable must be set to use {__class__.__name__}' @@ -31,20 +37,43 @@ def __init__(self, alarm): def send_message(self, message: str, level: EventLevel): if not self.token: logger.warning( - 'Unable to send a notification to Home Assistant. "SUPERVISOR_TOKEN" environment variable is not set' + "Unable to send a notification to Home Assistant. No token is set." ) return - notifier_name = cfg.HOMEASSISTANT_NOTIFICATIONS_NOTIFIER_NAME - url = f"http://supervisor/core/api/services/notify/{notifier_name}" + url = self.api_url.replace(":domain", "notify").replace( + ":service", cfg.HOMEASSISTANT_NOTIFICATIONS_NOTIFIER_NAME + ) + + data = {} + + if cfg.HOMEASSISTANT_NOTIFICATIONS_LOVELACE_URI: + # iOS + data["url"] = cfg.HOMEASSISTANT_NOTIFICATIONS_LOVELACE_URI + # Android + data["clickAction"] = cfg.HOMEASSISTANT_NOTIFICATIONS_LOVELACE_URI + + if level == EventLevel.CRITICAL: + data.update( + { + # iOS + "push": { + "interruption-level": "critical", + }, + # Android + "ttl": 0, + "priority": "high", + "channel": "alarm_stream", + } + ) - payload = {"message": message, "title": "Paradox", "data": {"level": level}} + payload = {"message": message, "title": "Paradox", "data": data} headers = {"Authorization": f"Bearer {self.token}"} res = requests.post(url, json=payload, headers=headers) if res.status_code == 200: - logger.debug(f"Notification sent: {message}, level={level}") + logger.info(f"Notification sent: {message}, level={level}") else: logger.error( f"Failed to send notification: code={res.status_code}, text: {res.text}" diff --git a/paradox/interfaces/text/pushbullet.py b/paradox/interfaces/text/pushbullet.py index 07183030..47c972e9 100644 --- a/paradox/interfaces/text/pushbullet.py +++ b/paradox/interfaces/text/pushbullet.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import asyncio import json import logging @@ -23,7 +22,7 @@ class PushBulletWSClient(WebSocketBaseClient): name = "pushbullet" def __init__(self, interface, url): - """ Initializes the PB WS Client""" + """Initializes the PB WS Client""" super().__init__(url) self.pb = Pushbullet(cfg.PUSHBULLET_KEY) @@ -31,7 +30,7 @@ def __init__(self, interface, url): self.interface = interface self.device = None - for i, device in enumerate(self.pb.devices): + for _, device in enumerate(self.pb.devices): if device.nickname == cfg.PUSHBULLET_DEVICE: logger.debug("Device found") self.device = device @@ -45,12 +44,12 @@ def stop(self): self.manager.stop() def handshake_ok(self): - """ Callback trigger when connection succeeded""" + """Callback trigger when connection succeeded""" logger.info("Handshake OK") self.manager.add(self) self.manager.start() for chat in self.pb.chats: - logger.debug("Associated contacts: {}".format(chat)) + logger.debug(f"Associated contacts: {chat}") # Receiving pending messages self.received_message(json.dumps({"type": "tickle", "subtype": "push"})) @@ -58,12 +57,12 @@ def handshake_ok(self): self.send_message("Active") def received_message(self, message): - """ Handle Pushbullet message. It should be a command """ - logger.debug("Received Message {}".format(message)) + """Handle Pushbullet message. It should be a command""" + logger.debug(f"Received Message {message}") try: message = json.loads(str(message)) - except: + except Exception: logger.exception("Unable to parse message") return @@ -107,11 +106,11 @@ def received_message(self, message): ) def unhandled_error(self, error): - logger.error("{}".format(error)) + logger.error(f"{error}") try: self.terminate() - except: + except Exception: logger.exception("Closing Pushbullet WS") self.close() @@ -129,7 +128,7 @@ def send_message(self, msg, dstchat=None): if chat.email in cfg.PUSHBULLET_CONTACTS: try: self.pb.push_note(cfg.PUSHBULLET_DEVICE, msg, chat=chat) - except: + except Exception: logger.exception("Sending message") time.sleep(5) @@ -148,21 +147,21 @@ def __init__(self, alarm): self.name = PushBulletWSClient.name self.pb_ws = None - def _run(self): - super(PushbulletTextInterface, self)._run() + async def run(self): + await super().run() try: self.pb_ws = PushBulletWSClient( self, - "wss://stream.pushbullet.com/websocket/{}".format(cfg.PUSHBULLET_KEY), + f"wss://stream.pushbullet.com/websocket/{cfg.PUSHBULLET_KEY}", ) self.pb_ws.connect() - except: + except Exception: logger.exception("Could not connect to Pushbullet service") logger.info("Pushbullet Interface Started") def stop(self): - """ Stops the Pushbullet interface""" + """Stops the Pushbullet interface""" super().stop() if self.pb_ws is not None: self.pb_ws.stop() diff --git a/paradox/interfaces/text/pushover.py b/paradox/interfaces/text/pushover.py index 6370654b..3e6f7b3f 100644 --- a/paradox/interfaces/text/pushover.py +++ b/paradox/interfaces/text/pushover.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- - +import http.client import logging import re - -import chump +import urllib from paradox.config import config as cfg from paradox.event import EventLevel @@ -12,12 +10,12 @@ logger = logging.getLogger("PAI").getChild(__name__) _level_2_priority = { - EventLevel.NOTSET: chump.LOWEST, - EventLevel.DEBUG: chump.LOWEST, - EventLevel.INFO: chump.LOW, - EventLevel.WARN: chump.NORMAL, - EventLevel.ERROR: chump.HIGH, - EventLevel.CRITICAL: chump.EMERGENCY, + EventLevel.NOTSET: -2, + EventLevel.DEBUG: -2, + EventLevel.INFO: -1, + EventLevel.WARN: 0, + EventLevel.ERROR: 1, + EventLevel.CRITICAL: 2, } @@ -33,62 +31,45 @@ def __init__(self, alarm): cfg.PUSHOVER_MIN_EVENT_LEVEL, ) - self.app = None self.users = {} - def _run(self): - super(PushoverTextInterface, self)._run() - - self.app = chump.Application(cfg.PUSHOVER_KEY) - if not self.app.is_authenticated: - raise Exception( - "Failed to authenticate with Pushover. Please check PUSHOVER_APPLICATION_KEY" - ) - def send_message(self, message: str, level: EventLevel): for settings in cfg.PUSHOVER_BROADCAST_KEYS: user_key = settings["user_key"] devices_raw = settings["devices"] - user = self.users.get(user_key) # type: chump.User - - if user is None: - user = self.users[user_key] = self.app.get_user(user_key) - - if not user.is_authenticated: - raise Exception( - "Failed to check user key with Pushover. Please check PUSHOVER_BROADCAST_KEYS[%s]" - % user_key - ) - if devices_raw == "*" or devices_raw is None: - try: - user.send_message( - message, - title="Alarm", - priority=_level_2_priority.get(level, chump.NORMAL), - ) - except: - logger.exception("Pushover send message") - + self._send_pushover_message(user_key, message, level) else: devices = list(filter(bool, re.split(r"[\s]*,[\s]*", devices_raw))) - for elem in (elem for elem in devices if elem not in user.devices): - logger.warning( - "%s is not in the Pushover device list for the user %s" - % (elem, user_key) - ) - for device in devices: - try: - user.send_message( - message, - title="PAI", - device=device, - priority=_level_2_priority.get(level, chump.NORMAL), - ) - except: - logger.exception("Pushover send message") + self._send_pushover_message(user_key, message, level, device) + + def _send_pushover_message(self, user_key, message, level, device=None): + conn = http.client.HTTPSConnection("api.pushover.net:443") + params = { + "token": cfg.PUSHOVER_KEY, + "user": user_key, + "message": message, + "priority": _level_2_priority.get(level, 0), + "title": "Alarm", + } + if device: + params["device"] = device + + conn.request( + "POST", + "/1/messages.json", + urllib.parse.urlencode(params), + {"Content-type": "application/x-www-form-urlencoded"}, + ) - # TODO: Missing the message reception + response = conn.getresponse() + if response.status != 200: + logger.error(f"Failed to send message: {response.reason}") + else: + logger.info( + f"Notification sent: {message}, level={level}, device={device if device else 'all'}" + ) + conn.close() diff --git a/paradox/interfaces/text/signal.py b/paradox/interfaces/text/signal.py index 1bdcd28c..81f92e33 100644 --- a/paradox/interfaces/text/signal.py +++ b/paradox/interfaces/text/signal.py @@ -1,15 +1,16 @@ -# -*- coding: utf-8 -*- import asyncio import logging from gi.repository import GLib +from pydbus import SystemBus + from paradox.config import config as cfg from paradox.event import EventLevel, Notification + # Signal interface. # Only exposes critical status changes and accepts commands from paradox.interfaces.text.core import ConfiguredAbstractTextInterface from paradox.lib import ps -from pydbus import SystemBus logger = logging.getLogger("PAI").getChild(__name__) @@ -27,30 +28,29 @@ def __init__(self, alarm): ) self.signal = None - self.loop = None + self.glib_loop = None def stop(self): - - """ Stops the Signal Interface Thread""" - if self.loop is not None: - self.loop.quit() + """Stops the Signal Interface Thread""" + if self.glib_loop is not None: + self.glib_loop.quit() super().stop() logger.debug("Signal Stopped") - def _run(self): - super(SignalTextInterface, self)._run() + async def run(self): + await super().run() bus = SystemBus() self.signal = bus.get("org.asamk.Signal") self.signal.onMessageReceived = self.handle_message - self.loop = GLib.MainLoop() + self.glib_loop = GLib.MainLoop() logger.debug("Signal Interface Running") - self.loop.run() + asyncio.get_event_loop().run_in_executor(None, self.glib_loop.run) def send_message(self, message: str, level: EventLevel): if self.signal is None: @@ -59,11 +59,11 @@ def send_message(self, message: str, level: EventLevel): try: self.signal.sendMessage(str(message), [], cfg.SIGNAL_CONTACTS) - except: + except Exception: logger.exception("Signal send message") def handle_message(self, timestamp, source, groupID, message, attachments): - """ Handle Signal message. It should be a command """ + """Handle Signal message. It should be a command""" logger.debug( "Received Message {} {} {} {} {}".format( @@ -77,10 +77,10 @@ def handle_message(self, timestamp, source, groupID, message, attachments): ) ret = future.result(10) - m = "Signal {} : {}".format(source, ret) + m = f"Signal {source} : {ret}" logger.info(m) else: - m = "Signal {} (UNK): {}".format(source, message) + m = f"Signal {source} (UNK): {message}" logger.warning(m) self.send_message(m, EventLevel.INFO) diff --git a/paradox/lib/async_message_manager.py b/paradox/lib/async_message_manager.py index 5089a8c9..fbaa4f2e 100644 --- a/paradox/lib/async_message_manager.py +++ b/paradox/lib/async_message_manager.py @@ -26,13 +26,9 @@ def can_handle(self, data: Container) -> bool: class AsyncMessageManager: - def __init__(self, loop=None): + def __init__(self): super().__init__() - if not loop: - loop = asyncio.get_event_loop() - self.loop = loop - self.handler_registry = HandlerRegistry() self.raw_handler_registry = HandlerRegistry() @@ -58,7 +54,11 @@ def deregister_handler(self, name): self.handler_registry.remove_by_name(name) def schedule_message_handling(self, message: Container): - return self.loop.create_task(self.handler_registry.handle(message)) + return asyncio.get_event_loop().create_task( + self.handler_registry.handle(message) + ) def schedule_raw_message_handling(self, message: Container): - return self.loop.create_task(self.raw_handler_registry.handle(message)) + return asyncio.get_event_loop().create_task( + self.raw_handler_registry.handle(message) + ) diff --git a/paradox/lib/help.py b/paradox/lib/help.py index 5149584c..b97e4015 100644 --- a/paradox/lib/help.py +++ b/paradox/lib/help.py @@ -36,9 +36,6 @@ "yaml": dict( mandatory=False, desc="the IP150 connection", install_name="pyyaml>=5.2.0" ), - "chump": dict( - mandatory=False, desc="the Pushover interface", install_name="chump>=1.6.0" - ), "pydbus": dict( mandatory=False, desc="the Signal interface", install_name="pydbus>=0.6.0" ), diff --git a/pyproject.toml b/pyproject.toml index 86246907..83a2af82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,3 +22,4 @@ combine_as_imports = true [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/requirements.txt b/requirements.txt index 788d52b1..5234c51c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ argparse>=1.4.0 -chump>=1.6.0 construct~=2.9.43 flake8 paho_mqtt>=1.5.0,<2 diff --git a/setup.cfg b/setup.cfg index 7bdd4887..f94c7ae7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,8 +55,6 @@ YAML = Pushbullet = pushbullet.py>=0.11.0 ws4py>=0.4.2 -Pushover = - chump>=1.6.0 Signal = pygobject>=3.20.0 pydbus>=0.6.0 diff --git a/tests/connection/ip/test_bare_connection.py b/tests/connection/ip/test_bare_connection.py index 89165930..524aeefb 100644 --- a/tests/connection/ip/test_bare_connection.py +++ b/tests/connection/ip/test_bare_connection.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import AsyncMock import pytest @@ -14,7 +15,9 @@ async def test_connect(mocker): protocol.is_active.return_value = True create_connection_mock = AsyncMock(return_value=(None, protocol)) - mocker.patch.object(connection.loop, "create_connection", create_connection_mock) + mocker.patch.object( + asyncio.get_event_loop(), "create_connection", create_connection_mock + ) assert connection.connected is False diff --git a/tests/connection/ip/test_local_ip_connection.py b/tests/connection/ip/test_local_ip_connection.py index 30e37792..4d84a3d9 100644 --- a/tests/connection/ip/test_local_ip_connection.py +++ b/tests/connection/ip/test_local_ip_connection.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import AsyncMock import pytest @@ -15,7 +16,9 @@ async def test_connect(mocker): protocol.is_active.return_value = True create_connection_mock = AsyncMock(return_value=(None, protocol)) - mocker.patch.object(connection.loop, "create_connection", create_connection_mock) + mocker.patch.object( + asyncio.get_event_loop(), "create_connection", create_connection_mock + ) connect_command_execute = mocker.patch.object( IPModuleConnectCommand, "execute", AsyncMock() ) diff --git a/tests/connection/ip/test_stun_ip_connection.py b/tests/connection/ip/test_stun_ip_connection.py index 0595bc48..26593934 100644 --- a/tests/connection/ip/test_stun_ip_connection.py +++ b/tests/connection/ip/test_stun_ip_connection.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import AsyncMock import pytest @@ -17,7 +18,9 @@ async def test_connect(mocker): protocol.is_active.return_value = True create_connection_mock = AsyncMock(return_value=(None, protocol)) - mocker.patch.object(connection.loop, "create_connection", create_connection_mock) + mocker.patch.object( + asyncio.get_event_loop(), "create_connection", create_connection_mock + ) connect_command_execute = mocker.patch.object( IPModuleConnectCommand, "execute", AsyncMock() ) @@ -74,7 +77,7 @@ async def assert_session_connect(mocker, session): "serial": "bf4c1fe4", "type": "HD77", "port": 54321, - "panelSerial": "0584b067" + "panelSerial": "0584b067", }, { "lastUpdate": "2021-05-07T15:41:19Z", @@ -85,7 +88,7 @@ async def assert_session_connect(mocker, session): "serial": "465e81a0", "type": "HD88", "port": 12345, - "panelSerial": "0584b067" + "panelSerial": "0584b067", }, { "lastUpdate": "2021-05-07T15:41:19Z", @@ -98,7 +101,7 @@ async def assert_session_connect(mocker, session): "panelSerial": "a72ed4bf", "xoraddr": "9a640069cda9b317", "API": None, - "ipAddress": "0.0.0.0" + "ipAddress": "0.0.0.0", }, { "lastUpdate": "2021-05-07T15:41:19Z", @@ -111,13 +114,13 @@ async def assert_session_connect(mocker, session): "panelSerial": "0584b067", "xoraddr": "c351472f48a5e1ba", "API": None, - "ipAddress": "0.0.0.0" - } + "ipAddress": "0.0.0.0", + }, ], "paid": 1, "daysLeft": 364, "sitePanelStatus": 1, - "email": "em@em.com" + "email": "em@em.com", } ] } @@ -130,6 +133,8 @@ def json(self): mocker.patch("requests.get").return_value = StubResponse() client = mocker.patch("paradox.lib.stun.StunClient") - client.return_value.receive_response.return_value = [{"attr_body": "abcdef", "name": "BEER"}] + client.return_value.receive_response.return_value = [ + {"attr_body": "abcdef", "name": "BEER"} + ] await session.connect() - return json_data \ No newline at end of file + return json_data diff --git a/tests/interfaces/test_gsm.py b/tests/interfaces/test_gsm.py new file mode 100644 index 00000000..efc2a7ea --- /dev/null +++ b/tests/interfaces/test_gsm.py @@ -0,0 +1,117 @@ +import asyncio +from unittest import mock + +import pytest + +from paradox.interfaces.text.gsm import ( + GSMTextInterface, + SerialCommunication, + SerialConnectionProtocol, +) + + +@pytest.fixture +async def connected_serial_communication(): + port = "test_port" + baud = 9600 + timeout = 5 + comm = SerialCommunication(port, baud, timeout) + + assert comm.queue.empty() + + async def mocked_create_serial_connection(loop, protocol_factory, *args, **kwargs): + transport = mock.Mock() + protocol = comm.make_protocol() + asyncio.get_event_loop().call_soon(protocol.connection_made, transport) + return (transport, protocol) + + with mock.patch( + "serial_asyncio.create_serial_connection", + new_callable=mock.AsyncMock, + side_effect=mocked_create_serial_connection, + ): + asyncio.get_event_loop().call_soon(comm.on_connection) + result = await comm.connect() + assert result + + assert comm.connected + + return comm + + +# Test SerialConnectionProtocol class +@pytest.mark.asyncio +async def test_serial_connection_protocol(): + handler = mock.MagicMock() + protocol = SerialConnectionProtocol(handler) + + transport = mock.MagicMock() + protocol.connection_made(transport) + handler.on_connection.assert_called_once() + + message = b"test_message" + await protocol.send_message(message) + transport.write.assert_called_once_with(message + b"\r\n") + + recv_data = b"test_data\r\n" + protocol.data_received(recv_data) + handler.on_message.assert_called_once_with(b"test_data") + + exc = Exception("test_exception") + protocol.connection_lost(exc) + handler.on_connection_loss.assert_called_once_with() + + +# Test SerialCommunication class +@pytest.mark.asyncio +async def test_serial_communication(connected_serial_communication): + comm = connected_serial_communication + + write_message = b"write_message" + write_response_message = b"write_response_message" + read_message = b"read_message" + + asyncio.get_event_loop().call_soon(comm.on_message, write_response_message) + result = await comm.write(write_message) + assert result == write_response_message + + asyncio.get_event_loop().call_soon(comm.on_message, read_message) + await comm.read() + assert comm.queue.empty() + + callback = mock.MagicMock() + comm.set_recv_callback(callback) + assert comm.recv_callback == callback + comm.on_message(read_message) + callback.assert_called_once_with(read_message) + + +# Test GSMTextInterface class +@pytest.mark.asyncio +async def test_gsm_text_interface(connected_serial_communication): + alarm = mock.MagicMock() + event = asyncio.Event() + + async def control_partition(partition, command): + assert partition == "outside" + assert command == "arm" + event.set() + + return True + + interface = GSMTextInterface(alarm) + interface.port = connected_serial_communication + interface.modem_connected = True + + data = b"+CMT: test_data" + interface.data_received(data) + assert interface.message_cmt == data.decode() + + # level = EventLevel.INFO + # await interface.send_message("bla", level) + + header = '+CMT: "+1234567890","","24/09/17,10:30:00+32"' + text = "partition outside arm" + alarm.control_partition.side_effect = control_partition + interface.process_cmt(header, text) + await asyncio.wait_for(event.wait(), timeout=0.1) diff --git a/tests/lib/test_event_filter.py b/tests/lib/test_event_filter.py index 1f6eb253..c944689c 100644 --- a/tests/lib/test_event_filter.py +++ b/tests/lib/test_event_filter.py @@ -16,6 +16,7 @@ def test_tag_match(): assert EventTagFilter(["partition+arm"]).match(event) is True assert EventTagFilter(["partition+arm+restore"]).match(event) is True + assert EventTagFilter(["partition,arm,-restore"]).match(event) is False assert EventTagFilter(["partition"]).match(event) is True assert EventTagFilter(["arm"]).match(event) is True assert EventTagFilter(["arm-zone"]).match(event) is True diff --git a/tests/pai.conf b/tests/pai.conf index a8ccbf86..ba0793bd 100644 --- a/tests/pai.conf +++ b/tests/pai.conf @@ -1,3 +1,5 @@ # Just make Config class happy and use defaults. LOGGING_FILE=None + +GSM_CONTACTS = ["+1234567890"] diff --git a/tests/test_async_queue.py b/tests/test_async_queue.py index ac329101..ad0e8fa0 100644 --- a/tests/test_async_queue.py +++ b/tests/test_async_queue.py @@ -1,9 +1,9 @@ import asyncio import binascii - from unittest import mock -import pytest + from construct import Container +import pytest from paradox.hardware.evo.parsers import LiveEvent, ReadEEPROMResponse from paradox.lib.async_message_manager import AsyncMessageManager @@ -19,12 +19,11 @@ def print_beer(m): print("beer") -def test_event_handler(): +@pytest.mark.asyncio +async def test_event_handler(): eh = EventMessageHandler(print_beer) - loop = asyncio.get_event_loop() - mh = AsyncMessageManager(loop) - + mh = AsyncMessageManager() mh.register_handler(eh) assert 1 == len(mh.handler_registry) @@ -33,13 +32,13 @@ def test_event_handler(): message = LiveEvent.parse(payload) - coro = asyncio.ensure_future(mh.schedule_message_handling(message)) - loop.run_until_complete(coro) + await mh.schedule_message_handling(message) assert 1 == len(mh.handler_registry) -def test_event_handler_failure(): +@pytest.mark.asyncio +async def test_event_handler_failure(): # eeprom_request_bin = binascii.unhexlify('500800009f004037') eeprom_response_bin = binascii.unhexlify( "524700009f0041133e001e0e0400000000060a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000121510010705004e85" @@ -48,26 +47,24 @@ def test_event_handler_failure(): eh = EventMessageHandler(print_beer) eh.handle = mock.MagicMock() - loop = asyncio.get_event_loop() - mh = AsyncMessageManager(loop) - + mh = AsyncMessageManager() mh.register_handler(eh) assert 1 == len(mh.handler_registry) message = ReadEEPROMResponse.parse(eeprom_response_bin) - coro = asyncio.ensure_future(mh.schedule_message_handling(message)) - loop.run_until_complete(coro) + coro = mh.schedule_message_handling(message) + result = await coro assert ( - coro.result() is None + result is None ) # failed to parse response message return None. Maybe needs to throw something. assert 1 == len(mh.handler_registry) eh.handle.assert_not_called() -def test_handler_two_messages(): +async def test_handler_two_messages(): def event_handler(message): print("event") @@ -80,28 +77,21 @@ async def get_eeprom_result(mhm): "524700009f0041133e001e0e0400000000060a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000121510010705004e85" ) - loop = asyncio.get_event_loop() - mh = AsyncMessageManager(loop) + mh = AsyncMessageManager() event_handler = EventMessageHandler(event_handler) mh.register_handler(event_handler) # running - task_handle_wait = loop.create_task(asyncio.sleep(0.1)) - task_get_eeprom = loop.create_task(get_eeprom_result(mh)) + task_handle_wait = asyncio.create_task(asyncio.sleep(0.1)) + task_get_eeprom = asyncio.create_task(get_eeprom_result(mh)) task_handle_event1 = mh.schedule_message_handling( LiveEvent.parse(event_response_bin) ) mh.schedule_message_handling(ReadEEPROMResponse.parse(eeprom_response_bin)) - task_handle_event2 = mh.schedule_message_handling( - LiveEvent.parse(event_response_bin) - ) - - # assert 2 == len(mh.handlers) + mh.schedule_message_handling(LiveEvent.parse(event_response_bin)) - loop.run_until_complete( - asyncio.gather(task_handle_wait, task_get_eeprom) - ) + await asyncio.gather(task_handle_wait, task_get_eeprom) assert 1 == len(mh.handler_registry) @@ -111,7 +101,8 @@ async def get_eeprom_result(mhm): assert 1 == len(mh.handler_registry) -def test_handler_timeout(): +@pytest.mark.asyncio +async def test_handler_timeout(): def event_handler(message): print("event received") @@ -131,23 +122,22 @@ async def post_eeprom_message(mhm): ReadEEPROMResponse.parse(eeprom_response_bin) ) - loop = asyncio.get_event_loop() - mh = AsyncMessageManager(loop) + mh = AsyncMessageManager() # running - task_get_eeprom = loop.create_task(get_eeprom_result(mh)) - loop.create_task(post_eeprom_message(mh)) + task_get_eeprom = asyncio.create_task(get_eeprom_result(mh)) + asyncio.create_task(post_eeprom_message(mh)) assert 0 == len(mh.handler_registry) with pytest.raises(asyncio.TimeoutError): - loop.run_until_complete(task_get_eeprom) + await task_get_eeprom assert 0 == len(mh.handler_registry) # Also test EventMessageHandler - event_handler = EventMessageHandler(event_handler) - mh.register_handler(event_handler) + event_handler_instance = EventMessageHandler(event_handler) + mh.register_handler(event_handler_instance) event_response_bin = b"\xe2\xff\xad\x06\x14\x13\x01\x04\x0e\x10\x00\x01\x05\x00\x00\x00\x00\x00\x02Living room \x00\xcc" task_handle_event1 = mh.schedule_message_handling( @@ -156,6 +146,6 @@ async def post_eeprom_message(mhm): assert 1 == len(mh.handler_registry) - loop.run_until_complete(task_handle_event1) + await task_handle_event1 assert 1 == len(mh.handler_registry)