diff --git a/agave/blueprints/rest_api.py b/agave/blueprints/rest_api.py index be7a8f97..688a4819 100644 --- a/agave/blueprints/rest_api.py +++ b/agave/blueprints/rest_api.py @@ -1,5 +1,5 @@ import mimetypes -from typing import Optional, Type, cast +from typing import Any, Optional, Type, cast from urllib.parse import urlencode from chalice import Blueprint, NotFoundError, Response @@ -27,6 +27,10 @@ def delete(self, path: str, **kwargs): def current_user_id(self): return self.current_request.user_id + @property + def current_platform_id(self): + return self.current_request.platform_id + def user_id_filter_required(self): """ This method is required to be implemented with your own business logic. @@ -36,6 +40,36 @@ def user_id_filter_required(self): 'this method should be override' ) # pragma: no cover + def platform_id_filter_required(self): + """ + This method is required to be implemented with your own business logic. + You are responsible of determining when `user_id` filter is required. + """ + raise NotImplementedError( + 'this method should be override' + ) # pragma: no cover + + def retrieve_object(self, resource_class: Any, resource_id: str) -> Any: + resource_id = ( + self.current_user_id if resource_id == 'me' else resource_id + ) + query = Q(id=resource_id) + if self.platform_id_filter_required() and hasattr( + resource_class.model, 'platform_id' + ): + query = query & Q(platform_id=self.current_platform_id) + + if self.user_id_filter_required() and hasattr( + resource_class.model, 'user_id' + ): + query = query & Q(user_id=self.current_user_id) + + try: + data = resource_class.model.objects.get(query) + except DoesNotExist: + raise NotFoundError('Not valid id') + return data + def validate(self, validation_type: Type[BaseModel]): """This decorator validate the request body using a custom pydantyc model If validation fails return a BadRequest response with details @@ -103,12 +137,8 @@ def wrapper_resource_class(cls): @copy_attributes(cls) def delete(id: str): - try: - model = cls.model.objects.get(id=id) - except DoesNotExist: - raise NotFoundError('Not valid id') - else: - return cls.delete(model) + model = self.retrieve_object(cls, id) + return cls.delete(model) route(delete) @@ -125,13 +155,11 @@ def update(id: str): params = self.current_request.json_body or dict() try: data = cls.update_validator(**params) - model = cls.model.objects.get(id=id) except ValidationError as e: return Response(e.json(), status_code=400) - except DoesNotExist: - raise NotFoundError('Not valid id') - else: - return cls.update(model, data) + + model = self.retrieve_object(cls, id) + return cls.update(model, data) route(update) @@ -149,18 +177,12 @@ def retrieve(id: str): The most of times this implementation is enough and is not necessary define a custom "retrieve" method """ - try: - id_query = Q(id=id) - if self.user_id_filter_required(): - id_query = id_query & Q(user_id=self.current_user_id) - data = cls.model.objects.get(id_query) - except DoesNotExist: - raise NotFoundError('Not valid id') + obj = self.retrieve_object(cls, id) # This case is when the return is not an application/$ # but can be some type of file such as image, xml, zip or pdf if hasattr(cls, 'download'): - file = cls.download(data) + file = cls.download(obj) mimetype = cast( str, self.current_request.headers.get('accept') ) @@ -177,9 +199,9 @@ def retrieve(id: str): status_code=200, ) elif hasattr(cls, 'retrieve'): - result = cls.retrieve(data) + result = cls.retrieve(obj) else: - result = data.to_dict() + result = obj.to_dict() return result @@ -209,9 +231,17 @@ def query(): query_params = cls.query_validator(**params) except ValidationError as e: return Response(e.json(), status_code=400) - # Set user_id request as query param - if self.user_id_filter_required(): + + if self.platform_id_filter_required() and hasattr( + cls.model, 'platform_id' + ): + query_params.platform_id = self.current_platform_id + + if self.user_id_filter_required() and hasattr( + cls.model, 'user_id' + ): query_params.user_id = self.current_user_id + filters = cls.get_query_filter(query_params) if ( hasattr(query_params, 'active') @@ -257,6 +287,8 @@ def _all(query: QueryParams, filters: Q): params = query.dict() if self.user_id_filter_required(): params.pop('user_id') + if self.platform_id_filter_required(): + params.pop('platform_id') next_page_uri = f'{path}?{urlencode(params)}' return dict(items=item_dicts, next_page_uri=next_page_uri) diff --git a/agave/version.py b/agave/version.py index c3bb2961..7fd229a3 100644 --- a/agave/version.py +++ b/agave/version.py @@ -1 +1 @@ -__version__ = '0.1.8' +__version__ = '0.2.0' diff --git a/examples/chalicelib/blueprints/authed.py b/examples/chalicelib/blueprints/authed.py index 98af8ca1..187398cd 100644 --- a/examples/chalicelib/blueprints/authed.py +++ b/examples/chalicelib/blueprints/authed.py @@ -3,6 +3,8 @@ from chalice import Blueprint +from ...config import TEST_DEFAULT_PLATFORM_ID, TEST_DEFAULT_USER_ID + class AuthedBlueprint(Blueprint): """ @@ -28,7 +30,8 @@ def decorator(user_handler: Callable): def authed_handler(*args, **kwargs): # your authentication logic goes here # before execute `user_handler` function. - self.current_request.user_id = 'US123456789' + self.current_request.user_id = TEST_DEFAULT_USER_ID + self.current_request.platform_id = TEST_DEFAULT_PLATFORM_ID return user_handler(*args, **kwargs) self._register_handler( # type: ignore @@ -61,3 +64,10 @@ def user_id_filter_required(self): :return: """ return False + + def platform_id_filter_required(self): + """ + It overrides `RestApiBlueprint.platform_id_filter_required()` method. + :return: + """ + return False diff --git a/examples/chalicelib/models/__init__.py b/examples/chalicelib/models/__init__.py index fa6ee1ce..2435a9f1 100644 --- a/examples/chalicelib/models/__init__.py +++ b/examples/chalicelib/models/__init__.py @@ -1,6 +1,8 @@ -__all__ = ['Account', 'Card', 'Transaction', 'File'] +__all__ = ['Account', 'Biller', 'Card', 'Transaction', 'File', 'User'] from .accounts import Account +from .billers import Biller from .cards import Card from .files import File from .transactions import Transaction +from .users import User diff --git a/examples/chalicelib/models/accounts.py b/examples/chalicelib/models/accounts.py index fd6617c2..ea20fe59 100644 --- a/examples/chalicelib/models/accounts.py +++ b/examples/chalicelib/models/accounts.py @@ -8,5 +8,6 @@ class Account(BaseModel, Document): id = StringField(primary_key=True, default=uuid_field('AC')) name = StringField(required=True) user_id = StringField(required=True) + platform_id = StringField(required=True) created_at = DateTimeField() deactivated_at = DateTimeField() diff --git a/examples/chalicelib/models/billers.py b/examples/chalicelib/models/billers.py new file mode 100644 index 00000000..ded9bc08 --- /dev/null +++ b/examples/chalicelib/models/billers.py @@ -0,0 +1,12 @@ +import datetime as dt + +from mongoengine import DateTimeField, Document, StringField + +from agave.models import BaseModel +from agave.models.helpers import uuid_field + + +class Biller(BaseModel, Document): + id = StringField(primary_key=True, default=uuid_field('BL')) + created_at = DateTimeField(default=dt.datetime.utcnow) + name = StringField(required=True) diff --git a/examples/chalicelib/models/users.py b/examples/chalicelib/models/users.py new file mode 100644 index 00000000..f0305069 --- /dev/null +++ b/examples/chalicelib/models/users.py @@ -0,0 +1,13 @@ +import datetime as dt + +from mongoengine import DateTimeField, Document, StringField + +from agave.models import BaseModel +from agave.models.helpers import uuid_field + + +class User(BaseModel, Document): + id = StringField(primary_key=True, default=uuid_field('US')) + created_at = DateTimeField(default=dt.datetime.utcnow) + name = StringField(required=True) + platform_id = StringField(required=True) diff --git a/examples/chalicelib/resources/__init__.py b/examples/chalicelib/resources/__init__.py index 5e7cd57e..ccdd1452 100644 --- a/examples/chalicelib/resources/__init__.py +++ b/examples/chalicelib/resources/__init__.py @@ -1,7 +1,9 @@ -__all__ = ['app', 'Account', 'Card', 'File', 'Transaction'] +__all__ = ['app', 'Account', 'Biller', 'Card', 'File', 'Transaction', 'User'] from .accounts import Account from .base import app +from .billers import Biller from .cards import Card from .files import File from .transactions import Transaction +from .users import User diff --git a/examples/chalicelib/resources/accounts.py b/examples/chalicelib/resources/accounts.py index 627d4d8f..cdbb11c1 100644 --- a/examples/chalicelib/resources/accounts.py +++ b/examples/chalicelib/resources/accounts.py @@ -1,7 +1,6 @@ import datetime as dt -from chalice import NotFoundError, Response -from mongoengine import DoesNotExist +from chalice import Response from agave.filters import generic_query @@ -23,6 +22,7 @@ def create(request: AccountRequest) -> Response: account = AccountModel( name=request.name, user_id=app.current_user_id, + platform_id=app.current_platform_id, ) account.save() return Response(account.to_dict(), status_code=201) diff --git a/examples/chalicelib/resources/billers.py b/examples/chalicelib/resources/billers.py new file mode 100644 index 00000000..03713fe2 --- /dev/null +++ b/examples/chalicelib/resources/billers.py @@ -0,0 +1,12 @@ +from agave.filters import generic_query + +from ..models import Biller as BillerModel +from ..validators import BillerQuery +from .base import app + + +@app.resource('/billers') +class Biller: + model = BillerModel + query_validator = BillerQuery + get_query_filter = generic_query diff --git a/examples/chalicelib/resources/users.py b/examples/chalicelib/resources/users.py new file mode 100644 index 00000000..1c32e80b --- /dev/null +++ b/examples/chalicelib/resources/users.py @@ -0,0 +1,12 @@ +from agave.filters import generic_query + +from ..models import User as UserModel +from ..validators import UserQuery +from .base import app + + +@app.resource('/users') +class User: + model = UserModel + query_validator = UserQuery + get_query_filter = generic_query diff --git a/examples/chalicelib/validators.py b/examples/chalicelib/validators.py index 80055022..ab444d31 100644 --- a/examples/chalicelib/validators.py +++ b/examples/chalicelib/validators.py @@ -7,6 +7,7 @@ class AccountQuery(QueryParams): name: Optional[str] = None user_id: Optional[str] = None + platform_id: Optional[str] = None active: Optional[bool] = None @@ -14,6 +15,14 @@ class TransactionQuery(QueryParams): user_id: Optional[str] = None +class BillerQuery(QueryParams): + name: str + + +class UserQuery(QueryParams): + platform_id: str + + class AccountRequest(BaseModel): name: str diff --git a/examples/config.py b/examples/config.py new file mode 100644 index 00000000..84c674d8 --- /dev/null +++ b/examples/config.py @@ -0,0 +1,4 @@ +TEST_DEFAULT_USER_ID = 'US123456789' +TEST_DEFAULT_PLATFORM_ID = 'PT123456' +TEST_SECOND_USER_ID = 'US987654321' +TEST_SECOND_PLATFORM_ID = 'PT987654321' diff --git a/tests/blueprint/test_blueprint.py b/tests/blueprint/test_blueprint.py index e08f547c..9d4d39bc 100644 --- a/tests/blueprint/test_blueprint.py +++ b/tests/blueprint/test_blueprint.py @@ -1,4 +1,5 @@ import datetime as dt +from typing import List from urllib.parse import urlencode import pytest @@ -6,6 +7,16 @@ from mock import MagicMock, patch from examples.chalicelib.models import Account, Card, File +from examples.config import ( + TEST_DEFAULT_PLATFORM_ID, + TEST_DEFAULT_USER_ID, + TEST_SECOND_PLATFORM_ID, +) + +PLATFORM_ID_FILTER_REQUIRED = ( + 'examples.chalicelib.blueprints.authed.' + 'AuthedBlueprint.platform_id_filter_required' +) USER_ID_FILTER_REQUIRED = ( 'examples.chalicelib.blueprints.authed.' @@ -34,6 +45,14 @@ def test_retrieve_resource(client: Client, account: Account) -> None: assert resp.json_body == account.to_dict() +@patch(PLATFORM_ID_FILTER_REQUIRED, MagicMock(return_value=True)) +def test_retrieve_resource_platform_id_filter_required( + client: Client, other_account: Account +) -> None: + resp = client.http.get(f'/accounts/{other_account.id}') + assert resp.status_code == 404 + + @patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) def test_retrieve_resource_user_id_filter_required( client: Client, other_account: Account @@ -42,6 +61,15 @@ def test_retrieve_resource_user_id_filter_required( assert resp.status_code == 404 +@patch(PLATFORM_ID_FILTER_REQUIRED, MagicMock(return_value=True)) +@patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) +def test_retrieve_resource_user_id_and_platform_id_filter_required( + client: Client, other_account: Account +) -> None: + resp = client.http.get(f'/accounts/{other_account.id}') + assert resp.status_code == 404 + + def test_retrieve_resource_not_found(client: Client) -> None: resp = client.http.get('/accounts/unknown_id') assert resp.status_code == 404 @@ -115,24 +143,31 @@ def test_query_all_with_limit(client: Client) -> None: @pytest.mark.usefixtures('accounts') -def test_query_all_resource(client: Client) -> None: - query_params = dict(page_size=2) - resp = client.http.get(f'/accounts?{urlencode(query_params)}') - assert resp.status_code == 200 - assert len(resp.json_body['items']) == 2 +def test_query_all_resource(client: Client, accounts: List[Account]) -> None: + accounts = list(reversed(accounts)) - resp = client.http.get(resp.json_body['next_page_uri']) - assert resp.status_code == 200 - assert len(resp.json_body['items']) == 2 + items = [] + page_uri = f'/accounts?{urlencode(dict(page_size=2))}' + + while page_uri: + resp = client.http.get(page_uri) + assert resp.status_code == 200 + items.extend(resp.json_body['items']) + page_uri = resp.json_body['next_page_uri'] + assert len(items) == len(accounts) + assert all(a.to_dict() == b for a, b in zip(accounts, items)) -def test_query_all_filter_active(client: Client, account: Account) -> None: + +def test_query_all_filter_active( + client: Client, account: Account, accounts: List[Account] +) -> None: query_params = dict(active=True) # Query active items resp = client.http.get(f'/accounts?{urlencode(query_params)}') assert resp.status_code == 200 items = resp.json_body['items'] - assert len(items) == 4 + assert len(items) == len(accounts) assert all(item['deactivated_at'] is None for item in items) # Deactivate Item @@ -141,7 +176,7 @@ def test_query_all_filter_active(client: Client, account: Account) -> None: resp = client.http.get(f'/accounts?{urlencode(query_params)}') assert resp.status_code == 200 items = resp.json_body['items'] - assert len(items) == 3 + assert len(items) == len(accounts) - 1 # Query deactivated items query_params = dict(active=False) @@ -152,23 +187,62 @@ def test_query_all_filter_active(client: Client, account: Account) -> None: assert items[0]['deactivated_at'] is not None -@pytest.mark.usefixtures('accounts') -@patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) -def test_query_user_id_filter_required(client: Client) -> None: - query_params = dict(page_size=2) +def test_query_all_created_after( + client: Client, accounts: List[Account] +) -> None: + created_at = dt.datetime(2020, 2, 1) + expected_length = len([a for a in accounts if a.created_at > created_at]) + + query_params = dict(created_after=created_at.isoformat()) resp = client.http.get(f'/accounts?{urlencode(query_params)}') + assert resp.status_code == 200 - assert len(resp.json_body['items']) == 2 - assert all( - item['user_id'] == 'US123456789' for item in resp.json_body['items'] + assert len(resp.json_body['items']) == expected_length + + +@patch(PLATFORM_ID_FILTER_REQUIRED, MagicMock(return_value=True)) +def test_query_platform_id_filter_required( + client: Client, accounts: List[Account] +) -> None: + accounts = list( + reversed( + [a for a in accounts if a.platform_id == TEST_DEFAULT_PLATFORM_ID] + ) ) - resp = client.http.get(resp.json_body['next_page_uri']) - assert resp.status_code == 200 - assert len(resp.json_body['items']) == 1 - assert all( - item['user_id'] == 'US123456789' for item in resp.json_body['items'] + items = [] + page_uri = f'/accounts?{urlencode(dict(page_size=2))}' + + while page_uri: + resp = client.http.get(page_uri) + assert resp.status_code == 200 + json_body = resp.json_body + items.extend(json_body['items']) + page_uri = json_body['next_page_uri'] + + assert len(items) == len(accounts) + assert all(a.to_dict() == b for a, b in zip(accounts, items)) + + +@patch(USER_ID_FILTER_REQUIRED, MagicMock(return_value=True)) +def test_query_user_id_filter_required( + client: Client, accounts: List[Account] +) -> None: + accounts = list( + reversed([a for a in accounts if a.user_id == TEST_DEFAULT_USER_ID]) ) + items = [] + page_uri = f'/accounts?{urlencode(dict(page_size=2))}' + + while page_uri: + resp = client.http.get(page_uri) + assert resp.status_code == 200 + json_body = resp.json_body + items.extend(json_body['items']) + page_uri = json_body['next_page_uri'] + + assert len(items) == len(accounts) + assert all(a.to_dict() == b for a, b in zip(accounts, items)) def test_query_resource_with_invalid_params(client: Client) -> None: @@ -211,3 +285,28 @@ def test_download_resource(client: Client, file: File) -> None: resp = client.http.get(f'/files/{file.id}', headers={'Accept': mimetype}) assert resp.status_code == 200 assert resp.headers.get('Content-Type') == mimetype + + +@pytest.mark.usefixtures('users') +def test_filter_no_user_id_query(client: Client) -> None: + resp = client.http.get(f'/users?platform_id={TEST_DEFAULT_PLATFORM_ID}') + resp_json = resp.json_body + assert resp.status_code == 200 + assert len(resp_json['items']) == 1 + user1 = resp_json['items'][0] + resp = client.http.get(f'/users?platform_id={TEST_SECOND_PLATFORM_ID}') + resp_json = resp.json_body + assert resp.status_code == 200 + assert len(resp_json['items']) == 1 + user2 = resp_json['items'][0] + assert user1['id'] != user2['id'] + + +@pytest.mark.usefixtures('billers') +def test_filter_no_user_id_and_no_platform_id_query( + client: Client, +) -> None: + resp = client.http.get('/billers?name=ATT') + resp_json = resp.json_body + assert resp.status_code == 200 + assert len(resp_json['items']) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index 8ae292ed..927b75f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,38 @@ import datetime as dt -from typing import Generator, List +import functools +from typing import Callable, Generator, List import pytest from chalice.test import Client +from mongoengine import Document -from examples.chalicelib.models import Account, Card, File +from examples.chalicelib.models import Account, Biller, Card, File, User +from examples.config import ( + TEST_DEFAULT_PLATFORM_ID, + TEST_DEFAULT_USER_ID, + TEST_SECOND_PLATFORM_ID, + TEST_SECOND_USER_ID, +) from .helpers import accept_json +FuncDecorator = Callable[..., Generator] + + +def collection_fixture(model: Document) -> Callable[..., FuncDecorator]: + def collection_decorator(func: Callable) -> FuncDecorator: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Generator[List, None, None]: + items = func(*args, **kwargs) + for item in items: + item.save() + yield items + model.objects.delete() + + return wrapper + + return collection_decorator + @pytest.fixture() def client() -> Generator[Client, None, None]: @@ -28,37 +53,47 @@ def client() -> Generator[Client, None, None]: @pytest.fixture -def accounts() -> Generator[List[Account], None, None]: - user_id = 'US123456789' - accs = [ +@collection_fixture(Account) +def accounts() -> List[Account]: + return [ Account( name='Frida Kahlo', - user_id=user_id, - created_at=dt.datetime(2020, 1, 1), + user_id=TEST_DEFAULT_USER_ID, + platform_id=TEST_DEFAULT_PLATFORM_ID, + created_at=dt.datetime(2020, 1, 1, 0), ), Account( name='Sor Juana Inés', - user_id=user_id, - created_at=dt.datetime(2020, 2, 1), + user_id=TEST_DEFAULT_USER_ID, + platform_id=TEST_DEFAULT_PLATFORM_ID, + created_at=dt.datetime(2020, 2, 1, 0), + ), + Account( + name='Eulalia Guzmán', + user_id='US222222', + platform_id=TEST_DEFAULT_PLATFORM_ID, + created_at=dt.datetime(2020, 2, 1, 1), + ), + Account( + name='Matilde Montoya', + user_id='US222222', + platform_id=TEST_DEFAULT_PLATFORM_ID, + created_at=dt.datetime(2020, 2, 1, 2), ), Account( name='Leona Vicario', - user_id=user_id, - created_at=dt.datetime(2020, 3, 1), + user_id=TEST_DEFAULT_USER_ID, + platform_id=TEST_DEFAULT_PLATFORM_ID, + created_at=dt.datetime(2020, 3, 1, 0), ), Account( name='Remedios Varo', - user_id='US987654321', - created_at=dt.datetime(2020, 4, 1), + user_id=TEST_SECOND_USER_ID, + platform_id=TEST_SECOND_PLATFORM_ID, + created_at=dt.datetime(2020, 4, 1, 0), ), ] - for acc in accs: - acc.save() - yield accs - for acc in accs: - acc.delete() - @pytest.fixture def account(accounts: List[Account]) -> Generator[Account, None, None]: @@ -71,21 +106,15 @@ def other_account(accounts: List[Account]) -> Generator[Account, None, None]: @pytest.fixture -def files() -> Generator[List[File], None, None]: - user_id = 'US123456789' - accs = [ +@collection_fixture(File) +def files() -> List[File]: + return [ File( name='Frida Kahlo', - user_id=user_id, + user_id=TEST_DEFAULT_USER_ID, ), ] - for acc in accs: - acc.save() - yield accs - for acc in accs: - acc.delete() - @pytest.fixture def file(files: List[File]) -> Generator[File, None, None]: @@ -93,38 +122,50 @@ def file(files: List[File]) -> Generator[File, None, None]: @pytest.fixture -def cards() -> Generator[List[Card], None, None]: - user_id = 'US123456789' - cards = [ +@collection_fixture(Card) +def cards() -> List[Card]: + return [ Card( number='5434000000000001', - user_id=user_id, + user_id=TEST_DEFAULT_USER_ID, created_at=dt.datetime(2020, 1, 1), ), Card( number='5434000000000002', - user_id=user_id, + user_id=TEST_DEFAULT_USER_ID, created_at=dt.datetime(2020, 2, 1), ), Card( number='5434000000000003', - user_id=user_id, + user_id=TEST_DEFAULT_USER_ID, created_at=dt.datetime(2020, 3, 1), ), Card( number='5434000000000004', - user_id='US987654321', + user_id=TEST_SECOND_USER_ID, created_at=dt.datetime(2020, 4, 1), ), ] - for card in cards: - card.save() - yield cards - for card in cards: - card.delete() - @pytest.fixture def card(cards: List[Card]) -> Generator[Card, None, None]: yield cards[0] + + +@pytest.fixture +@collection_fixture(User) +def users() -> List[User]: + return [ + User(name='User1', platform_id=TEST_DEFAULT_PLATFORM_ID), + User(name='User2', platform_id=TEST_SECOND_PLATFORM_ID), + ] + + +@pytest.fixture +@collection_fixture(Biller) +def billers() -> List[Biller]: + return [ + Biller(name='Telcel'), + Biller(name='ATT'), + ]