Skip to content

Modernize event loop handling in aiohttp tests #455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions tests/ext/aiohttp/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging

import pytest
Expand All @@ -17,22 +18,23 @@


@pytest.fixture(scope='function')
def recorder(loop):
def recorder(event_loop):
"""
Initiate a recorder and clear it up once has been used.
"""
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=event_loop))
xray_recorder.clear_trace_entities()
yield recorder
xray_recorder.clear_trace_entities()


async def test_ok(loop, recorder):
async def test_ok(recorder):
xray_recorder.begin_segment('name')
trace_config = aws_xray_trace_config()
status_code = 200
url = 'http://{}/status/{}?foo=bar'.format(BASE_URL, status_code)
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.get(url):
pass

Expand All @@ -46,25 +48,27 @@ async def test_ok(loop, recorder):
assert http_meta['response']['status'] == status_code


async def test_ok_name(loop, recorder):
async def test_ok_name(recorder):
xray_recorder.begin_segment('name')
trace_config = aws_xray_trace_config(name='test')
status_code = 200
url = 'http://{}/status/{}?foo=bar'.format(BASE_URL, status_code)
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.get(url):
pass

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.name == 'test'


async def test_error(loop, recorder):
async def test_error(recorder):
xray_recorder.begin_segment('name')
trace_config = aws_xray_trace_config()
status_code = 400
url = 'http://{}/status/{}'.format(BASE_URL, status_code)
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.post(url):
pass

Expand All @@ -78,12 +82,13 @@ async def test_error(loop, recorder):
assert http_meta['response']['status'] == status_code


async def test_throttle(loop, recorder):
async def test_throttle(recorder):
xray_recorder.begin_segment('name')
trace_config = aws_xray_trace_config()
status_code = 429
url = 'http://{}/status/{}'.format(BASE_URL, status_code)
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.head(url):
pass

Expand All @@ -98,12 +103,13 @@ async def test_throttle(loop, recorder):
assert http_meta['response']['status'] == status_code


async def test_fault(loop, recorder):
async def test_fault(recorder):
xray_recorder.begin_segment('name')
trace_config = aws_xray_trace_config()
status_code = 500
url = 'http://{}/status/{}'.format(BASE_URL, status_code)
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.put(url):
pass

Expand All @@ -117,10 +123,11 @@ async def test_fault(loop, recorder):
assert http_meta['response']['status'] == status_code


async def test_invalid_url(loop, recorder):
async def test_invalid_url(recorder):
xray_recorder.begin_segment('name')
trace_config = aws_xray_trace_config()
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
try:
async with session.get('http://doesnt.exist'):
pass
Expand All @@ -136,24 +143,26 @@ async def test_invalid_url(loop, recorder):
assert exception.type == 'ClientConnectorError'


async def test_no_segment_raise(loop, recorder):
async def test_no_segment_raise(recorder):
xray_recorder.configure(context_missing='RUNTIME_ERROR')
trace_config = aws_xray_trace_config()
status_code = 200
url = 'http://{}/status/{}?foo=bar'.format(BASE_URL, status_code)
event_loop = asyncio.get_running_loop()
with pytest.raises(SegmentNotFoundException):
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.get(url):
pass


async def test_no_segment_log_error(loop, recorder, caplog):
async def test_no_segment_log_error(recorder, caplog):
caplog.set_level(logging.ERROR)
xray_recorder.configure(context_missing='LOG_ERROR')
trace_config = aws_xray_trace_config()
status_code = 200
url = 'http://{}/status/{}?foo=bar'.format(BASE_URL, status_code)
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.get(url) as resp:
status_received = resp.status

Expand All @@ -162,13 +171,14 @@ async def test_no_segment_log_error(loop, recorder, caplog):
assert MISSING_SEGMENT_MSG in [rec.message for rec in caplog.records]


async def test_no_segment_ignore_error(loop, recorder, caplog):
async def test_no_segment_ignore_error(recorder, caplog):
caplog.set_level(logging.ERROR)
xray_recorder.configure(context_missing='IGNORE_ERROR')
trace_config = aws_xray_trace_config()
status_code = 200
url = 'http://{}/status/{}?foo=bar'.format(BASE_URL, status_code)
async with ClientSession(loop=loop, trace_configs=[trace_config]) as session:
event_loop = asyncio.get_running_loop()
async with ClientSession(loop=event_loop, trace_configs=[trace_config]) as session:
async with session.get(url) as resp:
status_received = resp.status

Expand Down
61 changes: 31 additions & 30 deletions tests/ext/aiohttp/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def app(cls, loop=None) -> web.Application:


