Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests to mypy targets; fix typing in tests (part of #116) #121

Merged
merged 4 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,23 @@ write_to = "src/validataclass/_version.py"
version_scheme = "post-release"

[tool.mypy]
files = "src/"
files = ["src/", "tests/"]
mypy_path = "src/"
explicit_package_bases = true

# Enable strict type checking
strict = true

# Ignore errors like `Module "validataclass.exceptions" does not explicitly export attribute "..."`
no_implicit_reexport = false

[[tool.mypy.overrides]]
module = 'tests.*'

# Don't enforce typed definitions in tests, this is a lot of unnecessary work (most parameters would be Any anyway).
allow_untyped_defs = true

# TODO: This is the main issue with mypy and validataclass right now.
# Defining dataclasses with validators using the @validataclass decorator, like `some_field: str = StringValidator()`,
# will cause "Incompatible types in assignment" errors. Until we find a way to solve this, ignore this error for now.
disable_error_code = "assignment"
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ testing =
coverage-conditional-plugin ~= 0.5
flake8 ~= 7.0
mypy ~= 1.9
types-python-dateutil
34 changes: 25 additions & 9 deletions tests/dataclasses/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,43 @@

import dataclasses
import sys
from typing import Any
from typing import Any, Dict, Type, TYPE_CHECKING

import pytest

from validataclass.dataclasses import Default
from validataclass.validators import T_Dataclass

# TODO: Replace type alias with dataclasses.Field[Any] when removing Python 3.9 support. (#15)
if TYPE_CHECKING:
T_DataclassField = dataclasses.Field[Any]
else:
T_DataclassField = dataclasses.Field


# Test helpers for dataclass tests

def assert_field_default(field: dataclasses.Field, default_value: Any):
def assert_field_default(field: T_DataclassField, default_value: Any) -> None:
"""
Asserts that a given (vali-)dataclass field has a specified default value.
"""
# Check regular dataclass defaults
assert (
(field.default == default_value and field.default_factory is dataclasses.MISSING)
or (field.default is dataclasses.MISSING and field.default_factory() == default_value)
)
# Check that the field has a regular dataclass default VALUE or default FACTORY, but not both
assert field.default is not dataclasses.MISSING or field.default_factory is not dataclasses.MISSING
assert field.default is dataclasses.MISSING or field.default_factory is dataclasses.MISSING

# Check regular dataclass default
if field.default_factory is not dataclasses.MISSING:
assert field.default_factory() == default_value
else:
assert field.default == default_value

# Check defaults in dataclass metadata
metadata_default = field.metadata.get('validator_default')
assert isinstance(metadata_default, Default)
assert metadata_default.get_value() == default_value


def assert_field_no_default(field: dataclasses.Field):
def assert_field_no_default(field: T_DataclassField) -> None:
"""
Asserts that a given (vali-)dataclass field has no default value.
"""
Expand All @@ -43,15 +54,20 @@ def assert_field_no_default(field: dataclasses.Field):

# For Python under 3.10, check that an exception raising default_factory is set
if sys.version_info < (3, 10):
assert field.default_factory is not dataclasses.MISSING
with pytest.raises(TypeError, match="required keyword-only argument"):
field.default_factory()
else:
assert field.default_factory is dataclasses.MISSING


def get_dataclass_fields(cls) -> dict:
def get_dataclass_fields(cls: Type[T_Dataclass]) -> Dict[str, T_DataclassField]:
"""
Returns a dictionary containing all fields of a given dataclass.
"""
# Make sure the class is really a dataclass
assert dataclasses.is_dataclass(cls) and isinstance(cls, type)

# Get fields and return them as a dictionary
fields_tuple = dataclasses.fields(cls)
return {field.name: field for field in fields_tuple}
3 changes: 2 additions & 1 deletion tests/dataclasses/defaults_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from copy import copy
from typing import Any, List

import pytest

Expand Down Expand Up @@ -42,7 +43,7 @@ def test_default_immutable_values(value, expected_repr):
@staticmethod
def test_default_list_deepcopied():
""" Test Default object with a list, make sure that it is deepcopied. """
default_list = []
default_list: List[Any] = []
default = Default(default_list)

# Check string representation and value
Expand Down
8 changes: 4 additions & 4 deletions tests/dataclasses/validataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import dataclasses
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import pytest

Expand All @@ -19,7 +19,7 @@
validataclass_field,
)
from validataclass.exceptions import DataclassValidatorFieldException
from validataclass.helpers import OptionalUnset, UnsetValue
from validataclass.helpers import OptionalUnset, UnsetValue, UnsetValueType
from validataclass.validators import (
DictValidator,
IntegerValidator,
Expand Down Expand Up @@ -279,7 +279,7 @@ class SubClass(BaseClass):
# Check type annotations
assert all(fields[field].type is int for field in ['required1', 'required2', 'optional2', 'optional4'])
assert all(fields[field].type is Optional[int] for field in ['required3', 'optional1'])
assert all(fields[field].type is OptionalUnset[int] for field in ['required4', 'optional3'])
assert all(fields[field].type is Union[int, UnsetValueType] for field in ['required4', 'optional3'])

# Check validators
assert all(type(field.metadata.get('validator')) is IntegerValidator for field in fields.values())
Expand Down Expand Up @@ -391,7 +391,7 @@ class BaseB:
field_both: str = StringValidator()

@validataclass
class SubClass(BaseB, BaseA):
class SubClass(BaseB, BaseA): # type: ignore[misc]
# Override the defaults to test that the decorator recognizes all fields of both base classes.
# If it does not, a "no validator for field X" error would be raised.
field_a: int = Default(42)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
"""

from decimal import Decimal
from typing import Any, List, Union
from typing import Any, List, Tuple, Union

from validataclass.validators import Validator


def unpack_params(*args) -> List[tuple]:
def unpack_params(*args: Any) -> List[Tuple[Any, ...]]:
"""
Returns a list containing tuples build from the arguments.

Expand Down Expand Up @@ -64,7 +64,7 @@ def unpack_params(*args) -> List[tuple]:
]
```
"""
unpacked = [tuple()]
unpacked: List[Tuple[Any, ...]] = [tuple()]

for arg in args:
if type(arg) is list:
Expand Down
37 changes: 19 additions & 18 deletions tests/validators/dataclass_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataclasses import dataclass, field
from decimal import Decimal
from typing import Optional, List
from typing import Any, Dict, List, Optional

import pytest

Expand Down Expand Up @@ -94,7 +94,7 @@ class UnitTestPostValidationDataclass:
start: int = IntegerValidator()
end: int = IntegerValidator()

def __post_validate__(self):
def __post_validate__(self) -> None:
if self.start > self.end:
raise ValidationError(code='invalid_range', reason='"start" must be smaller than or equal to "end".')

Expand All @@ -110,7 +110,7 @@ class UnitTestContextSensitiveDataclass:
name: str = UnitTestContextValidator()
value: Optional[int] = IntegerValidator(), Default(None)

def __post_validate__(self, *, value_required: bool = False):
def __post_validate__(self, *, value_required: bool = False) -> None:
if value_required and self.value is None:
raise DataclassPostValidationError(field_errors={
'value': RequiredValueError(reason='Value is required in this context.'),
Expand All @@ -124,7 +124,7 @@ class UnitTestContextSensitiveDataclassWithPosArgs(UnitTestContextSensitiveDatac
"""

# Same as UnitTestContextSensitiveDataclass, but with positional arguments
def __post_validate__(self, value_required: bool = False):
def __post_validate__(self, value_required: bool = False) -> None:
super().__post_validate__(value_required=value_required)


Expand All @@ -149,7 +149,7 @@ class UnitTestContextSensitiveDataclassWithVarKwargs:
ctx_b = None
extra_kwargs = None

def __post_validate__(self, *, ctx_a: str = '', ctx_b: str = '', **kwargs):
def __post_validate__(self, *, ctx_a: str = '', ctx_b: str = '', **kwargs: Any) -> None:
self.ctx_a = ctx_a
self.ctx_b = ctx_b
self.extra_kwargs = kwargs
Expand All @@ -166,7 +166,7 @@ class UnitTestPreValidateStaticMethodDataclass:
example_int: int = IntegerValidator()

@staticmethod
def __pre_validate__(input_data: dict) -> dict:
def __pre_validate__(input_data: Dict[Any, Any]) -> Dict[Any, Any]:
mapping = {
'exampleStr': 'example_str',
'exampleInt': 'example_int',
Expand All @@ -193,7 +193,7 @@ class UnitTestPreValidateClassMethodDataclass:
example_int: int = IntegerValidator()

@classmethod
def __pre_validate__(cls, input_data: dict) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any]) -> Dict[Any, Any]:
for from_key, to_key in cls.__key_mapping.items():
if from_key in input_data:
input_data[to_key] = input_data.pop(from_key)
Expand All @@ -212,7 +212,7 @@ class UnitTestPreValidateContextSensitiveDataclass:
target_field: int = IntegerValidator()

@classmethod
def __pre_validate__(cls, input_data: dict, *, source_field_name: str) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], *, source_field_name: str) -> Dict[Any, Any]:
if source_field_name in input_data:
return {'target_field': input_data[source_field_name]}
else:
Expand All @@ -236,7 +236,7 @@ class UnitTestPreValidateContextSensitiveVarKwargsDataclass:
example_int: int = IntegerValidator()

@classmethod
def __pre_validate__(cls, input_data: dict, **kwargs) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], **kwargs: Any) -> Dict[Any, Any]:
# Fill input_data with default values based on kwargs
for key, default_value in kwargs.items():
if key not in input_data:
Expand All @@ -252,7 +252,7 @@ class UnitTestInvalidPreValidateDataclass1:
""" Dataclass with invalid __pre_validate__ class method: Not enough arguments. """

@classmethod
def __pre_validate__(cls) -> dict:
def __pre_validate__(cls) -> Dict[Any, Any]:
return {}


Expand All @@ -261,7 +261,7 @@ class UnitTestInvalidPreValidateDataclass2:
""" Dataclass with invalid __pre_validate__ static method: Not enough arguments. """

@staticmethod
def __pre_validate__() -> dict:
def __pre_validate__() -> Dict[Any, Any]:
return {}


Expand All @@ -270,7 +270,7 @@ class UnitTestInvalidPreValidateDataclass3:
""" Dataclass with invalid __pre_validate__ class method: Too many positional arguments. """

@classmethod
def __pre_validate__(cls, input_data: dict, _extra_pos_argument) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], _extra_pos_argument: Any) -> Dict[Any, Any]:
return input_data


Expand All @@ -279,7 +279,7 @@ class UnitTestInvalidPreValidateDataclass4:
""" Dataclass with invalid __pre_validate__ static method: Too many positional arguments. """

@staticmethod
def __pre_validate__(input_data: dict, _extra_pos_argument) -> dict:
def __pre_validate__(input_data: Dict[Any, Any], _extra_pos_argument: Any) -> Dict[Any, Any]:
return input_data


Expand All @@ -288,7 +288,7 @@ class UnitTestInvalidPreValidateDataclass5:
""" Dataclass with invalid __pre_validate__ class method: Too many (variable) positional arguments. """

@classmethod
def __pre_validate__(cls, input_data: dict, *_args) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], *_args: Any) -> Dict[Any, Any]:
return input_data


Expand All @@ -297,7 +297,7 @@ class UnitTestInvalidPreValidateDataclass6:
""" Dataclass with invalid __pre_validate__ static method: Too many (variable) positional arguments. """

@staticmethod
def __pre_validate__(input_data: dict, *_args) -> dict:
def __pre_validate__(input_data: Dict[Any, Any], *_args: Any) -> Dict[Any, Any]:
return input_data


Expand Down Expand Up @@ -839,7 +839,8 @@ def test_dataclass_with_pre_validate_methods(
dataclass_cls,
):
""" Validate dataclasses with different __pre_validate__() methods (static and class methods). """
validator = DataclassValidator(dataclass_cls)
validator: DataclassValidator[Any] = DataclassValidator(dataclass_cls)

validated_data = validator.validate(input_data)

assert validated_data.example_str == expected_example_str
Expand Down Expand Up @@ -897,7 +898,7 @@ def test_dataclass_with_pre_validate_methods_invalid(
dataclass_cls,
):
""" Validate dataclasses with different __pre_validate__() methods and invalid input. """
validator = DataclassValidator(dataclass_cls)
validator: DataclassValidator[Any] = DataclassValidator(dataclass_cls)

with pytest.raises(DictFieldsValidationError) as exception_info:
validator.validate(input_data)
Expand Down Expand Up @@ -1048,7 +1049,7 @@ def test_dataclass_with_context_sensitive_pre_validate_with_var_kwargs_invalid(
)
def test_dataclass_with_invalid_forms_of_pre_validate(dataclass_cls):
""" Test error handling for dataclasses with __pre_validate__() methods with an invalid method signature. """
validator = DataclassValidator(dataclass_cls)
validator: DataclassValidator[Any] = DataclassValidator(dataclass_cls)

with pytest.raises(DataclassInvalidPreValidateSignatureException, match=PRE_VALIDATE_INVALID_SIGNATURE_ERROR):
validator.validate({})
Expand Down
12 changes: 8 additions & 4 deletions tests/validators/datetime_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,14 @@ def test_with_local_timezone_valid(input_string, local_timezone, expected_dateti
assert validated_dt == expected_datetime

# Check timezone of datetimes by comparing their offset to UTC
assert (
validated_dt.tzinfo == expected_datetime.tzinfo
or validated_dt.tzinfo.utcoffset(validated_dt) == expected_datetime.tzinfo.utcoffset(expected_datetime)
)
if expected_datetime.tzinfo is None:
assert validated_dt.tzinfo is None
else:
assert validated_dt.tzinfo is not None
assert (
validated_dt.tzinfo == expected_datetime.tzinfo
or validated_dt.tzinfo.utcoffset(validated_dt) == expected_datetime.tzinfo.utcoffset(expected_datetime)
)

# Test DateTimeValidator with target_timezone parameter

Expand Down
4 changes: 2 additions & 2 deletions tests/validators/dict_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ class UnitTestDictValidator(DictValidator):
'value': DecimalValidator(),
'optional_value': DecimalValidator(),
}
required_fields = ['name', 'value']
required_fields = {'name', 'value'}
binaryDiv marked this conversation as resolved.
Show resolved Hide resolved

validator = UnitTestDictValidator()
assert validator.validate(input_dict) == expected_output
Expand Down Expand Up @@ -608,7 +608,7 @@ class UnitTestDictValidator(DictValidator):
'value': DecimalValidator(),
'optional_value': DecimalValidator(),
}
required_fields = ['name', 'value']
required_fields = {'name', 'value'}

validator = UnitTestDictValidator()

Expand Down
4 changes: 3 additions & 1 deletion tests/validators/discard_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Use of this source code is governed by an MIT-style license that can be found in the LICENSE file.
"""

from typing import Any, List

import pytest

from validataclass.helpers import UnsetValue
Expand All @@ -15,7 +17,7 @@ class DiscardValidatorTest:
Unit tests for the DiscardValidator.
"""

example_input_data = [
example_input_data: List[Any] = [
None,
True,
False,
Expand Down
Loading
Loading