Skip to content

Commit

Permalink
Merge pull request #123 from mirumee/fix_decode_and_parse_application
Browse files Browse the repository at this point in the history
Change BaseModel to apply parse and serialize methods on every list element
  • Loading branch information
mat-sop authored Apr 5, 2023
2 parents 65b3410 + 292a2d4 commit 44f9740
Show file tree
Hide file tree
Showing 11 changed files with 503 additions and 81 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Unlocked versions of black, isort, autoflake and dev dependencies
- Added `remote_schema_verify_ssl` option.
- Changed how default values for inputs are generated to handle potential cycles.
- Fixed `BaseModel` incorrectly calling `parse` and `serialize` methods on entire list instead of its items for `List[Scalar]`.


## 0.4.0 (2023-03-20)
Expand Down
38 changes: 29 additions & 9 deletions ariadne_codegen/client_generators/dependencies/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Type, Union, get_args, get_origin

from pydantic import BaseModel as PydanticBaseModel
from pydantic.class_validators import validator
Expand All @@ -15,16 +15,36 @@ class Config:

# pylint: disable=no-self-argument
@validator("*", pre=True)
def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any:
decode = SCALARS_PARSE_FUNCTIONS.get(field.type_)
if decode and callable(decode):
def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any:
return cls._parse_custom_scalar_value(value, field.annotation)

@classmethod
def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any:
origin = get_origin(type_)
args = get_args(type_)
if origin is list and isinstance(value, list):
return [cls._parse_custom_scalar_value(item, args[0]) for item in value]

if origin is Union and type(None) in args:
sub_type: Any = list(filter(None, args))[0]
return cls._parse_custom_scalar_value(value, sub_type)

decode = SCALARS_PARSE_FUNCTIONS.get(type_)
if value and decode and callable(decode):
return decode(value)

return value

def dict(self, **kwargs: Any) -> Dict[str, Any]:
dict_ = super().dict(**kwargs)
for key, value in dict_.items():
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
if serialize and callable(serialize):
dict_[key] = serialize(value)
return dict_
return {key: self._serialize_value(value) for key, value in dict_.items()}

def _serialize_value(self, value: Any) -> Any:
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
if serialize and callable(serialize):
return serialize(value)

if isinstance(value, list):
return [self._serialize_value(item) for item in value]

return value
241 changes: 241 additions & 0 deletions tests/client_generators/dependencies/test_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from typing import List, Optional

import pytest

from ariadne_codegen.client_generators.dependencies.base_model import BaseModel


@pytest.mark.parametrize(
"annotation, value, expected_args",
[
(str, "a", {"a"}),
(Optional[str], "a", {"a"}),
(Optional[str], None, set()),
(List[str], ["a", "b"], {"a", "b"}),
(List[Optional[str]], ["a", None], {"a"}),
(Optional[List[str]], ["a", "b"], {"a", "b"}),
(Optional[List[str]], None, set()),
(Optional[List[Optional[str]]], ["a", None], {"a"}),
(Optional[List[Optional[str]]], None, set()),
(List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
(Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
(Optional[List[List[str]]], None, set()),
(
Optional[List[Optional[List[str]]]],
[["a", "b"], ["c", "d"]],
{"a", "b", "c", "d"},
),
(Optional[List[Optional[List[str]]]], None, set()),
(Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}),
(
Optional[List[Optional[List[Optional[str]]]]],
[["a", "b"], ["c", "d"]],
{"a", "b", "c", "d"},
),
(Optional[List[Optional[List[Optional[str]]]]], None, set()),
(Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}),
(
Optional[List[Optional[List[Optional[str]]]]],
[["a", None], ["b", None]],
{"a", "b"},
),
],
)
def test_parse_obj_applies_parse_on_every_list_element(
annotation, value, expected_args, mocker
):
mocked_parse = mocker.MagicMock(side_effect=lambda s: s)
mocker.patch(
"ariadne_codegen.client_generators.dependencies.base_model."
"SCALARS_PARSE_FUNCTIONS",
{str: mocked_parse},
)

class TestModel(BaseModel):
field: annotation

TestModel.parse_obj({"field": value})

assert mocked_parse.call_count == len(expected_args)
assert {c.args[0] for c in mocked_parse.call_args_list} == expected_args


def test_parse_obj_doesnt_apply_parse_on_not_matching_type(mocker):
mocked_parse = mocker.MagicMock(side_effect=lambda s: s)
mocker.patch(
"ariadne_codegen.client_generators.dependencies.base_model."
"SCALARS_PARSE_FUNCTIONS",
{str: mocked_parse},
)

class TestModel(BaseModel):
field_a: int
field_b: Optional[int]
field_c: Optional[int]
field_d: List[int]
field_e: Optional[List[int]]
field_f: Optional[List[int]]
field_g: Optional[List[Optional[int]]]
field_h: Optional[List[Optional[int]]]
field_i: Optional[List[Optional[int]]]

TestModel.parse_obj(
{
"field_a": 1,
"field_b": 2,
"field_c": None,
"field_d": [3, 4],
"field_e": [5, 6],
"field_f": None,
"field_g": [7, 8],
"field_h": [9, None],
"field_i": None,
}
)

assert not mocked_parse.called


def test_parse_obj_applies_parse_only_once_for_every_element(mocker):
mocked_parse = mocker.MagicMock(side_effect=lambda s: s)
mocker.patch(
"ariadne_codegen.client_generators.dependencies.base_model."
"SCALARS_PARSE_FUNCTIONS",
{str: mocked_parse},
)