@pytest.fixture(scope='function')
def recorder(loop):
def recorder(event_loop):
"""
Clean up context storage before and after each test run
"""
xray_recorder = get_new_stubbed_recorder()
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=event_loop))

patcher = patch('aws_xray_sdk.ext.aiohttp.middleware.xray_recorder', xray_recorder)
patcher.start()
Expand All @@ -120,15 +120,15 @@ def recorder(loop):
patcher.stop()


async def test_ok(aiohttp_client, loop, recorder):
async def test_ok(aiohttp_client, recorder):
"""
Test a normal response

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

resp = await client.get('/')
assert resp.status == 200
Expand All @@ -144,15 +144,15 @@ async def test_ok(aiohttp_client, loop, recorder):
assert response['status'] == 200


async def test_ok_x_forwarded_for(aiohttp_client, loop, recorder):
async def test_ok_x_forwarded_for(aiohttp_client, recorder):
"""
Test a normal response with x_forwarded_for headers

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

resp = await client.get('/', headers={'X-Forwarded-For': 'foo'})
assert resp.status == 200
Expand All @@ -162,15 +162,15 @@ async def test_ok_x_forwarded_for(aiohttp_client, loop, recorder):
assert segment.http['request']['x_forwarded_for']


async def test_ok_content_length(aiohttp_client, loop, recorder):
async def test_ok_content_length(aiohttp_client, recorder):
"""
Test a normal response with content length as response header

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

resp = await client.get('/?content_length=100')
assert resp.status == 200
Expand All @@ -179,15 +179,15 @@ async def test_ok_content_length(aiohttp_client, loop, recorder):
assert segment.http['response']['content_length'] == 100


async def test_error(aiohttp_client, loop, recorder):
async def test_error(aiohttp_client, recorder):
"""
Test a 4XX response

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

resp = await client.get('/error')
assert resp.status == 404
Expand All @@ -204,15 +204,15 @@ async def test_error(aiohttp_client, loop, recorder):
assert response['status'] == 404


async def test_exception(aiohttp_client, loop, recorder):
async def test_exception(aiohttp_client, recorder):
"""
Test handling an exception

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

with pytest.raises(Exception):
await client.get('/exception')
Expand All @@ -231,15 +231,15 @@ async def test_exception(aiohttp_client, loop, recorder):
assert exception.type == 'CancelledError'


async def test_unhauthorized(aiohttp_client, loop, recorder):
async def test_unhauthorized(aiohttp_client, recorder):
"""
Test a 401 response

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

resp = await client.get('/unauthorized')
assert resp.status == 401
Expand All @@ -256,8 +256,9 @@ async def test_unhauthorized(aiohttp_client, loop, recorder):
assert response['status'] == 401


async def test_response_trace_header(aiohttp_client, loop, recorder):
client = await aiohttp_client(ServerTest.app(loop=loop))
async def test_response_trace_header(aiohttp_client, recorder):
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))
resp = await client.get('/')
xray_header = resp.headers[http.XRAY_HEADER]
segment = recorder.emitter.pop()
Expand All @@ -266,15 +267,15 @@ async def test_response_trace_header(aiohttp_client, loop, recorder):
assert expected in xray_header


async def test_concurrent(aiohttp_client, loop, recorder):
async def test_concurrent(aiohttp_client, recorder):
"""
Test multiple concurrent requests

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

recorder.emitter = CustomStubbedEmitter()

Expand All @@ -283,25 +284,25 @@ async def get_delay():
assert resp.status == 200

if sys.version_info >= (3, 8):
await asyncio.wait([loop.create_task(get_delay()) for i in range(9)])
await asyncio.wait([event_loop.create_task(get_delay()) for i in range(9)])
else:
await asyncio.wait([loop.create_task(get_delay()) for i in range(9)], loop=loop)
await asyncio.wait([event_loop.create_task(get_delay()) for i in range(9)], loop=event_loop)

# Ensure all ID's are different
ids = [item.id for item in recorder.emitter.local]
assert len(ids) == len(set(ids))


async def test_disabled_sdk(aiohttp_client, loop, recorder):
async def test_disabled_sdk(aiohttp_client, recorder):
"""
Test a normal response when the SDK is disabled.

:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
global_sdk_config.set_sdk_enabled(False)
client = await aiohttp_client(ServerTest.app(loop=loop))
event_loop = asyncio.get_running_loop()
client = await aiohttp_client(ServerTest.app(loop=event_loop))

resp = await client.get('/')
assert resp.status == 200
Expand Down