diff --git a/tda/streaming.py b/tda/streaming.py index f3d6772..4736a49 100644 --- a/tda/streaming.py +++ b/tda/streaming.py @@ -13,6 +13,7 @@ import urllib.parse import websockets.legacy.client as ws_client +from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory from .utils import EnumEnforcer @@ -229,11 +230,15 @@ async def _init_from_principals(self, principals): # Initialize socket wss_url = 'wss://{}/ws'.format( principals['streamerInfo']['streamerSocketUrl']) + + ws_connect_args = { + 'extensions': [ClientPerMessageDeflateFactory()], + } + if self._ssl_context: - self._socket = await ws_client.connect( - wss_url, ssl=self._ssl_context) - else: - self._socket = await ws_client.connect(wss_url) + ws_connect_args['ssl'] = self._ssl_context + + self._socket = await ws_client.connect(wss_url, **ws_connect_args) # Initialize miscellaneous parameters self._source = principals['streamerInfo']['appId'] diff --git a/tests/streaming_test.py b/tests/streaming_test.py index ba624d7..4b0486a 100644 --- a/tests/streaming_test.py +++ b/tests/streaming_test.py @@ -369,7 +369,7 @@ async def test_login_no_ssl_context(self, ws_connect): await self.client.login() - ws_connect.assert_awaited_once_with(ANY) + ws_connect.assert_awaited_once_with(ANY, extensions=[ANY]) @no_duplicates @@ -387,7 +387,8 @@ async def test_login_ssl_context(self, ws_connect): await self.client.login() - ws_connect.assert_awaited_once_with(ANY, ssl='ssl_context') + ws_connect.assert_awaited_once_with( + ANY, ssl='ssl_context', extensions=[ANY]) @no_duplicates