Skip to content

Commit

Permalink
Improve KeyTransform initializer and types
Browse files Browse the repository at this point in the history
  • Loading branch information
adamchainz committed Aug 28, 2022
1 parent 5019529 commit fcc0e0c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
30 changes: 16 additions & 14 deletions src/django_mysql/models/fields/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,7 @@ class KeyTransform(Transform):

SPEC_MAP_NAMES = ", ".join(sorted(x.__name__ for x in SPEC_MAP.keys()))

TYPE_MAP: dict[str, type[Field] | Field] = {
"BINARY": DynamicField,
TYPE_MAP: dict[str, Field[Any, Any]] = {
"CHAR": TextField(),
"DATE": DateField(),
"DATETIME": DateTimeField(),
Expand All @@ -322,23 +321,26 @@ def __init__(
self,
key_name: str,
data_type: str,
*args: Any,
*expressions: Any,
subspec: SpecDict | None = None,
**kwargs: Any,
output_field: Field[Any, Any] | None = None,
**extra: Any,
) -> None:
super().__init__(*args, **kwargs)
self.key_name = key_name
self.data_type = data_type

try:
output_field = self.TYPE_MAP[data_type]
except KeyError: # pragma: no cover
raise ValueError(f"Invalid data_type '{data_type}'")
if output_field is not None:
raise ValueError("Cannot set output_field for KeyTransform")

if data_type == "BINARY":
self.output_field = output_field(spec=subspec)
output_field = DynamicField(spec=subspec)
else:
self.output_field = output_field
try:
output_field = self.TYPE_MAP[data_type]
except KeyError:
raise ValueError(f"Invalid data_type {data_type!r}")

super().__init__(*expressions, output_field=output_field, **extra)

self.key_name = key_name
self.data_type = data_type

def as_sql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
Expand Down
13 changes: 13 additions & 0 deletions tests/testapp/test_dynamicfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.test.utils import isolate_apps

from django_mysql.models import DynamicField
from django_mysql.models.fields.dynamic import KeyTransform
from tests.testapp.models import DynamicModel, SpeclessDynamicModel


Expand Down Expand Up @@ -159,6 +160,18 @@ def test_non_existent_transform(self):
def test_has_key(self):
assert list(DynamicModel.objects.filter(attrs__has_key="c")) == self.objs[1:3]

def test_key_transform_initialize_output_field(self):
with pytest.raises(ValueError) as excinfo:
KeyTransform("x", "y", output_field=CharField())

assert str(excinfo.value) == "Cannot set output_field for KeyTransform"

def test_key_transform_initialize_bad_type(self):
with pytest.raises(ValueError) as excinfo:
KeyTransform("x", "unknown")

assert str(excinfo.value) == "Invalid data_type 'unknown'"

def test_key_transform_datey(self):
assert list(DynamicModel.objects.filter(attrs__datey=dt.date(2001, 1, 4))) == [
self.objs[4]
Expand Down

0 comments on commit fcc0e0c

Please sign in to comment.