Skip to content

Commit

Permalink
Pydantic validator (#1121)
Browse files Browse the repository at this point in the history
* Add pydantic validation

* Add pydantic plugin

* Add tests for pydantic validation and plugin

* Add pydantic to dependencies

* Resolve issues from code review

* Make type hints backward compatible

* Remove `pydantic` constraint for `vaex`

Note that vaexio/vaex#2384 has been resolved

* Improve pydantic validator test import

* Add docstring to the pydantic check_output

* Add initial pydantic data quality docs

* Fix `pydantic support` title underline

* Fix pydantic strict mode link

* Fix spacing after `code-block`

* Add pydantic plugin details

* Fix double quotes for code references

* Remove name tags

* Add additional docstring example; tweak wording
  • Loading branch information
cswartzvi authored Sep 17, 2024
1 parent 3ed61dd commit ee9e4ae
Show file tree
Hide file tree
Showing 8 changed files with 477 additions and 7 deletions.
6 changes: 6 additions & 0 deletions docs/concepts/function-modifiers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ pandera support
Hamilton has a pandera plugin for data validation that you can install with ``pip install sf-hamilton[pandera]``. Then, you can pass a pandera schema (for DataFrame or Series) to ``@check_output(schema=...)``.


pydantic support
~~~~~~~~~~~~~~~~

Hamilton also supports data validation of pydantic models, which can be enabled with ``pip install sf-hamilton[pydantic]``. With pydantic installed, you can pass any subclass of the pydantic base model to ``@check_output(model=...)``. Pydantic validation is performed in strict mode, meaning that raw values will not be coerced to the model's types. For more information on strict mode see the `pydantic docs <https://docs.pydantic.dev/latest/concepts/strict_mode/>`_.


Split node output into *n* nodes
--------------------------------

Expand Down
10 changes: 7 additions & 3 deletions docs/reference/decorators/check_output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ Note that you can also specify custom decorators using the ``@check_output_custo
See `data_quality <https://github.com/dagworks-inc/hamilton/blob/main/data\_quality.md>`_ for more information on
available validators and how to build custom ones.

Note we also have a plugin that allows you to use pandera. There are two ways to access it:
1. `@check_output(schema=pandera_schema)`
2. `@h_pandera.check_output()` on a function that declares a typed pandera dataframe as an output
Note we also have a plugins that allow for validation with the pandera and pydantic libraries. There are two ways to access these:

1. ``@check_output(schema=pandera_schema)`` or ``@check_output(model=pydantic_model)``
2. ``@h_pandera.check_output()`` or ``@h_pydantic.check_output()`` on the function that declares either a typed dataframe or a pydantic model.

----

Expand All @@ -43,3 +44,6 @@ Note we also have a plugin that allows you to use pandera. There are two ways to

.. autoclass:: hamilton.plugins.h_pandera.check_output
:special-members: __init__

.. autoclass:: hamilton.plugins.h_pydantic.check_output
:special-members: __init__
17 changes: 17 additions & 0 deletions hamilton/data_quality/default_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,23 @@ def _append_pandera_to_default_validators():
_append_pandera_to_default_validators()


def _append_pydantic_to_default_validators():
"""Utility method to append pydantic validators as needed"""
try:
import pydantic # noqa: F401
except ModuleNotFoundError:
logger.debug(
"Cannot import pydantic from pydantic_validators. Run pip install sf-hamilton[pydantic] if needed."
)
return
from hamilton.data_quality import pydantic_validators

AVAILABLE_DEFAULT_VALIDATORS.extend(pydantic_validators.PYDANTIC_VALIDATORS)


_append_pydantic_to_default_validators()


def resolve_default_validators(
output_type: Type[Type],
importance: str,
Expand Down
60 changes: 60 additions & 0 deletions hamilton/data_quality/pydantic_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any, Type

from pydantic import BaseModel, TypeAdapter, ValidationError

from hamilton.data_quality import base
from hamilton.htypes import custom_subclass_check


class PydanticModelValidator(base.BaseDefaultValidator):
"""Pydantic model compatibility validator
Note that this validator uses pydantic's strict mode, which does not allow for
coercion of data. This means that if an object does not exactly match the reference
type, it will fail validation, regardless of whether it could be coerced into the
correct type.
:param model: Pydantic model to validate against
:param importance: Importance of the validator, possible values "warn" and "fail"
:param arbitrary_types_allowed: Whether arbitrary types are allowed in the model
"""

def __init__(self, model: Type[BaseModel], importance: str):
super(PydanticModelValidator, self).__init__(importance)
self.model = model
self._model_adapter = TypeAdapter(model)

@classmethod
def applies_to(cls, datatype: Type[Type]) -> bool:
# In addition to checking for a subclass of BaseModel, we also check for dict
# as this is the standard 'de-serialized' format of pydantic models in python
return custom_subclass_check(datatype, BaseModel) or custom_subclass_check(datatype, dict)

def description(self) -> str:
return "Validates that the returned object is compatible with the specified pydantic model"

def validate(self, data: Any) -> base.ValidationResult:
try:
# Currently, validate can not alter the output data, so we must use
# strict=True. The downside to this is that data that could be coerced
# into the correct type will fail validation.
self._model_adapter.validate_python(data, strict=True)
except ValidationError as e:
return base.ValidationResult(
passes=False, message=str(e), diagnostics={"model_errors": e.errors()}
)
return base.ValidationResult(
passes=True,
message=f"Data passes pydantic check for model {str(self.model)}",
)

@classmethod
def arg(cls) -> str:
return "model"

@classmethod
def name(cls) -> str:
return "pydantic_validator"


PYDANTIC_VALIDATORS = [PydanticModelValidator]
111 changes: 111 additions & 0 deletions hamilton/plugins/h_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import List

from pydantic import BaseModel

from hamilton import node
from hamilton.data_quality import base as dq_base
from hamilton.function_modifiers import InvalidDecoratorException
from hamilton.function_modifiers import base as fm_base
from hamilton.function_modifiers import check_output as base_check_output
from hamilton.function_modifiers.validation import BaseDataValidationDecorator
from hamilton.htypes import custom_subclass_check


class check_output(BaseDataValidationDecorator):
def __init__(
self,
importance: str = dq_base.DataValidationLevel.WARN.value,
target: fm_base.TargetType = None,
):
"""Specific output-checker for pydantic models. This decorator utilizes the output type of
the function, which can be any subclass of pydantic.BaseModel. The function output must
be declared with a type hint.
:param model: The pydantic model to use for validation. If this is not provided, then the output type of the function is used.
:param importance: Importance level (either "warn" or "fail") -- see documentation for check_output for more details.
:param target: The target of the decorator -- see documentation for check_output for more details.
Here is an example of how to use this decorator with a function that returns a pydantic model:
.. code-block:: python
from hamilton.plugins import h_pydantic
from pydantic import BaseModel
class MyModel(BaseModel):
a: int
b: float
c: str
@h_pydantic.check_output()
def foo() -> MyModel:
return MyModel(a=1, b=2.0, c="hello")
Alternatively, you can return a dictionary from the function (type checkers will probably
complain about this):
.. code-block:: python
from hamilton.plugins import h_pydantic
from pydantic import BaseModel
class MyModel(BaseModel):
a: int
b: float
c: str
@h_pydantic.check_output()
def foo() -> MyModel:
return {"a": 1, "b": 2.0, "c": "hello"}
You can also use pydantic validation through ``function_modifiers.check_output`` by
providing the model as an argument:
.. code-block:: python
from typing import Any
from hamilton import function_modifiers
from pydantic import BaseModel
class MyModel(BaseModel):
a: int
b: float
c: str
@function_modifiers.check_output(model=MyModel)
def foo() -> dict[str, Any]:
return {"a": 1, "b": 2.0, "c": "hello"}
Note, that because we do not (yet) support modification of the output, the validation is
performed in strict mode, meaning that no data coercion is performed. For example, the
following function will *fail* validation:
.. code-block:: python
from hamilton.plugins import h_pydantic
from pydantic import BaseModel
class MyModel(BaseModel):
a: int # Defined as an int
@h_pydantic.check_output() # This will fail validation!
def foo() -> MyModel:
return MyModel(a="1") # Assigned as a string
For more information about strict mode see the pydantic docs: https://docs.pydantic.dev/latest/concepts/strict_mode/
"""
super(check_output, self).__init__(target)
self.importance = importance
self.target = target

def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]:
output_type = node_to_validate.type
if not custom_subclass_check(output_type, BaseModel):
raise InvalidDecoratorException(
f"Output of function {node_to_validate.name} must be a Pydantic model"
)
return base_check_output(
importance=self.importance, model=output_type, target_=self.target
).get_validators(node_to_validate)
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ docs = [
"pillow",
"polars",
"pyarrow >= 1.0.0",
"pydantic >=2.0",
"pyspark",
"openlineage-python",
"PyYAML",
Expand All @@ -99,6 +100,7 @@ packaging = [
"build",
]
pandera = ["pandera"]
pydantic = ["pydantic>=2.0"]
pyspark = [
# we have to run these dependencies because Spark does not check to ensure the right target was called
"pyspark[pandas_on_spark,sql]"
Expand Down Expand Up @@ -129,6 +131,7 @@ test = [
"plotly",
"polars",
"pyarrow",
"pydantic >=2.0",
"pyreadstat", # for SPSS data loader
"pytest",
"pytest-asyncio",
Expand All @@ -144,10 +147,7 @@ test = [
]
tqdm = ["tqdm"]
ui = ["sf-hamilton-ui"]
vaex = [
"pydantic<2.0", # because of https://github.com/vaexio/vaex/issues/2384
"vaex"
]
vaex = ["vaex"]
visualization = ["graphviz", "networkx"]

[project.entry-points.console_scripts]
Expand Down
1 change: 1 addition & 0 deletions tests/integrations/pydantic/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Additional requirements on top of hamilton...pydantic
Loading

0 comments on commit ee9e4ae

Please sign in to comment.