Skip to content

Commit

Permalink
Merge pull request #121 from binary-butterfly/116-mypy-unit-test-fixes
Browse files Browse the repository at this point in the history
Add unit tests to mypy targets; fix typing in tests (part of #116)
  • Loading branch information
binaryDiv authored May 6, 2024
2 parents d0f0d3a + f8ab400 commit 6fa48d5
Show file tree
Hide file tree
Showing 16 changed files with 102 additions and 63 deletions.
15 changes: 14 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,23 @@ write_to = "src/validataclass/_version.py"
version_scheme = "post-release"

[tool.mypy]
files = "src/"
files = ["src/", "tests/"]
mypy_path = "src/"
explicit_package_bases = true

# Enable strict type checking
strict = true

# Ignore errors like `Module "validataclass.exceptions" does not explicitly export attribute "..."`
no_implicit_reexport = false

[[tool.mypy.overrides]]
module = 'tests.*'

# Don't enforce typed definitions in tests, this is a lot of unnecessary work (most parameters would be Any anyway).
allow_untyped_defs = true

# TODO: This is the main issue with mypy and validataclass right now.
# Defining dataclasses with validators using the @validataclass decorator, like `some_field: str = StringValidator()`,
# will cause "Incompatible types in assignment" errors. Until we find a way to solve this, ignore this error for now.
disable_error_code = "assignment"
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ testing =
coverage-conditional-plugin ~= 0.5
flake8 ~= 7.0
mypy ~= 1.9
types-python-dateutil
34 changes: 25 additions & 9 deletions tests/dataclasses/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,43 @@

import dataclasses
import sys
from typing import Any
from typing import Any, Dict, Type, TYPE_CHECKING

import pytest

from validataclass.dataclasses import Default
from validataclass.validators import T_Dataclass

# TODO: Replace type alias with dataclasses.Field[Any] when removing Python 3.9 support. (#15)
if TYPE_CHECKING:
T_DataclassField = dataclasses.Field[Any]
else:
T_DataclassField = dataclasses.Field


# Test helpers for dataclass tests

def assert_field_default(field: dataclasses.Field, default_value: Any):
def assert_field_default(field: T_DataclassField, default_value: Any) -> None:
"""
Asserts that a given (vali-)dataclass field has a specified default value.
"""
# Check regular dataclass defaults
assert (
(field.default == default_value and field.default_factory is dataclasses.MISSING)
or (field.default is dataclasses.MISSING and field.default_factory() == default_value)
)
# Check that the field has a regular dataclass default VALUE or default FACTORY, but not both
assert field.default is not dataclasses.MISSING or field.default_factory is not dataclasses.MISSING
assert field.default is dataclasses.MISSING or field.default_factory is dataclasses.MISSING

# Check regular dataclass default
if field.default_factory is not dataclasses.MISSING:
assert field.default_factory() == default_value
else:
assert field.default == default_value

# Check defaults in dataclass metadata
metadata_default = field.metadata.get('validator_default')
assert isinstance(metadata_default, Default)
assert metadata_default.get_value() == default_value


def assert_field_no_default(field: dataclasses.Field):
def assert_field_no_default(field: T_DataclassField) -> None:
"""
Asserts that a given (vali-)dataclass field has no default value.
"""
Expand All @@ -43,15 +54,20 @@ def assert_field_no_default(field: dataclasses.Field):

# For Python under 3.10, check that an exception raising default_factory is set
if sys.version_info < (3, 10):
assert field.default_factory is not dataclasses.MISSING
with pytest.raises(TypeError, match="required keyword-only argument"):
field.default_factory()
else:
assert field.default_factory is dataclasses.MISSING


def get_dataclass_fields(cls) -> dict:
def get_dataclass_fields(cls: Type[T_Dataclass]) -> Dict[str, T_DataclassField]:
"""
Returns a dictionary containing all fields of a given dataclass.
"""
# Make sure the class is really a dataclass
assert dataclasses.is_dataclass(cls) and isinstance(cls, type)

# Get fields and return them as a dictionary
fields_tuple = dataclasses.fields(cls)
return {field.name: field for field in fields_tuple}
3 changes: 2 additions & 1 deletion tests/dataclasses/defaults_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from copy import copy
from typing import Any, List

import pytest

Expand Down Expand Up @@ -42,7 +43,7 @@ def test_default_immutable_values(value, expected_repr):
@staticmethod
def test_default_list_deepcopied():
""" Test Default object with a list, make sure that it is deepcopied. """
default_list = []
default_list: List[Any] = []
default = Default(default_list)

# Check string representation and value
Expand Down
8 changes: 4 additions & 4 deletions tests/dataclasses/validataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import dataclasses
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import pytest

Expand All @@ -19,7 +19,7 @@
validataclass_field,
)
from validataclass.exceptions import DataclassValidatorFieldException
from validataclass.helpers import OptionalUnset, UnsetValue
from validataclass.helpers import OptionalUnset, UnsetValue, UnsetValueType
from validataclass.validators import (
DictValidator,
IntegerValidator,
Expand Down Expand Up @@ -279,7 +279,7 @@ class SubClass(BaseClass):
# Check type annotations
assert all(fields[field].type is int for field in ['required1', 'required2', 'optional2', 'optional4'])
assert all(fields[field].type is Optional[int] for field in ['required3', 'optional1'])
assert all(fields[field].type is OptionalUnset[int] for field in ['required4', 'optional3'])
assert all(fields[field].type is Union[int, UnsetValueType] for field in ['required4', 'optional3'])

# Check validators
assert all(type(field.metadata.get('validator')) is IntegerValidator for field in fields.values())
Expand Down Expand Up @@ -391,7 +391,7 @@ class BaseB:
field_both: str = StringValidator()

@validataclass
class SubClass(BaseB, BaseA):
class SubClass(BaseB, BaseA): # type: ignore[misc]
# Override the defaults to test that the decorator recognizes all fields of both base classes.
# If it does not, a "no validator for field X" error would be raised.
field_a: int = Default(42)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
"""

from decimal import Decimal
from typing import Any, List, Union
from typing import Any, List, Tuple, Union

from validataclass.validators import Validator


def unpack_params(*args) -> List[tuple]:
def unpack_params(*args: Any) -> List[Tuple[Any, ...]]:
"""
Returns a list containing tuples build from the arguments.
Expand Down Expand Up @@ -64,7 +64,7 @@ def unpack_params(*args) -> List[tuple]:
]
```
"""
unpacked = [tuple()]
unpacked: List[Tuple[Any, ...]] = [tuple()]

for arg in args:
if type(arg) is list:
Expand Down
37 changes: 19 additions & 18 deletions tests/validators/dataclass_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataclasses import dataclass, field
from decimal import Decimal
from typing import Optional, List
from typing import Any, Dict, List, Optional

import pytest

Expand Down Expand Up @@ -94,7 +94,7 @@ class UnitTestPostValidationDataclass:
start: int = IntegerValidator()
end: int = IntegerValidator()

def __post_validate__(self):
def __post_validate__(self) -> None:
if self.start > self.end:
raise ValidationError(code='invalid_range', reason='"start" must be smaller than or equal to "end".')

Expand All @@ -110,7 +110,7 @@ class UnitTestContextSensitiveDataclass:
name: str = UnitTestContextValidator()
value: Optional[int] = IntegerValidator(), Default(None)

def __post_validate__(self, *, value_required: bool = False):
def __post_validate__(self, *, value_required: bool = False) -> None:
if value_required and self.value is None:
raise DataclassPostValidationError(field_errors={
'value': RequiredValueError(reason='Value is required in this context.'),
Expand All @@ -124,7 +124,7 @@ class UnitTestContextSensitiveDataclassWithPosArgs(UnitTestContextSensitiveDatac
"""

# Same as UnitTestContextSensitiveDataclass, but with positional arguments
def __post_validate__(self, value_required: bool = False):
def __post_validate__(self, value_required: bool = False) -> None:
super().__post_validate__(value_required=value_required)


Expand All @@ -149,7 +149,7 @@ class UnitTestContextSensitiveDataclassWithVarKwargs:
ctx_b = None
extra_kwargs = None

def __post_validate__(self, *, ctx_a: str = '', ctx_b: str = '', **kwargs):
def __post_validate__(self, *, ctx_a: str = '', ctx_b: str = '', **kwargs: Any) -> None:
self.ctx_a = ctx_a
self.ctx_b = ctx_b
self.extra_kwargs = kwargs
Expand All @@ -166,7 +166,7 @@ class UnitTestPreValidateStaticMethodDataclass:
example_int: int = IntegerValidator()

@staticmethod
def __pre_validate__(input_data: dict) -> dict:
def __pre_validate__(input_data: Dict[Any, Any]) -> Dict[Any, Any]:
mapping = {
'exampleStr': 'example_str',
'exampleInt': 'example_int',
Expand All @@ -193,7 +193,7 @@ class UnitTestPreValidateClassMethodDataclass:
example_int: int = IntegerValidator()

@classmethod
def __pre_validate__(cls, input_data: dict) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any]) -> Dict[Any, Any]:
for from_key, to_key in cls.__key_mapping.items():
if from_key in input_data:
input_data[to_key] = input_data.pop(from_key)
Expand All @@ -212,7 +212,7 @@ class UnitTestPreValidateContextSensitiveDataclass:
target_field: int = IntegerValidator()

@classmethod
def __pre_validate__(cls, input_data: dict, *, source_field_name: str) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], *, source_field_name: str) -> Dict[Any, Any]:
if source_field_name in input_data:
return {'target_field': input_data[source_field_name]}
else:
Expand All @@ -236,7 +236,7 @@ class UnitTestPreValidateContextSensitiveVarKwargsDataclass:
example_int: int = IntegerValidator()

@classmethod
def __pre_validate__(cls, input_data: dict, **kwargs) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], **kwargs: Any) -> Dict[Any, Any]:
# Fill input_data with default values based on kwargs
for key, default_value in kwargs.items():
if key not in input_data:
Expand All @@ -252,7 +252,7 @@ class UnitTestInvalidPreValidateDataclass1:
""" Dataclass with invalid __pre_validate__ class method: Not enough arguments. """

@classmethod
def __pre_validate__(cls) -> dict:
def __pre_validate__(cls) -> Dict[Any, Any]:
return {}


Expand All @@ -261,7 +261,7 @@ class UnitTestInvalidPreValidateDataclass2:
""" Dataclass with invalid __pre_validate__ static method: Not enough arguments. """

@staticmethod
def __pre_validate__() -> dict:
def __pre_validate__() -> Dict[Any, Any]:
return {}


Expand All @@ -270,7 +270,7 @@ class UnitTestInvalidPreValidateDataclass3:
""" Dataclass with invalid __pre_validate__ class method: Too many positional arguments. """

@classmethod
def __pre_validate__(cls, input_data: dict, _extra_pos_argument) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], _extra_pos_argument: Any) -> Dict[Any, Any]:
return input_data


Expand All @@ -279,7 +279,7 @@ class UnitTestInvalidPreValidateDataclass4:
""" Dataclass with invalid __pre_validate__ static method: Too many positional arguments. """

@staticmethod
def __pre_validate__(input_data: dict, _extra_pos_argument) -> dict:
def __pre_validate__(input_data: Dict[Any, Any], _extra_pos_argument: Any) -> Dict[Any, Any]:
return input_data


Expand All @@ -288,7 +288,7 @@ class UnitTestInvalidPreValidateDataclass5:
""" Dataclass with invalid __pre_validate__ class method: Too many (variable) positional arguments. """

@classmethod
def __pre_validate__(cls, input_data: dict, *_args) -> dict:
def __pre_validate__(cls, input_data: Dict[Any, Any], *_args: Any) -> Dict[Any, Any]:
return input_data


Expand All @@ -297,7 +297,7 @@ class UnitTestInvalidPreValidateDataclass6:
""" Dataclass with invalid __pre_validate__ static method: Too many (variable) positional arguments. """

@staticmethod
def __pre_validate__(input_data: dict, *_args) -> dict:
def __pre_validate__(input_data: Dict[Any, Any], *_args: Any) -> Dict[Any, Any]:
return input_data


Expand Down Expand Up @@ -839,7 +839,8 @@ def test_dataclass_with_pre_validate_methods(
dataclass_cls,
):
""" Validate dataclasses with different __pre_validate__() methods (static and class methods). """
validator = DataclassValidator(dataclass_cls)
validator: DataclassValidator[Any] = DataclassValidator(dataclass_cls)

validated_data = validator.validate(input_data)

assert validated_data.example_str == expected_example_str
Expand Down Expand Up @@ -897,7 +898,7 @@ def test_dataclass_with_pre_validate_methods_invalid(
dataclass_cls,
):
""" Validate dataclasses with different __pre_validate__() methods and invalid input. """
validator = DataclassValidator(dataclass_cls)
validator: DataclassValidator[Any] = DataclassValidator(dataclass_cls)

with pytest.raises(DictFieldsValidationError) as exception_info:
validator.validate(input_data)
Expand Down Expand Up @@ -1048,7 +1049,7 @@ def test_dataclass_with_context_sensitive_pre_validate_with_var_kwargs_invalid(
)
def test_dataclass_with_invalid_forms_of_pre_validate(dataclass_cls):
""" Test error handling for dataclasses with __pre_validate__() methods with an invalid method signature. """
validator = DataclassValidator(dataclass_cls)
validator: DataclassValidator[Any] = DataclassValidator(dataclass_cls)

with pytest.raises(DataclassInvalidPreValidateSignatureException, match=PRE_VALIDATE_INVALID_SIGNATURE_ERROR):
validator.validate({})
Expand Down
12 changes: 8 additions & 4 deletions tests/validators/datetime_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,14 @@ def test_with_local_timezone_valid(input_string, local_timezone, expected_dateti
assert validated_dt == expected_datetime

# Check timezone of datetimes by comparing their offset to UTC
assert (
validated_dt.tzinfo == expected_datetime.tzinfo
or validated_dt.tzinfo.utcoffset(validated_dt) == expected_datetime.tzinfo.utcoffset(expected_datetime)
)
if expected_datetime.tzinfo is None:
assert validated_dt.tzinfo is None
else:
assert validated_dt.tzinfo is not None
assert (
validated_dt.tzinfo == expected_datetime.tzinfo
or validated_dt.tzinfo.utcoffset(validated_dt) == expected_datetime.tzinfo.utcoffset(expected_datetime)
)

# Test DateTimeValidator with target_timezone parameter

Expand Down
4 changes: 2 additions & 2 deletions tests/validators/dict_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ class UnitTestDictValidator(DictValidator):
'value': DecimalValidator(),
'optional_value': DecimalValidator(),
}
required_fields = ['name', 'value']
required_fields = {'name', 'value'}

validator = UnitTestDictValidator()
assert validator.validate(input_dict) == expected_output
Expand Down Expand Up @@ -608,7 +608,7 @@ class UnitTestDictValidator(DictValidator):
'value': DecimalValidator(),
'optional_value': DecimalValidator(),
}
required_fields = ['name', 'value']
required_fields = {'name', 'value'}

validator = UnitTestDictValidator()

Expand Down
4 changes: 3 additions & 1 deletion tests/validators/discard_validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Use of this source code is governed by an MIT-style license that can be found in the LICENSE file.
"""

from typing import Any, List

import pytest

from validataclass.helpers import UnsetValue
Expand All @@ -15,7 +17,7 @@ class DiscardValidatorTest:
Unit tests for the DiscardValidator.
"""

example_input_data = [
example_input_data: List[Any] = [
None,
True,
False,
Expand Down
Loading

0 comments on commit 6fa48d5

Please sign in to comment.