From 22924d7db8517cf4112b5a63079f149812554128 Mon Sep 17 00:00:00 2001 From: SamDanielThangarajan <12202554+SamDanielThangarajan@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:52:18 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20bidirectional=20protocols=20migh?= =?UTF-8?q?t=20use=20the=20same=20indicator=20for=20both=20out/in?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/nasdaq_protocols/ouch/core.py | 20 ++++++++++---------- tests/test_ouch_core.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/nasdaq_protocols/ouch/core.py b/src/nasdaq_protocols/ouch/core.py index 449f5fe..25690df 100644 --- a/src/nasdaq_protocols/ouch/core.py +++ b/src/nasdaq_protocols/ouch/core.py @@ -1,3 +1,5 @@ +from abc import ABC + import attrs from nasdaq_protocols.common import Serializable, Byte, CommonMessage, logable @@ -10,19 +12,19 @@ @attrs.define(auto_attribs=True, hash=True) -class OuchMessageId(Serializable): +class OuchMessageId(Serializable, ABC): indicator: int - direction: str = attrs.field(default="", eq=False, hash=False) + direction: str = 'outgoing' + + def to_bytes(self) -> tuple[int, bytes]: + return Byte.to_bytes(self.indicator) @classmethod def from_bytes(cls, bytes_: bytes) -> tuple[int, 'OuchMessageId']: return 1, OuchMessageId(Byte.from_bytes(bytes_)[1]) - def to_bytes(self) -> tuple[int, bytes]: - return Byte.to_bytes(self.indicator) - def __str__(self): - return f'indicator={self.indicator}' + return f'indicator={self.indicator}, direction={self.direction}' @attrs.define @@ -42,8 +44,6 @@ def __init_subclass__(cls, *args, **kwargs): if all(k in kwargs for k in ['direction', 'indicator']): kwargs['msg_id'] = OuchMessageId(kwargs['indicator'], kwargs['direction']) - if kwargs['direction'] == 'incoming': - Message.IncomingMsgClasses.append(cls) - elif kwargs['direction'] == 'outgoing': - Message.OutgoingMsgsClasses.append(cls) + container = cls.IncomingMsgClasses if kwargs['direction'] == 'incoming' else cls.OutgoingMsgsClasses + container.append(cls) super().__init_subclass__(**kwargs) diff --git a/tests/test_ouch_core.py b/tests/test_ouch_core.py index 871bd60..4dd24f6 100644 --- a/tests/test_ouch_core.py +++ b/tests/test_ouch_core.py @@ -54,7 +54,7 @@ def get(key): return msg -class TestOuchApp1MessageIn(App1OuchMessage, direction='incoming', indicator=3): +class TestOuchApp1MessageIn(App1OuchMessage, direction='incoming', indicator=1): __test__ = False class BodyRecord(Record): @@ -95,6 +95,17 @@ def test__from_bytes__different_indicator__decodes_correct_message(): assert decoded_app1_msg2_out[1] == app1_msg2_out +def test__from_bytes__same_indicator__always_decodes_only_outgoing_message(): + # When same indicator is used for both incoming and outgoing, + # from client side we always want to decode from_bytes what + # the server is sending us... + # so messages marked as "outgoing" should be decoded. + app1_msg_out = TestOuchApp1Message1Out.get(123456789) + + decoded = App1OuchMessage.from_bytes(app1_msg_out.to_bytes()[1]) + assert isinstance(decoded[1], TestOuchApp1Message1Out) + + def test__from_bytes__same_identifier_different_apps_returns_correct_message(): app1_msg = TestOuchApp1Message1Out.get(123456789) app2_msg = TestOuchApp2MessageOut.get('AB')