Skip to content

Commit 1c0d1ec

Browse files
committed
rework union type logic to preserve original type name for class when possible
1 parent 6a2c525 commit 1c0d1ec

15 files changed

+276
-114
lines changed

.changeset/nullable_enum_fix.md

-7
This file was deleted.

.changeset/union_fixes.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
default: patch
3+
---
4+
5+
# Fix class generation for some union types
6+
7+
Fixed issue #1120, where certain combinations of types-- such as a `oneOf` between a model or an enum and null, or the OpenAPI 3.0 equivalent of using `nullable: true`-- could cause unnecessary suffixes like "Type0" to be added to the class name, and/or could cause extra copies of the class to be generated.

end_to_end_tests/golden-record/my_test_api_client/models/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454
from .model_with_additional_properties_refed import ModelWithAdditionalPropertiesRefed
5555
from .model_with_any_json_properties import ModelWithAnyJsonProperties
56-
from .model_with_any_json_properties_additional_property_type_0 import ModelWithAnyJsonPropertiesAdditionalPropertyType0
56+
from .model_with_any_json_properties_additional_property import ModelWithAnyJsonPropertiesAdditionalProperty
5757
from .model_with_backslash_in_description import ModelWithBackslashInDescription
5858
from .model_with_circular_ref_a import ModelWithCircularRefA
5959
from .model_with_circular_ref_b import ModelWithCircularRefB
@@ -81,8 +81,8 @@
8181
from .post_naming_property_conflict_with_import_body import PostNamingPropertyConflictWithImportBody
8282
from .post_naming_property_conflict_with_import_response_200 import PostNamingPropertyConflictWithImportResponse200
8383
from .post_responses_unions_simple_before_complex_response_200 import PostResponsesUnionsSimpleBeforeComplexResponse200
84-
from .post_responses_unions_simple_before_complex_response_200a_type_1 import (
85-
PostResponsesUnionsSimpleBeforeComplexResponse200AType1,
84+
from .post_responses_unions_simple_before_complex_response_200a import (
85+
PostResponsesUnionsSimpleBeforeComplexResponse200A,
8686
)
8787
from .test_inline_objects_body import TestInlineObjectsBody
8888
from .test_inline_objects_response_200 import TestInlineObjectsResponse200
@@ -134,7 +134,7 @@
134134
"ModelWithAdditionalPropertiesInlinedAdditionalProperty",
135135
"ModelWithAdditionalPropertiesRefed",
136136
"ModelWithAnyJsonProperties",
137-
"ModelWithAnyJsonPropertiesAdditionalPropertyType0",
137+
"ModelWithAnyJsonPropertiesAdditionalProperty",
138138
"ModelWithBackslashInDescription",
139139
"ModelWithCircularRefA",
140140
"ModelWithCircularRefB",
@@ -162,7 +162,7 @@
162162
"PostNamingPropertyConflictWithImportBody",
163163
"PostNamingPropertyConflictWithImportResponse200",
164164
"PostResponsesUnionsSimpleBeforeComplexResponse200",
165-
"PostResponsesUnionsSimpleBeforeComplexResponse200AType1",
165+
"PostResponsesUnionsSimpleBeforeComplexResponse200A",
166166
"TestInlineObjectsBody",
167167
"TestInlineObjectsResponse200",
168168
"ValidationError",

end_to_end_tests/golden-record/my_test_api_client/models/a_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@ def _parse_nullable_model(data: object) -> Union["ModelWithUnionProperty", None]
353353
try:
354354
if not isinstance(data, dict):
355355
raise TypeError()
356-
nullable_model_type_1 = ModelWithUnionProperty.from_dict(data)
356+
nullable_model = ModelWithUnionProperty.from_dict(data)
357357

358-
return nullable_model_type_1
358+
return nullable_model
359359
except: # noqa: E722
360360
pass
361361
return cast(Union["ModelWithUnionProperty", None], data)
@@ -498,9 +498,9 @@ def _parse_not_required_nullable_model(data: object) -> Union["ModelWithUnionPro
498498
try:
499499
if not isinstance(data, dict):
500500
raise TypeError()
501-
not_required_nullable_model_type_1 = ModelWithUnionProperty.from_dict(data)
501+
not_required_nullable_model = ModelWithUnionProperty.from_dict(data)
502502

503-
return not_required_nullable_model_type_1
503+
return not_required_nullable_model
504504
except: # noqa: E722
505505
pass
506506
return cast(Union["ModelWithUnionProperty", None, Unset], data)

end_to_end_tests/golden-record/my_test_api_client/models/extended.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,9 @@ def _parse_nullable_model(data: object) -> Union["ModelWithUnionProperty", None]
361361
try:
362362
if not isinstance(data, dict):
363363
raise TypeError()
364-
nullable_model_type_1 = ModelWithUnionProperty.from_dict(data)
364+
nullable_model = ModelWithUnionProperty.from_dict(data)
365365

366-
return nullable_model_type_1
366+
return nullable_model
367367
except: # noqa: E722
368368
pass
369369
return cast(Union["ModelWithUnionProperty", None], data)
@@ -506,9 +506,9 @@ def _parse_not_required_nullable_model(data: object) -> Union["ModelWithUnionPro
506506
try:
507507
if not isinstance(data, dict):
508508
raise TypeError()
509-
not_required_nullable_model_type_1 = ModelWithUnionProperty.from_dict(data)
509+
not_required_nullable_model = ModelWithUnionProperty.from_dict(data)
510510

511-
return not_required_nullable_model_type_1
511+
return not_required_nullable_model
512512
except: # noqa: E722
513513
pass
514514
return cast(Union["ModelWithUnionProperty", None, Unset], data)

end_to_end_tests/golden-record/my_test_api_client/models/model_with_any_json_properties.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
from attrs import field as _attrs_field
55

66
if TYPE_CHECKING:
7-
from ..models.model_with_any_json_properties_additional_property_type_0 import (
8-
ModelWithAnyJsonPropertiesAdditionalPropertyType0,
9-
)
7+
from ..models.model_with_any_json_properties_additional_property import ModelWithAnyJsonPropertiesAdditionalProperty
108

119

1210
T = TypeVar("T", bound="ModelWithAnyJsonProperties")
@@ -17,17 +15,17 @@ class ModelWithAnyJsonProperties:
1715
""" """
1816

1917
additional_properties: Dict[
20-
str, Union["ModelWithAnyJsonPropertiesAdditionalPropertyType0", List[str], bool, float, int, str]
18+
str, Union["ModelWithAnyJsonPropertiesAdditionalProperty", List[str], bool, float, int, str]
2119
] = _attrs_field(init=False, factory=dict)
2220

2321
def to_dict(self) -> Dict[str, Any]:
24-
from ..models.model_with_any_json_properties_additional_property_type_0 import (
25-
ModelWithAnyJsonPropertiesAdditionalPropertyType0,
22+
from ..models.model_with_any_json_properties_additional_property import (
23+
ModelWithAnyJsonPropertiesAdditionalProperty,
2624
)
2725

2826
field_dict: Dict[str, Any] = {}
2927
for prop_name, prop in self.additional_properties.items():
30-
if isinstance(prop, ModelWithAnyJsonPropertiesAdditionalPropertyType0):
28+
if isinstance(prop, ModelWithAnyJsonPropertiesAdditionalProperty):
3129
field_dict[prop_name] = prop.to_dict()
3230
elif isinstance(prop, list):
3331
field_dict[prop_name] = prop
@@ -39,8 +37,8 @@ def to_dict(self) -> Dict[str, Any]:
3937

4038
@classmethod
4139
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
42-
from ..models.model_with_any_json_properties_additional_property_type_0 import (
43-
ModelWithAnyJsonPropertiesAdditionalPropertyType0,
40+
from ..models.model_with_any_json_properties_additional_property import (
41+
ModelWithAnyJsonPropertiesAdditionalProperty,
4442
)
4543

4644
d = src_dict.copy()
@@ -51,13 +49,13 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
5149

5250
def _parse_additional_property(
5351
data: object,
54-
) -> Union["ModelWithAnyJsonPropertiesAdditionalPropertyType0", List[str], bool, float, int, str]:
52+
) -> Union["ModelWithAnyJsonPropertiesAdditionalProperty", List[str], bool, float, int, str]:
5553
try:
5654
if not isinstance(data, dict):
5755
raise TypeError()
58-
additional_property_type_0 = ModelWithAnyJsonPropertiesAdditionalPropertyType0.from_dict(data)
56+
additional_property = ModelWithAnyJsonPropertiesAdditionalProperty.from_dict(data)
5957

60-
return additional_property_type_0
58+
return additional_property
6159
except: # noqa: E722
6260
pass
6361
try:
@@ -69,7 +67,7 @@ def _parse_additional_property(
6967
except: # noqa: E722
7068
pass
7169
return cast(
72-
Union["ModelWithAnyJsonPropertiesAdditionalPropertyType0", List[str], bool, float, int, str], data
70+
Union["ModelWithAnyJsonPropertiesAdditionalProperty", List[str], bool, float, int, str], data
7371
)
7472

7573
additional_property = _parse_additional_property(prop_dict)
@@ -85,13 +83,11 @@ def additional_keys(self) -> List[str]:
8583

8684
def __getitem__(
8785
self, key: str
88-
) -> Union["ModelWithAnyJsonPropertiesAdditionalPropertyType0", List[str], bool, float, int, str]:
86+
) -> Union["ModelWithAnyJsonPropertiesAdditionalProperty", List[str], bool, float, int, str]:
8987
return self.additional_properties[key]
9088

9189
def __setitem__(
92-
self,
93-
key: str,
94-
value: Union["ModelWithAnyJsonPropertiesAdditionalPropertyType0", List[str], bool, float, int, str],
90+
self, key: str, value: Union["ModelWithAnyJsonPropertiesAdditionalProperty", List[str], bool, float, int, str]
9591
) -> None:
9692
self.additional_properties[key] = value
9793

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
T = TypeVar("T", bound="ModelWithAnyJsonPropertiesAdditionalPropertyType0")
6+
T = TypeVar("T", bound="ModelWithAnyJsonPropertiesAdditionalProperty")
77

88

99
@_attrs_define
10-
class ModelWithAnyJsonPropertiesAdditionalPropertyType0:
10+
class ModelWithAnyJsonPropertiesAdditionalProperty:
1111
""" """
1212

1313
additional_properties: Dict[str, str] = _attrs_field(init=False, factory=dict)
@@ -21,10 +21,10 @@ def to_dict(self) -> Dict[str, Any]:
2121
@classmethod
2222
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
2323
d = src_dict.copy()
24-
model_with_any_json_properties_additional_property_type_0 = cls()
24+
model_with_any_json_properties_additional_property = cls()
2525

26-
model_with_any_json_properties_additional_property_type_0.additional_properties = d
27-
return model_with_any_json_properties_additional_property_type_0
26+
model_with_any_json_properties_additional_property.additional_properties = d
27+
return model_with_any_json_properties_additional_property
2828

2929
@property
3030
def additional_keys(self) -> List[str]:

end_to_end_tests/golden-record/my_test_api_client/models/post_responses_unions_simple_before_complex_response_200.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from attrs import field as _attrs_field
55

66
if TYPE_CHECKING:
7-
from ..models.post_responses_unions_simple_before_complex_response_200a_type_1 import (
8-
PostResponsesUnionsSimpleBeforeComplexResponse200AType1,
7+
from ..models.post_responses_unions_simple_before_complex_response_200a import (
8+
PostResponsesUnionsSimpleBeforeComplexResponse200A,
99
)
1010

1111

@@ -16,19 +16,19 @@
1616
class PostResponsesUnionsSimpleBeforeComplexResponse200:
1717
"""
1818
Attributes:
19-
a (Union['PostResponsesUnionsSimpleBeforeComplexResponse200AType1', str]):
19+
a (Union['PostResponsesUnionsSimpleBeforeComplexResponse200A', str]):
2020
"""
2121

22-
a: Union["PostResponsesUnionsSimpleBeforeComplexResponse200AType1", str]
22+
a: Union["PostResponsesUnionsSimpleBeforeComplexResponse200A", str]
2323
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
2424

2525
def to_dict(self) -> Dict[str, Any]:
26-
from ..models.post_responses_unions_simple_before_complex_response_200a_type_1 import (
27-
PostResponsesUnionsSimpleBeforeComplexResponse200AType1,
26+
from ..models.post_responses_unions_simple_before_complex_response_200a import (
27+
PostResponsesUnionsSimpleBeforeComplexResponse200A,
2828
)
2929

3030
a: Union[Dict[str, Any], str]
31-
if isinstance(self.a, PostResponsesUnionsSimpleBeforeComplexResponse200AType1):
31+
if isinstance(self.a, PostResponsesUnionsSimpleBeforeComplexResponse200A):
3232
a = self.a.to_dict()
3333
else:
3434
a = self.a
@@ -45,22 +45,22 @@ def to_dict(self) -> Dict[str, Any]:
4545

4646
@classmethod
4747
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
48-
from ..models.post_responses_unions_simple_before_complex_response_200a_type_1 import (
49-
PostResponsesUnionsSimpleBeforeComplexResponse200AType1,
48+
from ..models.post_responses_unions_simple_before_complex_response_200a import (
49+
PostResponsesUnionsSimpleBeforeComplexResponse200A,
5050
)
5151

5252
d = src_dict.copy()
5353

54-
def _parse_a(data: object) -> Union["PostResponsesUnionsSimpleBeforeComplexResponse200AType1", str]:
54+
def _parse_a(data: object) -> Union["PostResponsesUnionsSimpleBeforeComplexResponse200A", str]:
5555
try:
5656
if not isinstance(data, dict):
5757
raise TypeError()
58-
a_type_1 = PostResponsesUnionsSimpleBeforeComplexResponse200AType1.from_dict(data)
58+
a = PostResponsesUnionsSimpleBeforeComplexResponse200A.from_dict(data)
5959

60-
return a_type_1
60+
return a
6161
except: # noqa: E722
6262
pass
63-
return cast(Union["PostResponsesUnionsSimpleBeforeComplexResponse200AType1", str], data)
63+
return cast(Union["PostResponsesUnionsSimpleBeforeComplexResponse200A", str], data)
6464

6565
a = _parse_a(d.pop("a"))
6666

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
T = TypeVar("T", bound="PostResponsesUnionsSimpleBeforeComplexResponse200AType1")
6+
T = TypeVar("T", bound="PostResponsesUnionsSimpleBeforeComplexResponse200A")
77

88

99
@_attrs_define
10-
class PostResponsesUnionsSimpleBeforeComplexResponse200AType1:
10+
class PostResponsesUnionsSimpleBeforeComplexResponse200A:
1111
""" """
1212

1313
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
@@ -21,10 +21,10 @@ def to_dict(self) -> Dict[str, Any]:
2121
@classmethod
2222
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
2323
d = src_dict.copy()
24-
post_responses_unions_simple_before_complex_response_200a_type_1 = cls()
24+
post_responses_unions_simple_before_complex_response_200a = cls()
2525

26-
post_responses_unions_simple_before_complex_response_200a_type_1.additional_properties = d
27-
return post_responses_unions_simple_before_complex_response_200a_type_1
26+
post_responses_unions_simple_before_complex_response_200a.additional_properties = d
27+
return post_responses_unions_simple_before_complex_response_200a
2828

2929
@property
3030
def additional_keys(self) -> List[str]:

openapi_python_client/parser/properties/enum_property.py

+21-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from openapi_python_client.parser.properties.has_named_class import HasNamedClass
4+
from openapi_python_client.schema.data_type import DataType
5+
36
__all__ = ["EnumProperty", "ValueType"]
47

58
from typing import Any, ClassVar, List, Union, cast
@@ -19,7 +22,7 @@
1922

2023

2124
@define
22-
class EnumProperty(PropertyProtocol):
25+
class EnumProperty(PropertyProtocol, HasNamedClass):
2326
"""A property that should use an enum"""
2427

2528
name: str
@@ -42,7 +45,7 @@ class EnumProperty(PropertyProtocol):
4245
}
4346

4447
@classmethod
45-
def build(
48+
def build( # noqa: PLR0911
4649
cls,
4750
*,
4851
data: oai.Schema,
@@ -102,6 +105,21 @@ def build(
102105
Union[List[int], List[str]], unchecked_value_list
103106
) # We checked this with all the value_types stuff
104107

108+
if allow_null: # Only one of the values was None, that becomes a union
109+
data.oneOf = [
110+
oai.Schema(type=DataType.NULL),
111+
data.model_copy(update={"enum": value_list, "default": data.default}),
112+
]
113+
data.enum = None
114+
return UnionProperty.build(
115+
data=data,
116+
name=name,
117+
required=required,
118+
schemas=schemas,
119+
parent_name=parent_name,
120+
config=config,
121+
)
122+
105123
class_name = data.title or name
106124
if parent_name:
107125
class_name = f"{utils.pascal_case(parent_name)}{utils.pascal_case(class_name)}"
@@ -135,38 +153,8 @@ def build(
135153
return checked_default, schemas
136154
prop = evolve(prop, default=checked_default)
137155

138-
# Now, if one of the values was None, wrap the type we just made in a union with None. We're
139-
# constructing the UnionProperty directly instead of using UnionProperty.build() because we
140-
# do *not* want the usual union behavior of creating ThingType1, ThingType2, etc.
141-
returned_prop: EnumProperty | UnionProperty | PropertyError
142-
if allow_null:
143-
none_prop = NoneProperty.build(
144-
name=name,
145-
required=required,
146-
default=None,
147-
python_name=prop.python_name,
148-
description=None,
149-
example=None,
150-
)
151-
assert not isinstance(none_prop, PropertyError)
152-
union_prop = UnionProperty(
153-
name=name,
154-
required=required,
155-
default=checked_default,
156-
python_name=prop.python_name,
157-
description=data.description,
158-
example=data.example,
159-
inner_properties=[
160-
none_prop,
161-
prop,
162-
],
163-
)
164-
returned_prop = union_prop
165-
else:
166-
returned_prop = prop
167-
168156
schemas = evolve(schemas, classes_by_name={**schemas.classes_by_name, class_info.name: prop})
169-
return returned_prop, schemas
157+
return prop, schemas
170158

171159
def convert_value(self, value: Any) -> Value | PropertyError | None:
172160
if value is None or isinstance(value, Value):

0 commit comments

Comments
 (0)