Skip to content

Commit

Permalink
Merge pull request #196 from mirumee/pydantic_reserved_field_names
Browse files Browse the repository at this point in the history
Handle field names reserved by pydantic
  • Loading branch information
mat-sop authored Aug 22, 2023
2 parents 05cc3a5 + 0e025c1 commit b24d10c
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- Changed generated client and models to use pydantic v2.
- Changed custom scalars implementation to utilize pydantic's `BeforeValidator` and `PlainSerializer`. Added `scalars_module_name` option. Replaced `generate_scalars_parse_dict` and `generate_scalars_serialize_dict` with `generate_scalar_annotation` and `generate_scalar_imports` plugin hooks.
- Fixed generating default values of input types from remote schemas.
- Changed generating of input and result field names to add `_` to names reserved by pydantic.


## 0.7.1 (2023-06-06)
Expand Down
1 change: 1 addition & 0 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _parse_input_definition(
plugin_manager=self.plugin_manager,
node=field,
trim_leading_underscore=True,
handle_pydantic_resrved_field_names=True,
)
annotation, field_type = parse_input_field_type(
field.type, custom_scalars=self.custom_scalars
Expand Down
1 change: 1 addition & 0 deletions ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def _process_field_name(self, name: str, field: FieldNode) -> str:
plugin_manager=self.plugin_manager,
node=field,
trim_leading_underscore=True,
handle_pydantic_resrved_field_names=True,
)

def _get_field_from_schema(self, type_name: str, field_name: str) -> GraphQLField:
Expand Down
11 changes: 11 additions & 0 deletions ariadne_codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
from autoflake import fix_code # type: ignore
from black import Mode, format_str
from graphql import Node
from pydantic import BaseModel

from .plugins.manager import PluginManager

PYDANTIC_RESERVED_FIELD_NAMES = [
name for name in dir(BaseModel) if not name.startswith("_")
]


def ast_to_str(
ast_obj: ast.AST,
Expand Down Expand Up @@ -83,6 +88,7 @@ def process_name(
plugin_manager: Optional[PluginManager] = None,
node: Optional[Node] = None,
trim_leading_underscore: bool = False,
handle_pydantic_resrved_field_names: bool = False,
) -> str:
"""Processes the GraphQL name to remove keywords
and optionally convert to snake_case."""
Expand All @@ -91,6 +97,11 @@ def process_name(
processed_name = str_to_snake_case(processed_name)
if iskeyword(processed_name):
processed_name += "_"
if (
handle_pydantic_resrved_field_names
and processed_name in PYDANTIC_RESERVED_FIELD_NAMES
):
processed_name += "_"
if trim_leading_underscore:
processed_name = processed_name.lstrip("_")
if plugin_manager:
Expand Down
2 changes: 2 additions & 0 deletions tests/client_generators/input_types_generator/test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_generate_returns_module_with_valid_field_names(
_Bar: String!
____baz_: String!
_: String!
schema: String!
}
"""

Expand All @@ -186,4 +187,5 @@ def test_generate_returns_module_with_valid_field_names(
"bar",
"baz_",
"underscore_named_field_",
"schema_",
}
1 change: 1 addition & 0 deletions tests/client_generators/result_types_generator/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_Field5: String!
scalarField: SCALARA
_: String!
schema: String!
}
type CustomType1 {
Expand Down
9 changes: 8 additions & 1 deletion tests/client_generators/result_types_generator/test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_generate_returns_module_with_valid_field_names():
_field4
_Field5
_
schema
}
}
"""
Expand All @@ -140,4 +141,10 @@ def test_generate_returns_module_with_valid_field_names():
) # Round trip because invalid identifiers get picked up in parse
class_def = get_class_def(parsed, name_filter="CustomQueryCamelCaseQuery")
field_names = get_assignment_target_names(class_def)
assert field_names == {"in_", "field4", "field5", "underscore_named_field_"}
assert field_names == {
"in_",
"field4",
"field5",
"underscore_named_field_",
"schema_",
}

0 comments on commit b24d10c

Please sign in to comment.