diff --git a/src/dataclass_rest/parse_func.py b/src/dataclass_rest/parse_func.py index 43a99ee..b5eadcd 100644 --- a/src/dataclass_rest/parse_func.py +++ b/src/dataclass_rest/parse_func.py @@ -10,6 +10,7 @@ TypeAlias, TypedDict, Union, + get_type_hints, ) from .http_request import File @@ -25,9 +26,9 @@ def get_url_params_from_string(url_template: str) -> List[str]: def create_query_params_type( - spec: FullArgSpec, - func: Callable, - skipped: Sequence[str], + spec: FullArgSpec, + func: Callable, + skipped: Sequence[str], ) -> Type: fields = {} self_processed = False @@ -37,48 +38,50 @@ def create_query_params_type( continue if x in skipped: continue - fields[x] = spec.annotations.get(x, Any) + fields[x] = get_type_hints(func).get(x, Any) return TypedDict(f"{func.__name__}_Params", fields) def create_body_type( - spec: FullArgSpec, - body_param_name: str, + spec: FullArgSpec, + func: Callable, + body_param_name: str, ) -> Type: - return spec.annotations.get(body_param_name, Any) + return get_type_hints(func).get(body_param_name, Any) def create_response_type( - spec: FullArgSpec, + func: Callable, ) -> Type: - return spec.annotations.get("return", Any) + return get_type_hints(func).get("return", Any) -def get_file_params(spec): +def get_file_params(func: Callable) -> List[str]: + type_hints = get_type_hints(func) return [ field - for field, field_type in spec.annotations.items() + for field, field_type in type_hints.items() if isclass(field_type) and issubclass(field_type, File) ] def get_url_params_from_callable( - url_template: Callable[..., str], + url_template: Callable[..., str], ) -> List[str]: url_template_func_arg_spec = getfullargspec(url_template) return url_template_func_arg_spec.args def parse_func( - func: Callable, - method: str, - url_template: UrlTemplate, - additional_params: Dict[str, Any], - is_json_request: bool, # noqa: FBT001 - body_param_name: str, + func: Callable, + method: str, + url_template: UrlTemplate, + additional_params: Dict[str, Any], + is_json_request: bool, # noqa: FBT001 + body_param_name: str, ) -> MethodSpec: spec = getfullargspec(func) - file_params = get_file_params(spec) + file_params = get_file_params(func) is_string_url_template = isinstance(url_template, str) url_template_callable = ( @@ -99,8 +102,8 @@ def parse_func( url_template=url_template_callable, url_params=url_params, query_params_type=create_query_params_type(spec, func, skipped_params), - body_type=create_body_type(spec, body_param_name), - response_type=create_response_type(spec), + body_type=create_body_type(spec, func, body_param_name), + response_type=create_response_type(func), body_param_name=body_param_name, additional_params=additional_params, is_json_request=is_json_request, diff --git a/tests/requests/test_parse_func.py b/tests/requests/test_parse_func.py new file mode 100644 index 0000000..ae16373 --- /dev/null +++ b/tests/requests/test_parse_func.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import List, Optional + +import requests +import requests_mock + +from dataclass_rest import get, post +from dataclass_rest.http.requests import RequestsClient + + +@dataclass +class TestBody: + value: int + + +def test_string_hints(session: requests.Session, mocker: requests_mock.Mocker): + class Api(RequestsClient): + @get("/items/{item_id}") + def get_item(self, item_id: "str") -> "List[int]": + raise NotImplementedError + + @post("/items") + def create_item(self, body: "TestBody") -> "Optional[int]": + raise NotImplementedError + + mocker.get( + "http://example.com/items/1", + text="[1, 2, 3]", + complete_qs=True, + ) + mocker.post("http://example.com/items", text="1", complete_qs=True) + + client = Api(base_url="http://example.com", session=session) + + assert client.get_item("1") == [1, 2, 3] + assert client.create_item(TestBody(value=5)) == 1