diff --git a/README.md b/README.md index dc1e039..ad8b6d5 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,32 @@ class RealClient(RequestsClient): @post("todos") def create_todo(self, body: Todo) -> Todo: - """Создаем Todo""" + pass +``` + +You can use Callable ```(...) -> str``` as the url source, +all parameters passed to the client method can be obtained inside the Callable + +```python +from requests import Session +from dataclass_rest import get +from dataclass_rest.http.requests import RequestsClient + +def url_generator(todo_id: int) -> str: + return f"/todos/{todo_id}/" + + +class RealClient(RequestsClient): + def __init__(self): + super().__init__("https://dummyjson.com/", Session()) + + @get(url_generator) + def todo(self, todo_id: int) -> Todo: + pass + + +client = RealClient() +client.todo(5) ``` ## Asyncio diff --git a/src/dataclass_rest/__init__.py b/src/dataclass_rest/__init__.py index 064320e..e472fa8 100644 --- a/src/dataclass_rest/__init__.py +++ b/src/dataclass_rest/__init__.py @@ -1,7 +1,11 @@ __all__ = [ "File", "rest", - "get", "put", "post", "patch", "delete", + "get", + "put", + "post", + "patch", + "delete", ] from .http_request import File diff --git a/src/dataclass_rest/boundmethod.py b/src/dataclass_rest/boundmethod.py index 866f670..c57a222 100644 --- a/src/dataclass_rest/boundmethod.py +++ b/src/dataclass_rest/boundmethod.py @@ -13,11 +13,11 @@ class BoundMethod(ClientMethodProtocol, ABC): def __init__( - self, - name: str, - method_spec: MethodSpec, - client: ClientProtocol, - on_error: Optional[Callable[[Any], Any]], + self, + name: str, + method_spec: MethodSpec, + client: ClientProtocol, + on_error: Optional[Callable[[Any], Any]], ): self.name = name self.method_spec = method_spec @@ -26,21 +26,31 @@ def __init__( def _apply_args(self, *args, **kwargs) -> Dict: return getcallargs( - self.method_spec.func, self.client, *args, **kwargs, + self.method_spec.func, + self.client, + *args, + **kwargs, ) def _get_url(self, args) -> str: - return self.method_spec.url_template.format(**args) + args = { + arg: value + for arg, value in args.items() + if arg in self.method_spec.url_params + } + return self.method_spec.url_template(**args) def _get_body(self, args) -> Any: python_body = args.get(self.method_spec.body_param_name) return self.client.request_body_factory.dump( - python_body, self.method_spec.body_type, + python_body, + self.method_spec.body_type, ) def _get_query_params(self, args) -> Any: return self.client.request_args_factory.dump( - args, self.method_spec.query_params_type, + args, + self.method_spec.query_params_type, ) def _get_files(self, args) -> Dict[str, File]: @@ -51,11 +61,11 @@ def _get_files(self, args) -> Dict[str, File]: } def _create_request( - self, - url: str, - query_params: Any, - files: Dict[str, File], - data: Any, + self, + url: str, + query_params: Any, + files: Dict[str, File], + data: Any, ) -> HttpRequest: return HttpRequest( method=self.method_spec.http_method, diff --git a/src/dataclass_rest/client_protocol.py b/src/dataclass_rest/client_protocol.py index 1ed30d5..454b28f 100644 --- a/src/dataclass_rest/client_protocol.py +++ b/src/dataclass_rest/client_protocol.py @@ -25,7 +25,9 @@ def load(self, data: Any, class_: Type[TypeT]) -> TypeT: raise NotImplementedError def dump( - self, data: TypeT, class_: Optional[Type[TypeT]] = None, + self, + data: TypeT, + class_: Optional[Type[TypeT]] = None, ) -> Any: raise NotImplementedError @@ -37,6 +39,7 @@ class ClientProtocol(Protocol): method_class: Optional[Callable] def do_request( - self, request: HttpRequest, + self, + request: HttpRequest, ) -> Any: raise NotImplementedError diff --git a/src/dataclass_rest/http/aiohttp.py b/src/dataclass_rest/http/aiohttp.py index e8e400f..a8e2d60 100644 --- a/src/dataclass_rest/http/aiohttp.py +++ b/src/dataclass_rest/http/aiohttp.py @@ -48,9 +48,9 @@ class AiohttpClient(BaseClient): method_class = AiohttpMethod def __init__( - self, - base_url: str, - session: Optional[ClientSession] = None, + self, + base_url: str, + session: Optional[ClientSession] = None, ): super().__init__() self.session = session or ClientSession() @@ -68,7 +68,8 @@ async def do_request(self, request: HttpRequest) -> Any: for name, file in request.files.items(): data.add_field( name, - filename=file.filename, content_type=file.content_type, + filename=file.filename, + content_type=file.content_type, value=file.contents, ) try: diff --git a/src/dataclass_rest/http/requests.py b/src/dataclass_rest/http/requests.py index 6b17484..f8c58a8 100644 --- a/src/dataclass_rest/http/requests.py +++ b/src/dataclass_rest/http/requests.py @@ -16,7 +16,6 @@ class RequestsMethod(SyncMethod): - def _on_error_default(self, response: Response) -> Any: if 400 <= response.status_code < 500: raise ClientError(response.status_code) @@ -39,9 +38,9 @@ class RequestsClient(BaseClient): method_class = RequestsMethod def __init__( - self, - base_url: str, - session: Optional[Session] = None, + self, + base_url: str, + session: Optional[Session] = None, ): super().__init__() self.session = session or Session() diff --git a/src/dataclass_rest/method.py b/src/dataclass_rest/method.py index 23c11fa..b924995 100644 --- a/src/dataclass_rest/method.py +++ b/src/dataclass_rest/method.py @@ -7,9 +7,9 @@ class Method: def __init__( - self, - method_spec: MethodSpec, - method_class: Optional[Callable[..., BoundMethod]] = None, + self, + method_spec: MethodSpec, + method_class: Optional[Callable[..., BoundMethod]] = None, ): self.name = method_spec.func.__name__ self.method_spec = method_spec @@ -29,7 +29,9 @@ def __set_name__(self, owner, name): ) def __get__( - self, instance: Optional[ClientProtocol], objtype=None, + self, + instance: Optional[ClientProtocol], + objtype=None, ) -> BoundMethod: return self.method_class( name=self.name, diff --git a/src/dataclass_rest/methodspec.py b/src/dataclass_rest/methodspec.py index 0cfbd04..e8032c4 100644 --- a/src/dataclass_rest/methodspec.py +++ b/src/dataclass_rest/methodspec.py @@ -3,21 +3,22 @@ class MethodSpec: def __init__( - self, - func: Callable, - *, - url_template: str, - http_method: str, - response_type: Type, - body_param_name: str, - body_type: Type, - is_json_request: bool, - query_params_type: Type, - file_param_names: List[str], - additional_params: Dict[str, Any], + self, + func: Callable, + url_template: Callable[..., str], + url_params: List[str], + http_method: str, + response_type: Type, + body_param_name: str, + body_type: Type, + is_json_request: bool, # noqa: FBT001 + query_params_type: Type, + file_param_names: List[str], + additional_params: Dict[str, Any], ): self.func = func self.url_template = url_template + self.url_params = url_params self.http_method = http_method self.response_type = response_type self.body_param_name = body_param_name diff --git a/src/dataclass_rest/parse_func.py b/src/dataclass_rest/parse_func.py index 0bd54c6..acbffab 100644 --- a/src/dataclass_rest/parse_func.py +++ b/src/dataclass_rest/parse_func.py @@ -1,22 +1,33 @@ import string from inspect import FullArgSpec, getfullargspec, isclass -from typing import Any, Callable, Dict, List, Sequence, Type, TypedDict +from typing import ( + Any, + Callable, + Dict, + List, + Sequence, + Type, + TypeAlias, + TypedDict, + Union, +) from .http_request import File from .methodspec import MethodSpec DEFAULT_BODY_PARAM = "body" +UrlTemplate: TypeAlias = Union[str, Callable[..., str]] -def get_url_params(url_template: str) -> List[str]: +def get_url_params_from_string(url_template: str) -> List[str]: parsed_format = string.Formatter().parse(url_template) - return [x[1] for x in parsed_format] + return [x[1] for x in parsed_format if x[1]] def create_query_params_type( - spec: FullArgSpec, - func: Callable, - skipped: Sequence[str], + spec: FullArgSpec, + func: Callable, + skipped: Sequence[str], ) -> Type: fields = {} self_processed = False @@ -31,14 +42,14 @@ def create_query_params_type( def create_body_type( - spec: FullArgSpec, - body_param_name: str, + spec: FullArgSpec, + body_param_name: str, ) -> Type: return spec.annotations.get(body_param_name, Any) def create_response_type( - spec: FullArgSpec, + spec: FullArgSpec, ) -> Type: return spec.annotations.get("return", Any) @@ -51,23 +62,42 @@ def get_file_params(spec): ] +def get_url_params_from_callable( + url_template: Callable[..., str], +) -> List[str]: + url_template_func_arg_spec = getfullargspec(url_template) + return url_template_func_arg_spec.args + + def parse_func( - func: Callable, - *, - method: str, - url_template: str, - additional_params: Dict[str, Any], - is_json_request: bool, - body_param_name: str, + func: Callable, + method: str, + url_template: UrlTemplate, + additional_params: Dict[str, Any], + is_json_request: bool, # noqa: FBT001 + body_param_name: str, ) -> MethodSpec: spec = getfullargspec(func) - url_params = get_url_params(url_template) file_params = get_file_params(spec) + + is_string_url_template = isinstance(url_template, str) + url_template_callable = ( + url_template.format if is_string_url_template else url_template + ) + + url_params = ( + get_url_params_from_string(url_template) + if is_string_url_template + else get_url_params_from_callable(url_template) + ) + skipped_params = url_params + file_params + [body_param_name] + return MethodSpec( func=func, http_method=method, - url_template=url_template, + url_template=url_template_callable, + url_params=url_params, query_params_type=create_query_params_type(spec, func, skipped_params), body_type=create_body_type(spec, body_param_name), response_type=create_response_type(spec), diff --git a/src/dataclass_rest/rest.py b/src/dataclass_rest/rest.py index 21e03b0..61e29d1 100644 --- a/src/dataclass_rest/rest.py +++ b/src/dataclass_rest/rest.py @@ -2,13 +2,13 @@ from .boundmethod import BoundMethod from .method import Method -from .parse_func import DEFAULT_BODY_PARAM, parse_func +from .parse_func import DEFAULT_BODY_PARAM, UrlTemplate, parse_func _Func = TypeVar("_Func", bound=Callable[..., Any]) def rest( - url_template: str, + url_template: UrlTemplate, *, method: str, body_name: str = DEFAULT_BODY_PARAM, diff --git a/tests/requests/conftest.py b/tests/requests/conftest.py index 849f945..2950ca8 100644 --- a/tests/requests/conftest.py +++ b/tests/requests/conftest.py @@ -12,6 +12,7 @@ def session(): @pytest.fixture def mocker(session): with requests_mock.Mocker( - session=session, case_sensitive=True, + session=session, + case_sensitive=True, ) as session_mock: yield session_mock diff --git a/tests/requests/test_callable_url.py b/tests/requests/test_callable_url.py new file mode 100644 index 0000000..1fa4e54 --- /dev/null +++ b/tests/requests/test_callable_url.py @@ -0,0 +1,94 @@ +from typing import List, Optional + +import pytest +import requests +import requests_mock + +from dataclass_rest import get +from dataclass_rest.http.requests import RequestsClient + + +def static_url() -> str: + return "/get" + + +def param_url(entry_id: int) -> str: + return f"/get/{entry_id}" + + +def kwonly_param_url(entry_id: Optional[int] = None) -> str: + if entry_id: + return f"/get/{entry_id}" + return "/get/random" + + +def test_simple(session: requests.Session, mocker: requests_mock.Mocker): + class Api(RequestsClient): + @get(static_url) + def get_x(self) -> List[int]: + raise NotImplementedError + + mocker.get("http://example.com/get", text="[1,2]", complete_qs=True) + client = Api(base_url="http://example.com", session=session) + assert client.get_x() == [1, 2] + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + ( + 1, + 1, + ), + ( + 2, + 2, + ), + ], +) +def test_with_param( + session: requests.Session, + mocker: requests_mock.Mocker, + value: int, + expected: int, +): + class Api(RequestsClient): + @get(param_url) + def get_entry(self, entry_id: int) -> int: + raise NotImplementedError + + url = f"http://example.com/get/{expected}" + mocker.get(url, text=str(expected), complete_qs=True) + + client = Api(base_url="http://example.com", session=session) + assert client.get_entry(value) == expected + + +def test_excess_param(session: requests.Session, mocker: requests_mock.Mocker): + class Api(RequestsClient): + @get(param_url) + def get_entry( + self, entry_id: int, some_param: Optional[int] = None, + ) -> int: + raise NotImplementedError + + mocker.get( + "http://example.com/get/1?some_param=2", text="1", complete_qs=True, + ) + + client = Api(base_url="http://example.com", session=session) + assert client.get_entry(1, 2) == 1 + + +def test_kwonly_param(session: requests.Session, mocker: requests_mock.Mocker): + class Api(RequestsClient): + @get(kwonly_param_url) + def get_entry(self, *, entry_id: Optional[int] = None) -> int: + raise NotImplementedError + + mocker.get("http://example.com/get/1", text="1", complete_qs=True) + mocker.get("http://example.com/get/random", text="2", complete_qs=True) + + client = Api(base_url="http://example.com", session=session) + assert client.get_entry(entry_id=1) == 1 + assert client.get_entry() == 2 diff --git a/tests/requests/test_factory.py b/tests/requests/test_factory.py index d6a5ebc..72ae49f 100644 --- a/tests/requests/test_factory.py +++ b/tests/requests/test_factory.py @@ -27,19 +27,25 @@ class ResponseBody: def test_body(session, mocker): class Api(RequestsClient): def _init_request_body_factory(self) -> Retort: - return Retort(recipe=[ - name_mapping(name_style=NameStyle.CAMEL), - ]) + return Retort( + recipe=[ + name_mapping(name_style=NameStyle.CAMEL), + ], + ) def _init_request_args_factory(self) -> Retort: - return Retort(recipe=[ - name_mapping(name_style=NameStyle.UPPER_DOT), - ]) + return Retort( + recipe=[ + name_mapping(name_style=NameStyle.UPPER_DOT), + ], + ) def _init_response_body_factory(self) -> Retort: - return Retort(recipe=[ - name_mapping(name_style=NameStyle.LOWER_KEBAB), - ]) + return Retort( + recipe=[ + name_mapping(name_style=NameStyle.LOWER_KEBAB), + ], + ) @patch("/post/") def post_x(self, long_param: str, body: RequestBody) -> ResponseBody: diff --git a/tests/requests/test_params.py b/tests/requests/test_params.py index bd8405b..d91fc56 100644 --- a/tests/requests/test_params.py +++ b/tests/requests/test_params.py @@ -46,15 +46,18 @@ def post_x(self, id: str, param: Optional[int]) -> List[int]: mocker.post( url="http://example.com/post/x?", - text="[0]", complete_qs=True, + text="[0]", + complete_qs=True, ) mocker.post( url="http://example.com/post/x?param=1", - text="[1]", complete_qs=True, + text="[1]", + complete_qs=True, ) mocker.post( url="http://example.com/post/x?param=2", - text="[1,2]", complete_qs=True, + text="[1,2]", + complete_qs=True, ) client = Api(base_url="http://example.com", session=session) assert client.post_x("x", None) == [0] @@ -76,7 +79,8 @@ def post_x(self, body: RequestBody) -> None: mocker.post( url="http://example.com/post/", - text="null", complete_qs=True, + text="null", + complete_qs=True, ) client = Api(base_url="http://example.com", session=session) assert client.post_x(RequestBody(x=1, y="test")) is None diff --git a/tests/test_init.py b/tests/test_init.py index e2226e8..431302f 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -23,9 +23,11 @@ def __init__(self): ) def _init_request_body_factory(self) -> Retort: - return Retort(recipe=[ - name_mapping(name_style=NameStyle.CAMEL), - ]) + return Retort( + recipe=[ + name_mapping(name_style=NameStyle.CAMEL), + ], + ) @get("todos/{id}") def get_todo(self, id: str) -> Todo: @@ -41,9 +43,11 @@ def __init__(self): super().__init__("https://jsonplaceholder.typicode.com/") def _init_request_body_factory(self) -> Retort: - return Retort(recipe=[ - name_mapping(name_style=NameStyle.CAMEL), - ]) + return Retort( + recipe=[ + name_mapping(name_style=NameStyle.CAMEL), + ], + ) @get("todos/{id}") async def get_todo(self, id: str) -> Todo: