From a0d38314da7cf222f7bd034ba5ba53590cc511f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20L=C3=B3pez?= Date: Thu, 12 Oct 2023 21:17:23 -0600 Subject: [PATCH] add union types validation with pydantic logic --- examples/tasks/task_example.py | 13 ++-- fast_agave/exc.py | 5 ++ fast_agave/tasks/sqs_tasks.py | 15 ++--- fast_agave/version.py | 2 +- tests/tasks/test_sqs_tasks.py | 116 +++++++++++++++++++++++++++------ 5 files changed, 118 insertions(+), 33 deletions(-) diff --git a/examples/tasks/task_example.py b/examples/tasks/task_example.py index bb3e63f..1a84c95 100644 --- a/examples/tasks/task_example.py +++ b/examples/tasks/task_example.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from fast_agave.tasks.sqs_tasks import task @@ -10,17 +10,22 @@ QUEUE2_URL = 'http://127.0.0.1:4000/123456789012/validator.fifo' -class ValidatorModel(BaseModel): +class User(BaseModel): name: str age: int nick_name: Optional[str] +class Company(BaseModel): + legal_name: str + rfc: str + + @task(queue_url=QUEUE_URL, region_name='us-east-1') async def dummy_task(message) -> None: print(message) -@task(queue_url=QUEUE2_URL, region_name='us-east-1', validator=ValidatorModel) -async def task_validator(message: ValidatorModel) -> None: +@task(queue_url=QUEUE2_URL, region_name='us-east-1') +async def task_validator(message: Union[User, Company]) -> None: print(message.dict()) diff --git a/fast_agave/exc.py b/fast_agave/exc.py index be1ee74..978d594 100644 --- a/fast_agave/exc.py +++ b/fast_agave/exc.py @@ -56,3 +56,8 @@ class FastAgaveViewError(FastAgaveError): @dataclass class RetryTask(Exception): countdown: Optional[int] = None + + +@dataclass +class TaskDefinitionError(Exception): + error: str diff --git a/fast_agave/tasks/sqs_tasks.py b/fast_agave/tasks/sqs_tasks.py index 4ea5102..8440fa2 100644 --- a/fast_agave/tasks/sqs_tasks.py +++ b/fast_agave/tasks/sqs_tasks.py @@ -4,11 +4,11 @@ from functools import wraps from itertools import count from json import JSONDecodeError -from typing import AsyncGenerator, Callable, Coroutine, Optional, Type +from typing import AsyncGenerator, Callable, Coroutine from aiobotocore.httpsession import HTTPClientError from aiobotocore.session import get_session -from pydantic import BaseModel +from pydantic import BaseModel, validate_arguments from ..exc import RetryTask @@ -25,12 +25,10 @@ async def run_task( receipt_handle: str, message_receive_count: int, max_retries: int, - validator: Optional[Type[BaseModel]] = None, ) -> None: delete_message = True try: - data = validator(**body) if validator else body - await task_func(data) + await task_func(body) except RetryTask as retry: delete_message = message_receive_count >= max_retries + 1 if not delete_message and retry.countdown and retry.countdown > 0: @@ -88,7 +86,6 @@ def task( visibility_timeout: int = 3600, max_retries: int = 1, max_concurrent_tasks: int = 5, - validator: Optional[Type[BaseModel]] = None, ): def task_builder(task_func: Callable): @wraps(task_func) @@ -108,6 +105,9 @@ async def concurrency_controller(coro: Coroutine) -> None: can_read.set() session = get_session() + + task_with_validators = validate_arguments(task_func) + async with session.create_client('sqs', region_name) as sqs: async for message in message_consumer( queue_url, @@ -127,14 +127,13 @@ async def concurrency_controller(coro: Coroutine) -> None: bg_task = asyncio.create_task( concurrency_controller( run_task( - task_func, + task_with_validators, body, sqs, queue_url, message['ReceiptHandle'], message_receive_count, max_retries, - validator, ), ), name='fast-agave-task', diff --git a/fast_agave/version.py b/fast_agave/version.py index 2d7893e..c1afd70 100644 --- a/fast_agave/version.py +++ b/fast_agave/version.py @@ -1 +1 @@ -__version__ = '0.13.0' +__version__ = '0.14.0.dev0' diff --git a/tests/tasks/test_sqs_tasks.py b/tests/tasks/test_sqs_tasks.py index 559ace0..cb4ea66 100644 --- a/tests/tasks/test_sqs_tasks.py +++ b/tests/tasks/test_sqs_tasks.py @@ -2,6 +2,7 @@ import datetime as dt import json import uuid +from typing import Dict, Union from unittest.mock import AsyncMock, call, patch import aiobotocore.client @@ -32,13 +33,17 @@ async def test_execute_tasks(sqs_client) -> None: MessageGroupId='1234', ) - async_mock_function = AsyncMock(return_value=None) + async_mock_function = AsyncMock() + + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_called_with(test_message) assert async_mock_function.call_count == 1 @@ -48,26 +53,28 @@ async def test_execute_tasks(sqs_client) -> None: @pytest.mark.asyncio -async def test_execute_tasks_validator(sqs_client) -> None: - async_mock_function = AsyncMock(return_value=None) - +async def test_execute_tasks_with_validator(sqs_client) -> None: class Validator(BaseModel): id: str name: str + async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Validator) -> None: + await async_mock_function(data) + task_params = dict( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - validator=Validator, ) # Invalid body, not execute function await sqs_client.send_message( MessageBody=json.dumps(dict(foo='bar')), MessageGroupId='4321', ) - await task(**task_params)(async_mock_function)() + await task(**task_params)(my_task)() assert async_mock_function.call_count == 0 resp = await sqs_client.receive_message() assert 'Messages' not in resp @@ -78,7 +85,59 @@ class Validator(BaseModel): MessageBody=test_message.json(), MessageGroupId='1234', ) - await task(**task_params)(async_mock_function)() + await task(**task_params)(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + +@pytest.mark.asyncio +async def test_execute_tasks_with_union_validator(sqs_client) -> None: + class User(BaseModel): + id: str + name: str + + class Company(BaseModel): + id: str + legal_name: str + rfc: str + + async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Union[User, Company]) -> None: + await async_mock_function(data) + + task_params = dict( + queue_url=sqs_client.queue_url, + region_name=CORE_QUEUE_REGION, + wait_time_seconds=1, + visibility_timeout=1, + ) + # Invalid body, not execute function + test_message = dict(id='ID123', name='Sor Juana Inés de la Cruz') + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='4321', + ) + await task(**task_params)(my_task)() + async_mock_function.assert_called_with(test_message) + assert async_mock_function.call_count == 1 + + resp = await sqs_client.receive_message() + assert 'Messages' not in resp + assert len(BACKGROUND_TASKS) == 0 + + async_mock_function.reset_mock() + test_message = dict(id='ID123', legal_name='FastAgave', rfc='FA') + + await sqs_client.send_message( + MessageBody=json.dumps(test_message), + MessageGroupId='54321', + ) + await task(**task_params)(my_task)() async_mock_function.assert_called_with(test_message) assert async_mock_function.call_count == 1 @@ -92,14 +151,18 @@ async def test_not_execute_tasks(sqs_client) -> None: """ Este caso es cuando el queue está vacío. No hay nada que ejecutar """ - async_mock_function = AsyncMock(return_value=None) + async_mock_function = AsyncMock() + + async def my_task(data: Dict) -> None: + await async_mock_function(data) + # No escribimos un mensaje en el queue await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_not_called() resp = await sqs_client.receive_message() assert 'Messages' not in resp @@ -141,6 +204,10 @@ async def mock_create_client(*args, **kwargs): return client async_mock_function = AsyncMock(return_value=None) + + async def my_task(data: Dict) -> None: + await async_mock_function(data) + with patch( 'aiobotocore.client.AioClientCreator.create_client', mock_create_client ): @@ -150,7 +217,7 @@ async def mock_create_client(*args, **kwargs): wait_time_seconds=1, visibility_timeout=3, max_retries=1, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_called_once() @@ -175,12 +242,15 @@ async def test_retry_tasks_default_max_retries(sqs_client) -> None: async_mock_function = AsyncMock(side_effect=RetryTask) + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, - )(async_mock_function)() + )(my_task)() expected_calls = [call(test_message)] * 2 async_mock_function.assert_has_calls(expected_calls) @@ -207,13 +277,16 @@ async def test_retry_tasks_custom_max_retries(sqs_client) -> None: async_mock_function = AsyncMock(side_effect=RetryTask) + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, max_retries=3, - )(async_mock_function)() + )(my_task)() expected_calls = [call(test_message)] * 4 async_mock_function.assert_has_calls(expected_calls) @@ -243,13 +316,16 @@ async def test_does_not_retry_on_unhandled_exceptions(sqs_client) -> None: side_effect=Exception('something went wrong :(') ) + async def my_task(data: Dict) -> None: + await async_mock_function(data) + await task( queue_url=sqs_client.queue_url, region_name=CORE_QUEUE_REGION, wait_time_seconds=1, visibility_timeout=1, max_retries=3, - )(async_mock_function)() + )(my_task)() async_mock_function.assert_called_with(test_message) assert async_mock_function.call_count == 1 @@ -281,11 +357,10 @@ async def test_retry_tasks_with_countdown(sqs_client) -> None: MessageGroupId='1234', ) - call_times = [] + async_mock_function = AsyncMock(side_effect=RetryTask(countdown=2)) - async def countdown_tester(_): - call_times.append(dt.datetime.now()) - raise RetryTask(countdown=2) + async def countdown_tester(data: Dict): + await async_mock_function(data, dt.datetime.now()) await task( queue_url=sqs_client.queue_url, @@ -294,7 +369,8 @@ async def countdown_tester(_): visibility_timeout=1, )(countdown_tester)() - assert len(call_times) == 2 + call_times = [arg[1] for arg, _ in async_mock_function.call_args_list] + assert async_mock_function.call_count == 2 assert call_times[1] - call_times[0] >= dt.timedelta(seconds=2) resp = await sqs_client.receive_message() assert 'Messages' not in resp @@ -314,7 +390,7 @@ async def test_concurrency_controller( max_running_tasks = 0 - async def task_counter(_) -> None: + async def task_counter(_: Dict) -> None: nonlocal max_running_tasks await asyncio.sleep(1) running_tasks = len(await get_running_fast_agave_tasks())