Skip to content

Commit

Permalink
feat!: Polyfactory support
Browse files Browse the repository at this point in the history
  • Loading branch information
phha committed Jan 2, 2024
1 parent 655df9d commit cf94e4c
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 17 deletions.
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and extendable.

## Usage

### General usage

Use it as pytest fixture to ensure every test is run with a clean set of overrides:

```python
Expand Down Expand Up @@ -58,6 +60,8 @@ as an override.

It doesn't matter if your dependency is async or not. Overrider will do the right thing.

### Basic overrides

`override.value()` returns the override value:

```python
Expand Down Expand Up @@ -95,6 +99,8 @@ def test_get_item_drop_in(client: TestClient, override: Overrider) -> None:
assert Item(**response) == item
```

### Mocks and spies

Overrider can create mocks for you:

```python
Expand All @@ -121,6 +127,54 @@ def test_get_item_spy(client: TestClient, override: Overrider) -> None:
spy.assert_called_with(item_id=0)
```

### Auto-generated overrides

Overrider can auto-generate mock objects using [Polyfactory](https://polyfactory.litestar.dev/).

To enable this extra feature, use `pip install fastapi-overrider[polyfactory]`.

Overrider will automatically use a
[matching factory](https://polyfactory.litestar.dev/usage/library_factories/index.html)
for the given dependency.

Generate a single override value. You can provide optional keyword arguments to any of the
auto-generator methods in order to pin an attribute to a specific value, like `name` in
this example:

```python
def test_get_some_item(client: TestClient, override: Overrider) -> None:
item = override.some(lookup_item, name="Foo")

response = client.get(f"/item/{item.item_id}")

assert item.name == "Foo"
assert item == Item(**response.json())
```

You can also let Overrider generate multiple override values:

```python
def test_get_five_items(client: TestClient, override: Overrider) -> None:
items = override.batch(lookup_item, 5)

for item in items:
response = client.get(f"/item/{item.item_id}")
assert item == Item(**response.json())
```

Attempt to cover the full range of forms that a model can take:

```python
def test_cover_get_items(client: TestClient, override: Overrider) -> None:
items = override.cover(lookup_item)

for item in items:
response = client.get(f"/item/{item.item_id}")
assert item == Item(**response.json())
```

### Shortcuts

You can call Overrider directly and it will guess what you want to do:

If you pass in a callable, it will act like `override.function()`:
Expand Down Expand Up @@ -160,6 +214,8 @@ def test_get_item_call_mock(client: TestClient, override: Overrider) -> None:
assert Item(**response.json()) == item
```

### Advanced patterns

Reuse common overrides. They are composable, you can have multiple:

```python
Expand Down
89 changes: 88 additions & 1 deletion fastapi_overrider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import inspect
from collections import UserDict
from collections.abc import Callable
from collections.abc import Awaitable, Callable, Iterator
from contextlib import suppress
from functools import wraps
from typing import Any, ParamSpec, Self, TypeVar, overload
from unittest.mock import MagicMock, create_autospec, seal
Expand Down Expand Up @@ -112,6 +113,92 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002,ANN003,ANN401
self[key] = spy
return spy

with suppress(ModuleNotFoundError):
from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories import (
BaseFactory,
DataclassFactory,
TypedDictFactory,
)

# A list of factories used to generate a value.
# These will be tried in order until a matching factory has been found
factories: list[type[BaseFactory]] = []

with suppress(MissingDependencyException):
from polyfactory.factories.beanie_odm_factory import BeanieDocumentFactory

factories.append(BeanieDocumentFactory)
with suppress(MissingDependencyException):
from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory

factories.append(OdmanticModelFactory)
with suppress(MissingDependencyException):
from polyfactory.factories.msgspec_factory import MsgspecFactory

factories.append(MsgspecFactory)
with suppress(MissingDependencyException):
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

factories.append(SQLAlchemyFactory)
with suppress(MissingDependencyException):
from polyfactory.factories.pydantic_factory import ModelFactory

factories.append(ModelFactory)
factories.extend([DataclassFactory, TypedDictFactory])
with suppress(MissingDependencyException):
from polyfactory.factories.attrs_factory import AttrsFactory

factories.append(AttrsFactory)

def _get_factory(self, key: _DepType) -> BaseFactory:
return_type = inspect.get_annotations(key)["return"]
for factory in self.factories:
if factory.is_supported_type(return_type):
return factory.create_factory(model=return_type)()
message = f"Did not find a factory for type {return_type}"
raise (ValueError(message))

def some(
self,
key: Callable[_P, _T] | Callable[_P, Awaitable[_T]],
**kwargs: Any, # noqa: ANN401
) -> _T:
"""Override a dependency with a value generated by Polyfactory.
Additional keyword args are forwarded to the factory's build() method.
Returns the generated value"""
factory = self._get_factory(key)
override = factory.build(**kwargs)
return self.value(key, override)

def batch(
self,
key: Callable[_P, _T] | Callable[_P, Awaitable[_T]],
size: int,
**kwargs: Any, # noqa: ANN401
) -> Iterator[_T]:
"""Batch override a dependency with a given number of values generated by
Polyfactory.
Additional keyword args are forwarded to the factory's build() method.
Returns an Iterator of generated values"""
factory = self._get_factory(key)
overrides = factory.batch(size, **kwargs)
for override in overrides:
yield self.value(key, override)

def cover(
self,
key: Callable[_P, _T] | Callable[_P, Awaitable[_T]],
**kwargs: Any, # noqa: ANN401
) -> Iterator[_T]:
"""Override a dependency with values until full coverage has been reached.
Additional keyword args are forwarded to the factory's coverage() method.
Returns an Iterator of generated values"""
factory = self._get_factory(key)
overrides = factory.coverage(**kwargs)
for override in overrides:
yield self.value(key, override)

def __enter__(self: Self) -> Self:
self._restore_overrides = self._app.dependency_overrides
self._app.dependency_overrides = self._restore_overrides.copy()
Expand Down
68 changes: 67 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fastapi-overrider"
version = "0.6.1"
version = "0.7.0"
description = "FastAPI Dependency overrides made easy."
authors = ["Philipp Hack <[email protected]>"]
license = "MIT"
Expand All @@ -18,16 +18,18 @@ classifiers = [
python = "^3.10"
fastapi = "^0.95.1"
pytest = "^7.3.1"

polyfactory = {version = "^2.13.0", optional = true}

[tool.poetry.group.dev.dependencies]
pre-commit = "^2.21.0"
ruff = "^0.1.9"


[tool.poetry.group.test.dependencies]
httpx = "^0.26.0"

[tool.poetry.extras]
polyfactory = ["polyfactory"]

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
Loading

0 comments on commit cf94e4c

Please sign in to comment.