diff --git a/tests/components/smlight/__init__.py b/tests/components/smlight/__init__.py index 37184226507e73..e518e0573bab0d 100644 --- a/tests/components/smlight/__init__.py +++ b/tests/components/smlight/__init__.py @@ -1 +1,21 @@ """Tests for the SMLIGHT Zigbee adapter integration.""" + +from collections.abc import Callable +from unittest.mock import MagicMock + +from pysmlight.const import Events as SmEvents +from pysmlight.sse import MessageEvent + + +def get_mock_event_function( + mock: MagicMock, event: SmEvents +) -> Callable[[MessageEvent], None]: + """Extract event function from mock call_args.""" + return next( + ( + call_args[0][1] + for call_args in mock.sse.register_callback.call_args_list + if call_args[0][0] == event + ), + None, + ) diff --git a/tests/components/smlight/test_binary_sensor.py b/tests/components/smlight/test_binary_sensor.py index 1b1c0358c37061..b1d72b66dcfef9 100644 --- a/tests/components/smlight/test_binary_sensor.py +++ b/tests/components/smlight/test_binary_sensor.py @@ -1,6 +1,5 @@ """Tests for the SMLIGHT binary sensor platform.""" -from collections.abc import Callable from unittest.mock import MagicMock from freezegun.api import FrozenDateTimeFactory @@ -14,6 +13,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er +from . import get_mock_event_function from .conftest import setup_integration from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform @@ -95,13 +95,8 @@ async def test_internet_sensor_event( assert len(mock_smlight_client.get_param.mock_calls) == 2 mock_smlight_client.get_param.assert_called_with("inetState") - event_function: Callable[[MessageEvent], None] = next( - ( - call_args[0][1] - for call_args in mock_smlight_client.sse.register_callback.call_args_list - if call_args[0][0] == Events.EVENT_INET_STATE - ), - None, + event_function = get_mock_event_function( + mock_smlight_client, Events.EVENT_INET_STATE ) event_function(MOCK_INET_STATE) diff --git a/tests/components/smlight/test_update.py b/tests/components/smlight/test_update.py index b0b8910ef9bbcc..7bff12bb027a81 100644 --- a/tests/components/smlight/test_update.py +++ b/tests/components/smlight/test_update.py @@ -1,6 +1,5 @@ """Tests for the SMLIGHT update platform.""" -from collections.abc import Callable from unittest.mock import MagicMock from freezegun.api import FrozenDateTimeFactory @@ -23,6 +22,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry as er +from . import get_mock_event_function from .conftest import setup_integration from tests.common import MockConfigEntry, async_fire_time_changed, snapshot_platform @@ -67,18 +67,6 @@ ] -def get_callback_function(mock: MagicMock, trigger: SmEvents): - """Extract the callback function for a given trigger.""" - return next( - ( - call_args[0][1] - for call_args in mock.sse.register_callback.call_args_list - if trigger == call_args[0][0] - ), - None, - ) - - @pytest.fixture def platforms() -> list[Platform]: """Platforms, which should be loaded during the test.""" @@ -122,17 +110,13 @@ async def test_update_firmware( assert len(mock_smlight_client.fw_update.mock_calls) == 1 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.ZB_FW_prgs - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_prgs) event_function(MOCK_FIRMWARE_PROGRESS) state = hass.states.get(entity_id) assert state.attributes[ATTR_IN_PROGRESS] == 50 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.FW_UPD_done - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.FW_UPD_done) event_function(MOCK_FIRMWARE_DONE) @@ -178,9 +162,7 @@ async def test_update_legacy_firmware_v2( assert len(mock_smlight_client.fw_update.mock_calls) == 1 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.ESP_UPD_done - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.ESP_UPD_done) event_function(MOCK_FIRMWARE_DONE) @@ -220,9 +202,7 @@ async def test_update_firmware_failed( assert len(mock_smlight_client.fw_update.mock_calls) == 1 - event_function: Callable[[MessageEvent], None] = get_callback_function( - mock_smlight_client, SmEvents.ZB_FW_err - ) + event_function = get_mock_event_function(mock_smlight_client, SmEvents.ZB_FW_err) async def _call_event_function(event: MessageEvent): event_function(event)