diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index 287b7ae90..6318bbd33 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -308,8 +308,7 @@ class KeyTransform(Transform): SPEC_MAP_NAMES = ", ".join(sorted(x.__name__ for x in SPEC_MAP.keys())) - TYPE_MAP: dict[str, type[Field] | Field] = { - "BINARY": DynamicField, + TYPE_MAP: dict[str, Field[Any, Any]] = { "CHAR": TextField(), "DATE": DateField(), "DATETIME": DateTimeField(), @@ -322,23 +321,26 @@ def __init__( self, key_name: str, data_type: str, - *args: Any, + *expressions: Any, subspec: SpecDict | None = None, - **kwargs: Any, + output_field: Field[Any, Any] | None = None, + **extra: Any, ) -> None: - super().__init__(*args, **kwargs) - self.key_name = key_name - self.data_type = data_type - - try: - output_field = self.TYPE_MAP[data_type] - except KeyError: # pragma: no cover - raise ValueError(f"Invalid data_type '{data_type}'") + if output_field is not None: + raise ValueError("Cannot set output_field for KeyTransform") if data_type == "BINARY": - self.output_field = output_field(spec=subspec) + output_field = DynamicField(spec=subspec) else: - self.output_field = output_field + try: + output_field = self.TYPE_MAP[data_type] + except KeyError: + raise ValueError(f"Invalid data_type {data_type!r}") + + super().__init__(*expressions, output_field=output_field, **extra) + + self.key_name = key_name + self.data_type = data_type def as_sql( self, compiler: SQLCompiler, connection: BaseDatabaseWrapper diff --git a/tests/testapp/test_dynamicfield.py b/tests/testapp/test_dynamicfield.py index 95c885b93..d1bdc6a45 100644 --- a/tests/testapp/test_dynamicfield.py +++ b/tests/testapp/test_dynamicfield.py @@ -15,6 +15,7 @@ from django.test.utils import isolate_apps from django_mysql.models import DynamicField +from django_mysql.models.fields.dynamic import KeyTransform from tests.testapp.models import DynamicModel, SpeclessDynamicModel @@ -159,6 +160,18 @@ def test_non_existent_transform(self): def test_has_key(self): assert list(DynamicModel.objects.filter(attrs__has_key="c")) == self.objs[1:3] + def test_key_transform_initialize_output_field(self): + with pytest.raises(ValueError) as excinfo: + KeyTransform("x", "y", output_field=CharField()) + + assert str(excinfo.value) == "Cannot set output_field for KeyTransform" + + def test_key_transform_initialize_bad_type(self): + with pytest.raises(ValueError) as excinfo: + KeyTransform("x", "unknown") + + assert str(excinfo.value) == "Invalid data_type 'unknown'" + def test_key_transform_datey(self): assert list(DynamicModel.objects.filter(attrs__datey=dt.date(2001, 1, 4))) == [ self.objs[4]