diff --git a/src/nasdaq_protocols/ouch/core.py b/src/nasdaq_protocols/ouch/core.py index 449f5fe..25690df 100644 --- a/src/nasdaq_protocols/ouch/core.py +++ b/src/nasdaq_protocols/ouch/core.py @@ -1,3 +1,5 @@ +from abc import ABC + import attrs from nasdaq_protocols.common import Serializable, Byte, CommonMessage, logable @@ -10,19 +12,19 @@ @attrs.define(auto_attribs=True, hash=True) -class OuchMessageId(Serializable): +class OuchMessageId(Serializable, ABC): indicator: int - direction: str = attrs.field(default="", eq=False, hash=False) + direction: str = 'outgoing' + + def to_bytes(self) -> tuple[int, bytes]: + return Byte.to_bytes(self.indicator) @classmethod def from_bytes(cls, bytes_: bytes) -> tuple[int, 'OuchMessageId']: return 1, OuchMessageId(Byte.from_bytes(bytes_)[1]) - def to_bytes(self) -> tuple[int, bytes]: - return Byte.to_bytes(self.indicator) - def __str__(self): - return f'indicator={self.indicator}' + return f'indicator={self.indicator}, direction={self.direction}' @attrs.define @@ -42,8 +44,6 @@ def __init_subclass__(cls, *args, **kwargs): if all(k in kwargs for k in ['direction', 'indicator']): kwargs['msg_id'] = OuchMessageId(kwargs['indicator'], kwargs['direction']) - if kwargs['direction'] == 'incoming': - Message.IncomingMsgClasses.append(cls) - elif kwargs['direction'] == 'outgoing': - Message.OutgoingMsgsClasses.append(cls) + container = cls.IncomingMsgClasses if kwargs['direction'] == 'incoming' else cls.OutgoingMsgsClasses + container.append(cls) super().__init_subclass__(**kwargs) diff --git a/tests/test_ouch_core.py b/tests/test_ouch_core.py index 871bd60..4dd24f6 100644 --- a/tests/test_ouch_core.py +++ b/tests/test_ouch_core.py @@ -54,7 +54,7 @@ def get(key): return msg -class TestOuchApp1MessageIn(App1OuchMessage, direction='incoming', indicator=3): +class TestOuchApp1MessageIn(App1OuchMessage, direction='incoming', indicator=1): __test__ = False class BodyRecord(Record): @@ -95,6 +95,17 @@ def test__from_bytes__different_indicator__decodes_correct_message(): assert decoded_app1_msg2_out[1] == app1_msg2_out +def test__from_bytes__same_indicator__always_decodes_only_outgoing_message(): + # When same indicator is used for both incoming and outgoing, + # from client side we always want to decode from_bytes what + # the server is sending us... + # so messages marked as "outgoing" should be decoded. + app1_msg_out = TestOuchApp1Message1Out.get(123456789) + + decoded = App1OuchMessage.from_bytes(app1_msg_out.to_bytes()[1]) + assert isinstance(decoded[1], TestOuchApp1Message1Out) + + def test__from_bytes__same_identifier_different_apps_returns_correct_message(): app1_msg = TestOuchApp1Message1Out.get(123456789) app2_msg = TestOuchApp2MessageOut.get('AB')