class TestModelC(BaseModel):
value: str

class TestModelB(BaseModel):
value: str
field_c: TestModelC

class TestModelA(BaseModel):
value: str
field_b: TestModelB

TestModelA.parse_obj(
{"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}}
)

assert mocked_parse.call_count == 3
assert {c.args[0] for c in mocked_parse.call_args_list} == {"a", "b", "c"}


@pytest.mark.parametrize(
"annotation, value, expected_args",
[
(str, "a", {"a"}),
(Optional[str], "a", {"a"}),
(Optional[str], None, set()),
(List[str], ["a", "b"], {"a", "b"}),
(List[Optional[str]], ["a", None], {"a"}),
(Optional[List[str]], ["a", "b"], {"a", "b"}),
(Optional[List[str]], None, set()),
(Optional[List[Optional[str]]], ["a", None], {"a"}),
(Optional[List[Optional[str]]], None, set()),
(List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
(Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
(Optional[List[List[str]]], None, set()),
(
Optional[List[Optional[List[str]]]],
[["a", "b"], ["c", "d"]],
{"a", "b", "c", "d"},
),
(Optional[List[Optional[List[str]]]], None, set()),
(Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}),
(
Optional[List[Optional[List[Optional[str]]]]],
[["a", "b"], ["c", "d"]],
{"a", "b", "c", "d"},
),
(Optional[List[Optional[List[Optional[str]]]]], None, set()),
(Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}),
(
Optional[List[Optional[List[Optional[str]]]]],
[["a", None], ["b", None]],
{"a", "b"},
),
],
)
def test_dict_applies_serialize_on_every_list_element(
annotation, value, expected_args, mocker
):
mocked_serialize = mocker.MagicMock(side_effect=lambda s: s)
mocker.patch(
"ariadne_codegen.client_generators.dependencies.base_model."
"SCALARS_SERIALIZE_FUNCTIONS",
{str: mocked_serialize},
)

class TestModel(BaseModel):
field: annotation

TestModel.parse_obj({"field": value}).dict()

assert mocked_serialize.call_count == len(expected_args)
assert {c.args[0] for c in mocked_serialize.call_args_list} == expected_args


def test_dict_doesnt_apply_serialize_on_not_matching_type(mocker):
mocked_serialize = mocker.MagicMock(side_effect=lambda s: s)
mocker.patch(
"ariadne_codegen.client_generators.dependencies.base_model."
"SCALARS_SERIALIZE_FUNCTIONS",
{str: mocked_serialize},
)

class TestModel(BaseModel):
field_a: int
field_b: Optional[int]
field_c: Optional[int]
field_d: List[int]
field_e: Optional[List[int]]
field_f: Optional[List[int]]
field_g: Optional[List[Optional[int]]]
field_h: Optional[List[Optional[int]]]
field_i: Optional[List[Optional[int]]]

TestModel.parse_obj(
{
"field_a": 1,
"field_b": 2,
"field_c": None,
"field_d": [3, 4],
"field_e": [5, 6],
"field_f": None,
"field_g": [7, 8],
"field_h": [9, None],
"field_i": None,
}
).dict()

assert not mocked_serialize.called


def test_dict_applies_serialize_only_once_for_every_element(mocker):
mocked_serialize = mocker.MagicMock(side_effect=lambda s: s)
mocker.patch(
"ariadne_codegen.client_generators.dependencies.base_model."
"SCALARS_SERIALIZE_FUNCTIONS",
{str: mocked_serialize},
)

class TestModelC(BaseModel):
value: str

class TestModelB(BaseModel):
value: str
field_c: TestModelC

class TestModelA(BaseModel):
value: str
field_b: TestModelB

TestModelA.parse_obj(
{"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}}
).dict()

assert mocked_serialize.call_count == 3
assert {c.args[0] for c in mocked_serialize.call_args_list} == {"a", "b", "c"}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Type, Union, get_args, get_origin

from pydantic import BaseModel as PydanticBaseModel
from pydantic.class_validators import validator
Expand All @@ -15,16 +15,36 @@ class Config:

# pylint: disable=no-self-argument
@validator("*", pre=True)
def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any:
decode = SCALARS_PARSE_FUNCTIONS.get(field.type_)
if decode and callable(decode):
def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any:
return cls._parse_custom_scalar_value(value, field.annotation)

@classmethod
def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any:
origin = get_origin(type_)
args = get_args(type_)
if origin is list and isinstance(value, list):
return [cls._parse_custom_scalar_value(item, args[0]) for item in value]

if origin is Union and type(None) in args:
sub_type: Any = list(filter(None, args))[0]
return cls._parse_custom_scalar_value(value, sub_type)

decode = SCALARS_PARSE_FUNCTIONS.get(type_)
if value and decode and callable(decode):
return decode(value)

return value

def dict(self, **kwargs: Any) -> Dict[str, Any]:
dict_ = super().dict(**kwargs)
for key, value in dict_.items():
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
if serialize and callable(serialize):
dict_[key] = serialize(value)
return dict_
return {key: self._serialize_value(value) for key, value in dict_.items()}

def _serialize_value(self, value: Any) -> Any:
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
if serialize and callable(serialize):
return serialize(value)

if isinstance(value, list):
return [self._serialize_value(item) for item in value]

return value
Loading

0 comments on commit 44f9740

Please sign in to comment.