From 30538d072ff4e6e6ae996b6d797b583dd1f46fef Mon Sep 17 00:00:00 2001 From: Wiktor Latanowicz Date: Fri, 24 Mar 2023 14:43:09 +0100 Subject: [PATCH 1/4] Add tests --- indi/device/events.py | 51 +++++- tests/indi/device/events/__init__.py | 0 tests/indi/device/events/test_change_event.py | 73 ++++++++ tests/indi/device/events/test_read_event.py | 44 +++++ tests/indi/device/events/test_write_event.py | 75 +++++++++ tests/indi/device/test_device.py | 2 +- tests/indi/transport/__indi__.py | 0 tests/indi/transport/test_buffer.py | 156 ++++++++++++++++++ 8 files changed, 398 insertions(+), 3 deletions(-) create mode 100644 tests/indi/device/events/__init__.py create mode 100644 tests/indi/device/events/test_change_event.py create mode 100644 tests/indi/device/events/test_read_event.py create mode 100644 tests/indi/device/events/test_write_event.py create mode 100644 tests/indi/transport/__indi__.py create mode 100644 tests/indi/transport/test_buffer.py diff --git a/indi/device/events.py b/indi/device/events.py index d566fe7..8663680 100644 --- a/indi/device/events.py +++ b/indi/device/events.py @@ -136,13 +136,34 @@ def __init__( self.vector = vector self.element = element self.prevent_default = False - self.propagate = True + + def __eq__(self, __value: object) -> bool: + return ( + isinstance(__value, self.__class__) + and self.__class__ == __value.__class__ + and self.vector == __value.vector + and self.element == __value.element + and self.prevent_default == __value.prevent_default + ) class Write(BaseEvent): """Event raised after receiving new value from client. Can be used to write new value to physical device. + + Event is raised after a message with new value is received. + It is not raised after new value is assigned in code. + To raise this event in code you should use element's + `set_value` method. + + After the write logic is done you can set `event.prevent_default` + to True in event handler's body to skip setting of element's internal + state if you want to update internal state after the device confirms + the change of state (in a separate callback). + + If you don't change `event.prevent_default` (default is False), the state + get's updated according to the new value and `Change` event is raised. """ def __init__(self, element: Element, new_value) -> None: @@ -151,11 +172,22 @@ def __init__(self, element: Element, new_value) -> None: ) self.new_value = new_value + def __eq__(self, __value: object) -> bool: + return ( + super().__eq__(__value) + and isinstance(__value, self.__class__) + and self.new_value == __value.new_value + ) + class Read(BaseEvent): """Event raised before sending value to client. Can be used to read value from physical device. + + After the read logic is complete, use `reset_value()` on the element to + synchronize it's internal state with device's state. + Don't assign value directly nor use `set_value()` - this will cause infinite recursion. """ def __init__(self, element: Element) -> None: @@ -165,7 +197,14 @@ def __init__(self, element: Element) -> None: class Change(BaseEvent): - """Event raised on value change.""" + """Event raised on value change. + + This event is raised after the internal value of element state changes. + It's raised on both: changes caused by incoming messages and assigns of new values in code. + + It is recommended to use `Read` and `Write` events to communicate with + physical device. + """ def __init__(self, element: Element, old_value, new_value) -> None: super().__init__( @@ -173,3 +212,11 @@ def __init__(self, element: Element, old_value, new_value) -> None: ) self.new_value = new_value self.old_value = old_value + + def __eq__(self, __value: object) -> bool: + return ( + super().__eq__(__value) + and isinstance(__value, self.__class__) + and self.new_value == __value.new_value + and self.old_value == __value.old_value + ) diff --git a/tests/indi/device/events/__init__.py b/tests/indi/device/events/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/indi/device/events/test_change_event.py b/tests/indi/device/events/test_change_event.py new file mode 100644 index 0000000..23fb427 --- /dev/null +++ b/tests/indi/device/events/test_change_event.py @@ -0,0 +1,73 @@ +from typing import Optional +from unittest.mock import Mock + +from indi import message +from indi.device import Driver, properties +from indi.device.events import Change, on +from indi.message import one_parts +from indi.routing import Router + + +class DummyDevice(Driver): + name = "DEVICE" + + def __init__( + self, side_effect, name: Optional[str] = None, router: Optional[Router] = None + ): + super().__init__(name, router) + self.side_effect = side_effect + + main = properties.Group( + "MAIN", + vectors=dict( + text=properties.TextVector( + "TEXT", + elements=dict( + txt=properties.Text("TXT", default="lorem"), + ), + ), + ), + ) + + @on(main.text.txt, Change) + def on_write(self, event): + self.side_effect(event) + + +def test_device_emits_change_event_on_message(): + old_value = "lorem" + new_value = "ipsum" + msg = message.NewTextVector( + device="DEVICE", + name="TEXT", + children=(one_parts.OneText(name="TXT", value=new_value),), + ) + + side_effect = Mock() + + dev = DummyDevice(side_effect) + assert dev.main.text.txt.value == old_value + + dev.message_from_client(msg) + + expected_event = Change(dev.main.text.txt, old_value, new_value) + dev.side_effect.assert_called_once_with(expected_event) + + assert dev.main.text.txt.value == new_value + + +def test_device_emits_change_event_on_assign(): + old_value = "lorem" + new_value = "ipsum" + + side_effect = Mock() + + dev = DummyDevice(side_effect) + assert dev.main.text.txt.value == old_value + + dev.main.text.txt.value = new_value + + expected_event = Change(dev.main.text.txt, old_value, new_value) + dev.side_effect.assert_called_once_with(expected_event) + + assert dev.main.text.txt.value == new_value diff --git a/tests/indi/device/events/test_read_event.py b/tests/indi/device/events/test_read_event.py new file mode 100644 index 0000000..3bf4801 --- /dev/null +++ b/tests/indi/device/events/test_read_event.py @@ -0,0 +1,44 @@ +from typing import Optional +from unittest.mock import Mock + +from indi import message +from indi.device import Driver, properties +from indi.device.events import Read, on +from indi.message import one_parts +from indi.routing import Router + + +class DummyDevice(Driver): + name = "DEVICE" + + def __init__( + self, side_effect, name: Optional[str] = None, router: Optional[Router] = None + ): + super().__init__(name, router) + self.side_effect = side_effect + + main = properties.Group( + "MAIN", + vectors=dict( + text=properties.TextVector( + "TEXT", + elements=dict( + txt=properties.Text("TXT", default="lorem"), + ), + ), + ), + ) + + @on(main.text.txt, Read) + def on_write(self, event): + self.side_effect(event) + + +def test_device_emits_read_event(): + new_value = "ipsum" + + def side_effect(event): + event.element.reset_value(new_value) + + dev = DummyDevice(side_effect) + assert dev.main.text.txt.value == new_value diff --git a/tests/indi/device/events/test_write_event.py b/tests/indi/device/events/test_write_event.py new file mode 100644 index 0000000..2439988 --- /dev/null +++ b/tests/indi/device/events/test_write_event.py @@ -0,0 +1,75 @@ +from typing import Optional +from unittest.mock import Mock + +from indi import message +from indi.device import Driver, properties +from indi.device.events import Write, on +from indi.message import one_parts +from indi.routing import Router + + +class DummyDevice(Driver): + name = "DEVICE" + + def __init__( + self, side_effect, name: Optional[str] = None, router: Optional[Router] = None + ): + super().__init__(name, router) + self.side_effect = side_effect + + main = properties.Group( + "MAIN", + vectors=dict( + text=properties.TextVector( + "TEXT", + elements=dict( + txt=properties.Text("TXT", default="lorem"), + ), + ), + ), + ) + + @on(main.text.txt, Write) + def on_write(self, event): + self.side_effect(event) + + +def test_device_emits_write_event(): + new_value = "ipsum" + msg = message.NewTextVector( + device="DEVICE", + name="TEXT", + children=(one_parts.OneText(name="TXT", value=new_value),), + ) + + side_effect = Mock() + + dev = DummyDevice(side_effect) + assert dev.main.text.txt.value == "lorem" + + dev.message_from_client(msg) + + expected_event = Write(dev.main.text.txt, new_value) + dev.side_effect.assert_called_once_with(expected_event) + + assert dev.main.text.txt.value == new_value + + +def test_device_write_event_prevent_default(): + new_value = "ipsum" + msg = message.NewTextVector( + device="DEVICE", + name="TEXT", + children=(one_parts.OneText(name="TXT", value=new_value),), + ) + + def side_effect(event): + event.prevent_default = True + + dev = DummyDevice(side_effect) + assert dev.main.text.txt.value == "lorem" + + dev.message_from_client(msg) + + assert dev.main.text.txt.value == "lorem" + assert dev.main.text.txt.value != new_value diff --git a/tests/indi/device/test_device.py b/tests/indi/device/test_device.py index 6a2859e..566480a 100644 --- a/tests/indi/device/test_device.py +++ b/tests/indi/device/test_device.py @@ -112,7 +112,7 @@ def test_device_process_new_text_vector_message(): assert dev.main.text.txt.value == "ipsum" -def test_device_process_new_text_vector_message(): +def test_device_process_new_blob_vector_message(): msg = message.NewBLOBVector( device="DEVICE", name="BLOB", diff --git a/tests/indi/transport/__indi__.py b/tests/indi/transport/__indi__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/indi/transport/test_buffer.py b/tests/indi/transport/test_buffer.py new file mode 100644 index 0000000..ee3c0bb --- /dev/null +++ b/tests/indi/transport/test_buffer.py @@ -0,0 +1,156 @@ +from itertools import chain, combinations +from random import choice, randint, random + +import pytest + +from indi import message +from indi.message import const, one_parts +from indi.transport import Buffer + +NUM_RANDOM_TEST_CASES = 50 +NUM_MESSAGES_IN_RANDOM_TEST_CASE = 100 + + +def powerset(iterable): + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +noise_messages = ["", ""] + +raw_messages = [ + '', + '', + '', + '2.0', +] + +indi_messages = [ + message.GetProperties(version="1.0"), + message.GetProperties(version="1.0", device="Camera"), + message.GetProperties(version="1.0", device="Camera", name="EXPOSURE"), + message.SetTextVector( + device="CAMERA", + name="EXPOSE", + state=const.State.ALERT, + children=[one_parts.OneText(name="EXPOSE_TIME", value="2.0")], + ), +] + + +def random_test_case(size, with_noise=False): + input_strings = [] + output_messages = [] + + choices = list(range(min(len(raw_messages), len(indi_messages)))) + + for _ in range(size): + if with_noise and random() > 0.5: + for _ in range(randint(1, 5)): + input_strings.append(choice(noise_messages)) + + idx = choice(choices) + input_strings.append(raw_messages[idx]) + output_messages.append(indi_messages[idx]) + + if with_noise and random() > 0.5: + for _ in range(randint(1, 5)): + input_strings.append(choice(noise_messages)) + + return ( + tuple(input_strings), + tuple(output_messages), + ) + + +def random_test_cases(count, size, with_noise=False): + for _ in range(count): + yield random_test_case(size, with_noise=with_noise) + + +manual_test_cases = [] + + +@pytest.mark.parametrize( + "input_strings,expected_output_messages", + manual_test_cases + + list(random_test_cases(NUM_RANDOM_TEST_CASES, NUM_MESSAGES_IN_RANDOM_TEST_CASE)) + + list( + random_test_cases( + NUM_RANDOM_TEST_CASES, NUM_MESSAGES_IN_RANDOM_TEST_CASE, with_noise=True + ) + ), +) +def test_buffer_individual_messages(input_strings, expected_output_messages): + output_messages = [] + + buffer = Buffer() + + def callback(msg): + output_messages.append(msg) + + for b in input_strings: + buffer.append(b) + buffer.process(callback) + + assert len(expected_output_messages) == len(output_messages) + assert tuple(expected_output_messages) == tuple(output_messages) + + +@pytest.mark.parametrize( + "input_strings,expected_output_messages", + manual_test_cases + + list(random_test_cases(NUM_RANDOM_TEST_CASES, NUM_MESSAGES_IN_RANDOM_TEST_CASE)) + + list( + random_test_cases( + NUM_RANDOM_TEST_CASES, NUM_MESSAGES_IN_RANDOM_TEST_CASE, with_noise=True + ) + ), +) +def test_buffer_all_at_once(input_strings, expected_output_messages): + output_messages = [] + + buffer = Buffer() + + def callback(msg): + output_messages.append(msg) + + complete_input = "".join(input_strings) + buffer.append(complete_input) + + buffer.process(callback) + + assert len(expected_output_messages) == len(output_messages) + assert tuple(expected_output_messages) == tuple(output_messages) + + +@pytest.mark.parametrize( + "input_strings,expected_output_messages", + manual_test_cases + + list(random_test_cases(NUM_RANDOM_TEST_CASES, NUM_MESSAGES_IN_RANDOM_TEST_CASE)) + + list( + random_test_cases( + NUM_RANDOM_TEST_CASES, NUM_MESSAGES_IN_RANDOM_TEST_CASE, with_noise=True + ) + ), +) +def test_buffer_random_length_reads(input_strings, expected_output_messages): + output_messages = [] + + buffer = Buffer() + + def callback(msg): + output_messages.append(msg) + + complete_input = "".join(input_strings) + + while complete_input: + num_chars_to_append = randint(1, 1024) + chars_to_append = complete_input[:num_chars_to_append] + complete_input = complete_input[num_chars_to_append:] + + buffer.append(chars_to_append) + buffer.process(callback) + + assert len(expected_output_messages) == len(output_messages) + assert tuple(expected_output_messages) == tuple(output_messages) From 8e53a70001e1af9f8f1dff6e7a1cbd907ad0616f Mon Sep 17 00:00:00 2001 From: Wiktor Latanowicz Date: Sun, 26 Mar 2023 20:19:27 +0200 Subject: [PATCH 2/4] Fixes in buffer implementation --- indi/transport/buffer.py | 95 ++++++++++++++++++++--------- tests/indi/transport/test_buffer.py | 40 ++++++++++-- 2 files changed, 102 insertions(+), 33 deletions(-) diff --git a/indi/transport/buffer.py b/indi/transport/buffer.py index 2eaea7f..8b781f0 100644 --- a/indi/transport/buffer.py +++ b/indi/transport/buffer.py @@ -16,40 +16,79 @@ def append(self, data: str): self.data += data def _cleanup_buffer(self): - start = len(self.data) - 1 + start = None + + # find first occurrence of any known xml tag: for tag in self.allowed_tags: - start = min(start, self.data.find("<" + tag)) + lookup = "<" + tag + found_pos = self.data.find(lookup) + if found_pos >= 0: + start = min(start, found_pos) if start is not None else found_pos - if start >= 0: - self.data = self.data[start:] + if start == 0: + break - def process(self, callback: Callable[[IndiMessage], None]): + if start is not None: + if start > 0: + self.data = self.data[start:] + return + + # if no known tags found + # search for the last xml tag opening + # just in case it's the part of valid message + # and the rest will arrive soon + last_tag_pos = self.data.rfind("<") + if last_tag_pos >= 0: + start = last_tag_pos + + if start is not None: + if start > 0: + self.data = self.data[start:] + return + + # neither known tag nor xml opening found in the buffer + # we can safely assume everything is junk and discard it + self.data = "" + + def _cleanup_beginning(self): + self.data = self.data[1:] self._cleanup_buffer() + + def _find_message_in_buffer(self): end = 0 - while len(self.data) > 0 and end >= 0: + while end < len(self.data) - 1: end = self.data.find(">", end) + if end < 0: + return None, None + + end += 1 - if end > 0: - end += 1 - partial = self.data[0:end] + partial = self.data[:end] + try: + ET.fromstring(partial) + is_correct_xml = True + except ET.ParseError: + is_correct_xml = False + + if is_correct_xml: try: - ET.fromstring(partial) - is_correct_xml = True - except ET.ParseError: - is_correct_xml = False - - if is_correct_xml: - self.data = self.data[end:] - end = 0 - message = None - try: - message = IndiMessage.from_string(partial) - except Exception: - logger.warning("Buffer: Contents is not a valid message") - - if message: - try: - callback(message) - except Exception: - logger.exception("Error procesing message") + message = IndiMessage.from_string(partial) + return message, end + except Exception: + logger.warning("Buffer: Contents is not a valid message") + return None, None + + def process(self, callback: Callable[[IndiMessage], None]): + self._cleanup_buffer() + while self.data: + message, end = self._find_message_in_buffer() + if not message: + if len(self.data) > 1024: + self._cleanup_beginning() + continue + break + + self.data = self.data[end:] + self._cleanup_buffer() + callback(message) diff --git a/tests/indi/transport/test_buffer.py b/tests/indi/transport/test_buffer.py index ee3c0bb..9edfa47 100644 --- a/tests/indi/transport/test_buffer.py +++ b/tests/indi/transport/test_buffer.py @@ -1,3 +1,4 @@ +import string from itertools import chain, combinations from random import choice, randint, random @@ -16,7 +17,12 @@ def powerset(iterable): return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) -noise_messages = ["", ""] +noise_messages = [ + "", + "", + '', + "", +] raw_messages = [ '', @@ -37,6 +43,23 @@ def powerset(iterable): ), ] +manual_test_cases = [ + [ + ( + "junk", + "junk2", + ), + (), + ], + [ + ( + "junk", + raw_messages[0], + ), + (indi_messages[0],), + ], +] + def random_test_case(size, with_noise=False): input_strings = [] @@ -63,12 +86,18 @@ def random_test_case(size, with_noise=False): ) -def random_test_cases(count, size, with_noise=False): - for _ in range(count): - yield random_test_case(size, with_noise=with_noise) +def random_string(length): + chars = string.ascii_letters + string.digits + return "".join(choice(chars) for i in range(length)) -manual_test_cases = [] +def random_test_cases(count, size, with_noise=False): + for _ in range(count): + input_strings, output_messages = random_test_case(size, with_noise=with_noise) + if with_noise: + noise_to_append = random_string(1024) + input_strings = input_strings + (noise_to_append,) + yield input_strings, output_messages @pytest.mark.parametrize( @@ -124,6 +153,7 @@ def callback(msg): assert tuple(expected_output_messages) == tuple(output_messages) +# @pytest.mark.skip() @pytest.mark.parametrize( "input_strings,expected_output_messages", manual_test_cases From 9c117ae7101ad97bd39d84b6b8b7c2cd00e7ba9d Mon Sep 17 00:00:00 2001 From: Wiktor Latanowicz Date: Sun, 26 Mar 2023 21:18:54 +0200 Subject: [PATCH 3/4] Buffer performance improvements --- indi/message/base.py | 18 +++++++------ indi/message/defs.py | 5 ++++ indi/message/del_property.py | 1 + indi/message/enable_blob.py | 1 + indi/message/get_properties.py | 1 + indi/message/news.py | 4 +++ indi/message/one_light.py | 1 + indi/message/pings.py | 2 ++ indi/message/sets.py | 5 ++++ indi/transport/buffer.py | 42 ++++++++++++++++++++--------- tests/indi/transport/test_buffer.py | 3 +++ 11 files changed, 63 insertions(+), 20 deletions(-) diff --git a/indi/message/base.py b/indi/message/base.py index b3ecee9..b886cc7 100644 --- a/indi/message/base.py +++ b/indi/message/base.py @@ -11,6 +11,8 @@ class IndiMessage: from_device = False from_client = False + _message_classes: List[Type[IndiMessage]] = [] + def __init__(self, device=None, **junk): self.device = device @@ -19,20 +21,20 @@ def tag_name(cls): return cls.__name__[:1].lower() + cls.__name__[1:] @classmethod - def __all_subclasses__(cls) -> Tuple[Type[IndiMessage], ...]: - subclasses = [] - for subclass in cls.__subclasses__(): - subclasses.append(subclass) - for nested_subclass in subclass.__all_subclasses__(): - subclasses.append(nested_subclass) - return tuple(subclasses) + def register_message(cls, message_class): + cls._message_classes.append(message_class) + return message_class + + @classmethod + def all_message_classes(cls): + return cls._message_classes @classmethod def from_xml(cls, xml: ET.Element) -> IndiMessage: tag = xml.tag message_class = None - for subclass in cls.__all_subclasses__(): + for subclass in cls.all_message_classes(): if subclass.tag_name() == tag: message_class = subclass diff --git a/indi/message/defs.py b/indi/message/defs.py index e920f03..27c49d9 100644 --- a/indi/message/defs.py +++ b/indi/message/defs.py @@ -58,18 +58,22 @@ def __init__( self.timeout = timeout +@IndiMessage.register_message class DefBLOBVector(DefWritableVector): children_class = DefBLOB +@IndiMessage.register_message class DefLightVector(DefVector): children_class = DefLight +@IndiMessage.register_message class DefNumberVector(DefWritableVector): children_class = DefNumber +@IndiMessage.register_message class DefSwitchVector(DefWritableVector): children_class = DefSwitch @@ -78,5 +82,6 @@ def __init__(self, *args, rule: const.SwitchRuleType, **kwargs): self.rule = checks.dictionary(rule, const.SwitchRule) +@IndiMessage.register_message class DefTextVector(DefWritableVector): children_class = DefText diff --git a/indi/message/del_property.py b/indi/message/del_property.py index cac8981..f533539 100644 --- a/indi/message/del_property.py +++ b/indi/message/del_property.py @@ -8,6 +8,7 @@ from indi.message import TimestampType +@IndiMessage.register_message class DelProperty(IndiMessage): from_device = True diff --git a/indi/message/enable_blob.py b/indi/message/enable_blob.py index e5d3618..4e90f78 100644 --- a/indi/message/enable_blob.py +++ b/indi/message/enable_blob.py @@ -4,6 +4,7 @@ from indi.message.base import IndiMessage +@IndiMessage.register_message class EnableBLOB(IndiMessage): from_client = True diff --git a/indi/message/get_properties.py b/indi/message/get_properties.py index 806d5da..00ba47e 100644 --- a/indi/message/get_properties.py +++ b/indi/message/get_properties.py @@ -3,6 +3,7 @@ from indi.message.base import IndiMessage +@IndiMessage.register_message class GetProperties(IndiMessage): from_device = True from_client = True diff --git a/indi/message/news.py b/indi/message/news.py index 9e2c50b..a8c8d2b 100644 --- a/indi/message/news.py +++ b/indi/message/news.py @@ -36,17 +36,21 @@ def __init__( self.children = checks.children(children, self.children_class) +@IndiMessage.register_message class NewBLOBVector(NewVector): children_class = OneBLOB +@IndiMessage.register_message class NewNumberVector(NewVector): children_class = OneNumber +@IndiMessage.register_message class NewSwitchVector(NewVector): children_class = OneSwitch +@IndiMessage.register_message class NewTextVector(NewVector): children_class = OneText diff --git a/indi/message/one_light.py b/indi/message/one_light.py index 4623072..53e2139 100644 --- a/indi/message/one_light.py +++ b/indi/message/one_light.py @@ -1,6 +1,7 @@ from indi.message.base import IndiMessage +@IndiMessage.register_message class OneLight(IndiMessage): from_device = True diff --git a/indi/message/pings.py b/indi/message/pings.py index d6b5e14..69fa57a 100644 --- a/indi/message/pings.py +++ b/indi/message/pings.py @@ -1,6 +1,7 @@ from indi.message.base import IndiMessage +@IndiMessage.register_message class PingReply(IndiMessage): from_client = True @@ -9,6 +10,7 @@ def __init__(self, uid: str, **junk): self.uid = uid +@IndiMessage.register_message class PingRequest(IndiMessage): from_device = True diff --git a/indi/message/sets.py b/indi/message/sets.py index 584580d..958ec3a 100644 --- a/indi/message/sets.py +++ b/indi/message/sets.py @@ -43,21 +43,26 @@ def __init__( self.children = checks.children(children, self.child_class) +@IndiMessage.register_message class SetBLOBVector(SetVector): child_class = OneBLOB +@IndiMessage.register_message class SetLightVector(SetVector): child_class = OneLight +@IndiMessage.register_message class SetNumberVector(SetVector): child_class = OneNumber +@IndiMessage.register_message class SetSwitchVector(SetVector): child_class = OneSwitch +@IndiMessage.register_message class SetTextVector(SetVector): child_class = OneText diff --git a/indi/transport/buffer.py b/indi/transport/buffer.py index 8b781f0..64ad868 100644 --- a/indi/transport/buffer.py +++ b/indi/transport/buffer.py @@ -1,5 +1,6 @@ import logging import xml.etree.ElementTree as ET +from io import StringIO from typing import Callable from indi.message import IndiMessage @@ -9,19 +10,34 @@ class Buffer: def __init__(self) -> None: - self.data = "" - self.allowed_tags = [m.tag_name() for m in IndiMessage.__all_subclasses__()] + self.max_buffer_size_before_frontal_cleanup = 2048 + self.buffer = StringIO() + self.allowed_tags = [m.tag_name() for m in IndiMessage.all_message_classes()] def append(self, data: str): - self.data += data + self.buffer.write(data) + + @property + def data(self) -> str: + return self.buffer.getvalue() + + @data.setter + def data(self, value: str): + self.buffer = StringIO() + self.append(value) + + @property + def data_len(self): + return self.buffer.tell() def _cleanup_buffer(self): start = None + data = self.data # find first occurrence of any known xml tag: for tag in self.allowed_tags: lookup = "<" + tag - found_pos = self.data.find(lookup) + found_pos = data.find(lookup) if found_pos >= 0: start = min(start, found_pos) if start is not None else found_pos @@ -30,20 +46,20 @@ def _cleanup_buffer(self): if start is not None: if start > 0: - self.data = self.data[start:] + self.data = data[start:] return # if no known tags found # search for the last xml tag opening # just in case it's the part of valid message # and the rest will arrive soon - last_tag_pos = self.data.rfind("<") + last_tag_pos = data.rfind("<") if last_tag_pos >= 0: start = last_tag_pos if start is not None: if start > 0: - self.data = self.data[start:] + self.data = data[start:] return # neither known tag nor xml opening found in the buffer @@ -56,14 +72,15 @@ def _cleanup_beginning(self): def _find_message_in_buffer(self): end = 0 - while end < len(self.data) - 1: - end = self.data.find(">", end) + data = self.data + while end < len(data) - 1: + end = data.find(">", end) if end < 0: return None, None end += 1 - partial = self.data[:end] + partial = data[:end] try: ET.fromstring(partial) @@ -81,10 +98,11 @@ def _find_message_in_buffer(self): def process(self, callback: Callable[[IndiMessage], None]): self._cleanup_buffer() - while self.data: + while self.data_len: message, end = self._find_message_in_buffer() + if not message: - if len(self.data) > 1024: + if self.data_len > self.max_buffer_size_before_frontal_cleanup: self._cleanup_beginning() continue break diff --git a/tests/indi/transport/test_buffer.py b/tests/indi/transport/test_buffer.py index 9edfa47..2568fe0 100644 --- a/tests/indi/transport/test_buffer.py +++ b/tests/indi/transport/test_buffer.py @@ -114,6 +114,7 @@ def test_buffer_individual_messages(input_strings, expected_output_messages): output_messages = [] buffer = Buffer() + buffer.max_buffer_size_before_frontal_cleanup = 128 def callback(msg): output_messages.append(msg) @@ -140,6 +141,7 @@ def test_buffer_all_at_once(input_strings, expected_output_messages): output_messages = [] buffer = Buffer() + buffer.max_buffer_size_before_frontal_cleanup = 128 def callback(msg): output_messages.append(msg) @@ -168,6 +170,7 @@ def test_buffer_random_length_reads(input_strings, expected_output_messages): output_messages = [] buffer = Buffer() + buffer.max_buffer_size_before_frontal_cleanup = 128 def callback(msg): output_messages.append(msg) From 38d1646206d4317f85af767d0b47423bba00259a Mon Sep 17 00:00:00 2001 From: Wiktor Latanowicz Date: Mon, 27 Mar 2023 12:26:41 +0200 Subject: [PATCH 4/4] Disable cleanup at the beginning of the buffer for blob connections --- indi/client/client.py | 2 +- indi/transport/buffer.py | 6 +++--- indi/transport/client/tcp.py | 8 ++++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/indi/client/client.py b/indi/client/client.py index 05448b9..1a632f8 100644 --- a/indi/client/client.py +++ b/indi/client/client.py @@ -351,7 +351,7 @@ async def start(self): self.process_message ) self.blob_connection_handler = await self.blob_connection.connect( - self.process_message + self.process_message, for_blobs=True ) asyncio.get_running_loop().create_task( diff --git a/indi/transport/buffer.py b/indi/transport/buffer.py index 64ad868..022e977 100644 --- a/indi/transport/buffer.py +++ b/indi/transport/buffer.py @@ -1,7 +1,7 @@ import logging import xml.etree.ElementTree as ET from io import StringIO -from typing import Callable +from typing import Callable, Optional from indi.message import IndiMessage @@ -10,7 +10,7 @@ class Buffer: def __init__(self) -> None: - self.max_buffer_size_before_frontal_cleanup = 2048 + self.max_buffer_size_before_frontal_cleanup: Optional[int] = 2048 self.buffer = StringIO() self.allowed_tags = [m.tag_name() for m in IndiMessage.all_message_classes()] @@ -101,7 +101,7 @@ def process(self, callback: Callable[[IndiMessage], None]): while self.data_len: message, end = self._find_message_in_buffer() - if not message: + if not message and self.max_buffer_size_before_frontal_cleanup is not None: if self.data_len > self.max_buffer_size_before_frontal_cleanup: self._cleanup_beginning() continue diff --git a/indi/transport/client/tcp.py b/indi/transport/client/tcp.py index 145a574..9eff586 100644 --- a/indi/transport/client/tcp.py +++ b/indi/transport/client/tcp.py @@ -14,8 +14,12 @@ def __init__( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, callback: Callable[[IndiMessage], None], + for_blobs=False, ): self.buffer = Buffer() + if for_blobs: + self.buffer.max_buffer_size_before_frontal_cleanup = None + self.reader, self.writer = reader, writer self.callback = callback self.sender_lock = asyncio.Lock() @@ -53,7 +57,7 @@ def __init__(self, address: str = "127.0.0.1", port: int = 7624): self.address = address self.port = port - async def connect(self, callback: Callable[[IndiMessage], None]): + async def connect(self, callback: Callable[[IndiMessage], None], for_blobs=False): reader, writer = await asyncio.open_connection(self.address, self.port) - handler = ConnectionHandler(reader, writer, callback) + handler = ConnectionHandler(reader, writer, callback, for_blobs=for_blobs) return handler