Skip to content

fix: no blank children names #666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,21 @@ def from_field_info(
if is_union(annotation):
constraints = {}
children = []

# create a child for each of the possible union values
for arg in get_args(annotation):
# don't add the NoneType in an optional to the list of children
if arg is NoneType:
continue
child_field_info = FieldInfo.from_annotation(arg)
merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info)

children.append(
# recurse for each element of the union
cls.from_field_info(
field_name="",
# this is a fake field name, but it makes it possible to debug which type variant
# is the source of an exception downstream
field_name=field_name,
field_info=merged_field_info,
use_alias=use_alias,
),
Expand Down
2 changes: 2 additions & 0 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def from_type(
metadata = cls.get_constraints_metadata(annotation)
constraints = cls.parse_constraints(metadata)

# annotations can take many forms: Optional, an Annotated type, or anything with __args__
# in order to normalize the annotation, we need to unwrap the annotation.
if not annotated and (origin := get_origin(annotation)) and origin in TYPE_MAPPING:
container = TYPE_MAPPING[origin]
annotation = container[get_args(annotation)] # type: ignore[index]
Expand Down
9 changes: 7 additions & 2 deletions polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def unwrap_new_type(annotation: Any) -> Any:


def unwrap_union(annotation: Any, random: Random) -> Any:
"""Unwraps union types - recursively.
"""Unwraps union types recursively and picks a random type from each union.

:param annotation: A type annotation, possibly a type union.
:param random: An instance of random.Random.
Expand Down Expand Up @@ -91,7 +91,12 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any:


def flatten_annotation(annotation: Any) -> list[Any]:
"""Flattens an annotation.
"""Flattens an annotation into an array of possible types. For example:

* Union[str, int] → [str, int]
* Optional[str] → [str, None]
* Union[str, Optional[int]] → [str, int, None]
* NewType('UserId', int) → [int]

:param annotation: A type annotation.

Expand Down
2 changes: 1 addition & 1 deletion polyfactory/utils/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def is_new_type(annotation: Any) -> "TypeGuard[type[NewType]]":


def is_annotated(annotation: Any) -> bool:
"""Determine whether a given annotation is 'typing.Annotated'.
"""Determine whether a given annotation is 'typing.Annotated', or a similar typing annotation such as Optional.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? The function looks for annotated origin or similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is why i left this comment, it's pretty tricky

In [12]: a
Out[12]: typing.Optional[int]

In [13]: a.__args__
Out[13]: (int, NoneType)

In [14]: b = int | None
Out[14]: int | None

In [15]: b.__args__
Out[15]: (int, NoneType)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this match _AnnotatedAlias or other conditions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with _AnnotatedAlias, not sure...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on testing in terminal think this handles this correctly so fine as is?

>>> from polyfactory.utils.predicates import is_annotated
>>> from typing import Optional
>>> is_annotated(Optional[int])
False
>>> is_annotated(int | None)
False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't think the functionality is broken here, it was just challenging to understand what was going on so I added the comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this checks optional type based on https://github.com/litestar-org/polyfactory/pull/666/files#r2014961152. Can this be reverted?


:param annotation: A type annotation.

Expand Down
47 changes: 46 additions & 1 deletion tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from decimal import Decimal
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from typing import Callable, Dict, FrozenSet, List, Literal, Optional, Sequence, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, FrozenSet, List, Literal, Optional, Sequence, Set, Tuple, Type, Union
from uuid import UUID

import pytest
Expand Down Expand Up @@ -64,8 +64,10 @@
validator,
)

from polyfactory.exceptions import ParameterException
from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory
from polyfactory.field_meta import FieldMeta
from tests.models import Person, PetFactory

IS_PYDANTIC_V1 = _IS_PYDANTIC_V1
Expand Down Expand Up @@ -634,6 +636,49 @@ class A(BaseModel):
assert AFactory.build()


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires modern union types")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these be split out so this is tested on all python versions?

@pytest.mark.skipif(IS_PYDANTIC_V1, reason="pydantic 2 only test")
def test_optional_custom_type() -> None:
from pydantic_core import core_schema

class CustomType:
def __init__(self, _: Any) -> None:
pass

def __get_pydantic_core_schema__(self, _: Any) -> core_schema.StringSchema:
# for pydantic to stop complaining
return core_schema.str_schema()

class OptionalFormOne(BaseModel):
optional_custom_type: Optional[CustomType]

@classmethod
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
return False

class OptionalFormOneFactory(ModelFactory[OptionalFormOne]):
@classmethod
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
return False

class OptionalFormTwo(BaseModel):
# this is represented differently than `Optional[None]` internally
optional_custom_type_second_form: CustomType | None

class OptionalFormTwoFactory(ModelFactory[OptionalFormTwo]):
@classmethod
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
return False

# ensure the custom type field name and variant is in the error message

with pytest.raises(ParameterException, match=r"optional_custom_type"):
OptionalFormOneFactory.build()

with pytest.raises(ParameterException, match=r"optional_custom_type_second_form"):
OptionalFormTwoFactory.build()


def test_collection_unions_with_models() -> None:
class A(BaseModel):
a: int
Expand Down
Loading