From 1f3f315fd073b82789a60105b754f278240b9f1f Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 15 Aug 2022 10:14:13 +0100 Subject: [PATCH 01/29] Use Mypy's strict mode --- .pre-commit-config.yaml | 4 +++ pyproject.toml | 11 ++++--- src/django_mysql/compat.py | 2 ++ src/django_mysql/models/__init__.py | 30 ++++++++++++++++++- src/django_mysql/models/fields/__init__.py | 4 +-- src/django_mysql/models/fields/dynamic.py | 10 +++---- src/django_mysql/operations.py | 34 +++++++++++----------- src/django_mysql/status.py | 6 ++-- 8 files changed, 67 insertions(+), 34 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ad875596..dd5bffd6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -73,3 +73,7 @@ repos: rev: v1.11.2 hooks: - id: mypy + additional_dependencies: + - django-stubs==5.0.4 + - mysqlclient + - pytest==8.3.2 diff --git a/pyproject.toml b/pyproject.toml index 1fc4b11f..85e24dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,20 +92,19 @@ enable_error_code = [ "redundant-expr", "truthy-bool", ] -check_untyped_defs = true -disallow_any_generics = true -disallow_incomplete_defs = true -disallow_untyped_defs = true mypy_path = "src/" namespace_packages = false -no_implicit_optional = true +plugins = [ "mypy_django_plugin.main" ] +strict = true warn_unreachable = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = "tests.*" allow_untyped_defs = true +[tool.django-stubs] +django_settings_module = "tests.settings" + [tool.rstcheck] ignore_directives = [ "automodule", diff --git a/src/django_mysql/compat.py b/src/django_mysql/compat.py index a072e556..97a0b089 100644 --- a/src/django_mysql/compat.py +++ b/src/django_mysql/compat.py @@ -6,6 +6,8 @@ from typing import TypeVar from typing import cast +__all__ = ("cache",) + if sys.version_info >= (3, 9): from functools import cache else: diff --git a/src/django_mysql/models/__init__.py b/src/django_mysql/models/__init__.py index 5b074b13..36ea1f52 100644 --- a/src/django_mysql/models/__init__.py +++ b/src/django_mysql/models/__init__.py @@ -4,7 +4,7 @@ from django_mysql.models.aggregates import BitOr from django_mysql.models.aggregates import BitXor from django_mysql.models.aggregates import GroupConcat -from django_mysql.models.base import Model # noqa +from django_mysql.models.base import Model from django_mysql.models.expressions import ListF from django_mysql.models.expressions import SetF from django_mysql.models.fields import Bit1BooleanField @@ -25,3 +25,31 @@ from django_mysql.models.query import SmartIterator from django_mysql.models.query import add_QuerySetMixin from django_mysql.models.query import pt_visual_explain + +__all__ = ( + "add_QuerySetMixin", + "ApproximateInt", + "Bit1BooleanField", + "BitAnd", + "BitOr", + "BitXor", + "DynamicField", + "EnumField", + "FixedCharField", + "GroupConcat", + "ListCharField", + "ListF", + "ListTextField", + "Model", + "NullBit1BooleanField", + "pt_visual_explain", + "QuerySet", + "QuerySetMixin", + "SetCharField", + "SetF", + "SetTextField", + "SizedBinaryField", + "SizedTextField", + "SmartChunkedIterator", + "SmartIterator", +) diff --git a/src/django_mysql/models/fields/__init__.py b/src/django_mysql/models/fields/__init__.py index f413947d..f2d0e011 100644 --- a/src/django_mysql/models/fields/__init__.py +++ b/src/django_mysql/models/fields/__init__.py @@ -12,7 +12,7 @@ from django_mysql.models.fields.sizes import SizedBinaryField from django_mysql.models.fields.sizes import SizedTextField -__all__ = [ +__all__ = ( "Bit1BooleanField", "DynamicField", "EnumField", @@ -24,4 +24,4 @@ "SetTextField", "SizedBinaryField", "SizedTextField", -] +) diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index db6891b3..67dcd1d0 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -25,9 +25,9 @@ from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ -from django_mysql.checks import mysql_connections from django_mysql.models.lookups import DynColHasKey from django_mysql.typing import DeconstructResult +from django_mysql.utils import mysql_connections try: import mariadb_dyncol @@ -89,7 +89,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]: return errors def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] if mariadb_dyncol is None: errors.append( checks.Error( @@ -102,7 +102,7 @@ def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]: return errors def _check_mariadb_version(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] any_conn_works = any( (conn.vendor == "mysql" and conn.mysql_is_mariadb) @@ -121,7 +121,7 @@ def _check_mariadb_version(self) -> list[checks.CheckMessage]: return errors def _check_character_set(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] conn = None for _alias, check_conn in mysql_connections(): @@ -153,7 +153,7 @@ def _check_character_set(self) -> list[checks.CheckMessage]: def _check_spec_recursively( self, spec: Any, path: str = "" ) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] if not isinstance(spec, dict): errors.append( diff --git a/src/django_mysql/operations.py b/src/django_mysql/operations.py index ea1ff5b4..6b853dcc 100644 --- a/src/django_mysql/operations.py +++ b/src/django_mysql/operations.py @@ -2,7 +2,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.migrations.operations.base import Operation -from django.db.migrations.state import ModelState +from django.db.migrations.state import ProjectState from django.utils.functional import cached_property @@ -15,15 +15,15 @@ def __init__(self, name: str, soname: str) -> None: self.name = name self.soname = soname - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass # pragma: no cover def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: if not self.plugin_installed(schema_editor): schema_editor.execute( @@ -34,8 +34,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: if self.plugin_installed(schema_editor): schema_editor.execute("UNINSTALL PLUGIN %s" % self.name) @@ -63,15 +63,15 @@ class InstallSOName(Operation): def __init__(self, soname: str) -> None: self.soname = soname - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass # pragma: no cover def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: schema_editor.execute("INSTALL SONAME %s", (self.soname,)) @@ -79,8 +79,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: schema_editor.execute("UNINSTALL SONAME %s", (self.soname,)) @@ -100,15 +100,15 @@ def __init__( def reversible(self) -> bool: return self.from_engine is not None - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_state: ModelState, - to_state: ModelState, + from_state: ProjectState, + to_state: ProjectState, ) -> None: self._change_engine(app_label, schema_editor, to_state, engine=self.engine) @@ -116,8 +116,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_state: ModelState, - to_state: ModelState, + from_state: ProjectState, + to_state: ProjectState, ) -> None: if self.from_engine is None: raise NotImplementedError("You cannot reverse this operation") @@ -128,7 +128,7 @@ def _change_engine( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - to_state: ModelState, + to_state: ProjectState, engine: str, ) -> None: new_model = to_state.apps.get_model(app_label, self.name) diff --git a/src/django_mysql/status.py b/src/django_mysql/status.py index fea900ae..ecf68a56 100644 --- a/src/django_mysql/status.py +++ b/src/django_mysql/status.py @@ -6,7 +6,6 @@ from django.db import connections from django.db.backends.utils import CursorWrapper from django.db.utils import DEFAULT_DB_ALIAS -from django.utils.functional import SimpleLazyObject from django_mysql.exceptions import TimeoutError @@ -16,6 +15,7 @@ class BaseStatus: Base class for the status classes """ + __slots__ = ("db",) query = "" def __init__(self, using: str | None = None) -> None: @@ -127,5 +127,5 @@ class SessionStatus(BaseStatus): query = "SHOW SESSION STATUS" -global_status = SimpleLazyObject(GlobalStatus) -session_status = SimpleLazyObject(SessionStatus) +global_status = GlobalStatus() +session_status = SessionStatus() From 7ee3a9015acbe7275a2e7ea8b79aca1a15e51491 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 14:19:23 +0100 Subject: [PATCH 02/29] Finish hints for cache module --- src/django_mysql/cache.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/django_mysql/cache.py b/src/django_mysql/cache.py index 6cbcf146..230b5e78 100644 --- a/src/django_mysql/cache.py +++ b/src/django_mysql/cache.py @@ -19,6 +19,7 @@ from django.core.cache.backends.base import default_key_func from django.db import connections from django.db import router +from django.db.models import Model from django.utils.encoding import force_bytes from django.utils.module_loading import import_string @@ -62,7 +63,9 @@ def __init__(self, table: str, params: dict[str, Any]) -> None: super().__init__(params) self._table = table - class CacheEntry: + CacheEntry: type[Model] # force Mypy to accept duck typing + + class CacheEntry: # type: ignore [no-redef] _meta = Options(table) self.cache_model_class = CacheEntry @@ -183,7 +186,7 @@ def get_many( self, keys: Iterable[str], version: int | None = None ) -> dict[str, Any]: made_key_to_key = {self.make_key(key, version=version): key for key in keys} - made_keys = list(made_key_to_key.keys()) + made_keys: list[Any] = list(made_key_to_key.keys()) for key in made_keys: self.validate_key(key) @@ -266,7 +269,7 @@ def _base_set( return True else: # mode = 'add' # Use a special code in the add query for "did insert" - insert_id = cursor.lastrowid + insert_id: int = cursor.lastrowid return insert_id != 444 _set_many_query = collapse_spaces( @@ -416,7 +419,8 @@ def _base_delta( raise ValueError("Key '%s' not found, or not an integer" % key) # New value stored in insert_id - return cursor.lastrowid + result: int = cursor.lastrowid + return result # Looks a bit tangled to turn the blob back into an int for updating, but # it works. Stores the new value for insert_id() with LAST_INSERT_ID @@ -448,7 +452,7 @@ def touch( db = router.db_for_write(self.cache_model_class) table = connections[db].ops.quote_name(self._table) with connections[db].cursor() as cursor: - affected_rows = cursor.execute( + affected_rows: int = cursor.execute( self._touch_query.format(table=table), [exp, key, self._now()] ) return affected_rows > 0 @@ -612,18 +616,20 @@ def delete_with_prefix(self, prefix: str, version: int | None = None) -> int: prefix = self.make_key(prefix + "%", version=version) with connections[db].cursor() as cursor: - return cursor.execute( + result: int = cursor.execute( """DELETE FROM {table} WHERE cache_key LIKE %s""".format( table=table ), (prefix,), ) + return result def cull(self) -> int: db = router.db_for_write(self.cache_model_class) table = connections[db].ops.quote_name(self._table) + num_deleted: int with connections[db].cursor() as cursor: # First, try just deleting expired keys num_deleted = cursor.execute( From 0072a12c165d34771a02ffe9e77b5154f337044b Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 14:19:38 +0100 Subject: [PATCH 03/29] Finish hints for utils --- src/django_mysql/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/django_mysql/utils.py b/src/django_mysql/utils.py index 6a98540f..9b4f9afd 100644 --- a/src/django_mysql/utils.py +++ b/src/django_mysql/utils.py @@ -162,7 +162,7 @@ def index_name(model: Model, *field_names: str, using: str = DEFAULT_DB_ALIAS) - ), (model._meta.db_table,) + column_names, ) - indexes = defaultdict(list) + indexes: defaultdict[str, list[str]] = defaultdict(list) for index_name, _, column_name in cursor.fetchall(): indexes[index_name].append(column_name) @@ -177,7 +177,7 @@ def get_list_sql(sequence: list[str] | tuple[str, ...]) -> str: return "({})".format(",".join("%s" for x in sequence)) -def mysql_connections() -> Generator[BaseDatabaseWrapper]: +def mysql_connections() -> Generator[[str, BaseDatabaseWrapper]]: conn_names = [DEFAULT_DB_ALIAS] + list(set(connections) - {DEFAULT_DB_ALIAS}) for alias in conn_names: connection = connections[alias] From 273b76c9be4053b529f2389f6b9c4967055e60a5 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 14:19:46 +0100 Subject: [PATCH 04/29] Finish hints for operations --- src/django_mysql/operations.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/django_mysql/operations.py b/src/django_mysql/operations.py index 6b853dcc..c5003bd5 100644 --- a/src/django_mysql/operations.py +++ b/src/django_mysql/operations.py @@ -48,7 +48,7 @@ def plugin_installed(self, schema_editor: BaseDatabaseSchemaEditor) -> bool: WHERE PLUGIN_NAME LIKE %s""", (self.name,), ) - count = cursor.fetchone()[0] + count: int = cursor.fetchone()[0] return count > 0 def describe(self) -> str: @@ -96,9 +96,7 @@ def __init__( self.engine = to_engine self.from_engine = from_engine - @property - def reversible(self) -> bool: - return self.from_engine is not None + self.reversible = self.from_engine is not None def state_forwards(self, app_label: str, state: ProjectState) -> None: pass From 7d8beefdbb0e1ec2d970359b9addf2c3f5d8bb2c Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 14:22:25 +0100 Subject: [PATCH 05/29] Finish hints for locks --- src/django_mysql/locks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/django_mysql/locks.py b/src/django_mysql/locks.py index 9470df1f..39d8b0f1 100644 --- a/src/django_mysql/locks.py +++ b/src/django_mysql/locks.py @@ -6,6 +6,7 @@ from django.db import connections from django.db.backends.utils import CursorWrapper from django.db.models import Model +from django.db.transaction import Atomic from django.db.transaction import TransactionManagementError from django.db.transaction import atomic from django.db.utils import DEFAULT_DB_ALIAS @@ -77,7 +78,8 @@ def is_held(self) -> bool: def holding_connection_id(self) -> int | None: with self.get_cursor() as cursor: cursor.execute("SELECT IS_USED_LOCK(%s)", (self.name,)) - return cursor.fetchone()[0] + result: int | None = cursor.fetchone()[0] + return result @classmethod def held_with_prefix( @@ -108,6 +110,7 @@ def __init__( self.read: list[str] = self._process_names(read) self.write: list[str] = self._process_names(write) self.db = DEFAULT_DB_ALIAS if using is None else using + self._atomic: Atomic | None = None def _process_names(self, names: list[str | type[Model]] | None) -> list[str]: """ @@ -170,6 +173,7 @@ def release( ) -> None: connection = connections[self.db] with connection.cursor() as cursor: + assert self._atomic is not None self._atomic.__exit__(exc_type, exc_value, exc_traceback) self._atomic = None cursor.execute("UNLOCK TABLES") From b99b21ecdb00d6fdf92018c7539a1844f9d9f7d7 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 14:23:47 +0100 Subject: [PATCH 06/29] Finish hints for management commands --- src/django_mysql/management/commands/cull_mysql_caches.py | 7 ++++--- src/django_mysql/management/commands/dbparams.py | 8 +++++--- .../management/commands/mysql_cache_migration.py | 3 ++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/django_mysql/management/commands/cull_mysql_caches.py b/src/django_mysql/management/commands/cull_mysql_caches.py index ae8d6d81..27e2683a 100644 --- a/src/django_mysql/management/commands/cull_mysql_caches.py +++ b/src/django_mysql/management/commands/cull_mysql_caches.py @@ -31,9 +31,10 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None: help="Specify the cache alias(es) to cull.", ) - def handle( - self, *args: Any, verbosity: int, aliases: list[str], **options: Any - ) -> None: + def handle(self, *args: Any, **options: Any) -> None: + verbosity: int = options["verbosity"] + aliases: list[str] = options["aliases"] + if not aliases: aliases = list(settings.CACHES) diff --git a/src/django_mysql/management/commands/dbparams.py b/src/django_mysql/management/commands/dbparams.py index fa6791f9..4de2c3d7 100644 --- a/src/django_mysql/management/commands/dbparams.py +++ b/src/django_mysql/management/commands/dbparams.py @@ -51,9 +51,11 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None: "pt-online-schema-change $(./manage.py dbparams --dsn)", ) - def handle( - self, *args: Any, alias: str, show_mysql: bool, show_dsn: bool, **options: Any - ) -> None: + def handle(self, *args: Any, **options: Any) -> None: + alias: str = options["alias"] + show_mysql: bool = options["show_mysql"] + show_dsn: bool = options["show_dsn"] + try: connection = connections[alias] except ConnectionDoesNotExist: diff --git a/src/django_mysql/management/commands/mysql_cache_migration.py b/src/django_mysql/management/commands/mysql_cache_migration.py index c73008cf..9222c23d 100644 --- a/src/django_mysql/management/commands/mysql_cache_migration.py +++ b/src/django_mysql/management/commands/mysql_cache_migration.py @@ -30,7 +30,8 @@ def add_arguments(self, parser: argparse.ArgumentParser) -> None: help="Specify the cache alias(es) to create migrations for.", ) - def handle(self, *args: Any, aliases: list[str], **options: Any) -> None: + def handle(self, *args: Any, **options: Any) -> None: + aliases: list[str] = options["aliases"] if not aliases: aliases = list(settings.CACHES) From 447dcc3de10630077b02d5a307f9a4c8f3db40ac Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 14:28:55 +0100 Subject: [PATCH 07/29] Finish hints for status --- src/django_mysql/status.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_mysql/status.py b/src/django_mysql/status.py index ecf68a56..9fd76db2 100644 --- a/src/django_mysql/status.py +++ b/src/django_mysql/status.py @@ -57,7 +57,7 @@ def get_many(self, names: Iterable[str]) -> dict[str, int | float | bool | str]: ] ) - cursor.execute(query, names) + cursor.execute(query, tuple(names)) return {name: self._cast(value) for name, value in cursor.fetchall()} From a885a6b235ef2d811e5fa7727ec85c9d19dff7cd Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 14:33:24 +0100 Subject: [PATCH 08/29] Stricter signature for contribute_to_class --- src/django_mysql/models/fields/lists.py | 6 ++++-- src/django_mysql/models/fields/sets.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/django_mysql/models/fields/lists.py b/src/django_mysql/models/fields/lists.py index 5d25f01f..90f90bc7 100644 --- a/src/django_mysql/models/fields/lists.py +++ b/src/django_mysql/models/fields/lists.py @@ -163,8 +163,10 @@ def formfield(self, **kwargs: Any) -> FormField: defaults.update(kwargs) return super().formfield(**defaults) - def contribute_to_class(self, cls: type[Model], name: str, **kwargs: Any) -> None: - super().contribute_to_class(cls, name, **kwargs) + def contribute_to_class( + self, cls: type[Model], name: str, private_only: bool = False + ) -> None: + super().contribute_to_class(cls, name, private_only=private_only) self.base_field.model = cls diff --git a/src/django_mysql/models/fields/sets.py b/src/django_mysql/models/fields/sets.py index b7c0d6d4..344078ff 100644 --- a/src/django_mysql/models/fields/sets.py +++ b/src/django_mysql/models/fields/sets.py @@ -145,8 +145,10 @@ def formfield(self, **kwargs: Any) -> FormField: defaults.update(kwargs) return super().formfield(**defaults) - def contribute_to_class(self, cls: type[Model], name: str, **kwargs: Any) -> None: - super().contribute_to_class(cls, name, **kwargs) + def contribute_to_class( + self, cls: type[Model], name: str, private_only: bool = False + ) -> None: + super().contribute_to_class(cls, name, private_only=private_only) self.base_field.model = cls From 586531abf9f8b4272d9f4def8f415ad23dc9b3c2 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 15:12:49 +0100 Subject: [PATCH 09/29] Fix JSONExtract output_field arg --- src/django_mysql/models/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_mysql/models/functions.py b/src/django_mysql/models/functions.py index 62623f32..556f07be 100644 --- a/src/django_mysql/models/functions.py +++ b/src/django_mysql/models/functions.py @@ -182,7 +182,7 @@ def __init__( self, expression: ExpressionArgument, *paths: ExpressionArgument, - output_field: type[DjangoField] | None = None, + output_field: DjangoField | None = None, ) -> None: exprs = [expression] for path in paths: From 76393677a4d36416da074a86c4885cfd1d8e1e98 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 15:16:24 +0100 Subject: [PATCH 10/29] Fix types for IndexLookup.as_sql() --- src/django_mysql/models/fields/lists.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/django_mysql/models/fields/lists.py b/src/django_mysql/models/fields/lists.py index 90f90bc7..828a6f51 100644 --- a/src/django_mysql/models/fields/lists.py +++ b/src/django_mysql/models/fields/lists.py @@ -2,7 +2,6 @@ from typing import Any from typing import Callable -from typing import Iterable from typing import cast from django.core import checks @@ -14,6 +13,7 @@ from django.db.models import Model from django.db.models import TextField from django.db.models.expressions import BaseExpression +from django.db.models.sql.compiler import SQLCompiler from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ @@ -230,10 +230,10 @@ def __init__(self, index: int, *args: Any, **kwargs: Any) -> None: self.index = index def as_sql( - self, qn: Callable[[str], str], connection: BaseDatabaseWrapper - ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper + ) -> tuple[str, list[str | int]]: + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) # Put rhs on the left since that's the order FIND_IN_SET uses return f"(FIND_IN_SET({rhs}, {lhs}) = {self.index})", params From 107e55b76df07ca9ca6abb98e32bfa7d1216d6c1 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 22 Aug 2022 15:18:15 +0100 Subject: [PATCH 11/29] Fix type for GroupConcat arg 1 --- src/django_mysql/models/aggregates.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index 6370831c..58170761 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -5,7 +5,6 @@ from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Aggregate from django.db.models import CharField -from django.db.models import Expression from django.db.models.sql.compiler import SQLCompiler @@ -29,7 +28,7 @@ class GroupConcat(Aggregate): def __init__( self, - expression: Expression, + expression: Any, distinct: bool = False, separator: str | None = None, ordering: str | None = None, From e4449c73a9ce5d0cf03e7dd4a18789fb8141dc1e Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Fri, 26 Aug 2022 12:19:33 +0100 Subject: [PATCH 12/29] Fix some as_sql() signatures --- src/django_mysql/models/lookups.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/django_mysql/models/lookups.py b/src/django_mysql/models/lookups.py index 5b67b842..8eb922af 100644 --- a/src/django_mysql/models/lookups.py +++ b/src/django_mysql/models/lookups.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any -from typing import Callable from typing import Iterable from django.db.backends.base.base import BaseDatabaseWrapper @@ -24,11 +23,11 @@ class SoundsLike(Lookup): def as_sql( self, - qn: Callable[[str], str], + compiler: SQLCompiler, connection: BaseDatabaseWrapper, ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) return f"{lhs} SOUNDS LIKE {rhs}", params @@ -66,10 +65,10 @@ def get_prep_lookup(self) -> Any: return super().get_prep_lookup() def as_sql( - self, qn: Callable[[str], str], connection: BaseDatabaseWrapper + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) # Put rhs (and params) on the left since that's the order FIND_IN_SET uses params = tuple(rhs_params) + tuple(lhs_params) return f"FIND_IN_SET({rhs}, {lhs})", params @@ -86,9 +85,9 @@ class DynColHasKey(Lookup): lookup_name = "has_key" def as_sql( - self, qn: Callable[[str], str], connection: BaseDatabaseWrapper + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper ) -> tuple[str, Iterable[Any]]: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) return f"COLUMN_EXISTS({lhs}, {rhs})", params From a6f7929391c5ab89e96bd805591fee3c7c5ffe26 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Sat, 27 Aug 2022 10:13:30 +0100 Subject: [PATCH 13/29] assert --- src/django_mysql/models/fields/sets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/django_mysql/models/fields/sets.py b/src/django_mysql/models/fields/sets.py index 344078ff..b050d1f3 100644 --- a/src/django_mysql/models/fields/sets.py +++ b/src/django_mysql/models/fields/sets.py @@ -169,6 +169,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]: and isinstance(self.base_field, CharField) and self.size ): + assert self.base_field.max_length is not None max_size = ( # The chars used (self.size * (self.base_field.max_length)) From facdc6221da5e82b96d352dfe8931f491ce93a06 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Sun, 28 Aug 2022 18:53:32 +0100 Subject: [PATCH 14/29] mute isinstance --- src/django_mysql/models/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_mysql/models/functions.py b/src/django_mysql/models/functions.py index 556f07be..be697d59 100644 --- a/src/django_mysql/models/functions.py +++ b/src/django_mysql/models/functions.py @@ -403,7 +403,7 @@ def __init__( if not hasattr(name, "resolve_expression"): name = Value(name) - if isinstance(value, dict): + if isinstance(value, dict): # type: ignore [unreachable] raise ValueError("ColumnAdd with nested values is not supported") if not hasattr(value, "resolve_expression"): value = Value(value) From 7aeb183782b0900a06e14695e7330906a99bc6b9 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Tue, 30 Aug 2022 10:57:13 +0100 Subject: [PATCH 15/29] more fixes --- src/django_mysql/models/fields/lists.py | 5 ++--- src/django_mysql/models/query.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/django_mysql/models/fields/lists.py b/src/django_mysql/models/fields/lists.py index 828a6f51..8888be6d 100644 --- a/src/django_mysql/models/fields/lists.py +++ b/src/django_mysql/models/fields/lists.py @@ -14,7 +14,6 @@ from django.db.models import TextField from django.db.models.expressions import BaseExpression from django.db.models.sql.compiler import SQLCompiler -from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ from django_mysql.forms import SimpleListField @@ -74,7 +73,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]: return errors @property - def description(self) -> Any: + def description(self) -> str: return _("List of %(base_description)s") % { "base_description": self.base_field.description } @@ -154,7 +153,7 @@ def value_to_string(self, obj: Any) -> str: vals = self.value_from_object(obj) return self.get_prep_value(vals) - def formfield(self, **kwargs: Any) -> FormField: + def formfield(self, **kwargs: Any) -> Any: defaults = { "form_class": SimpleListField, "base_field": self.base_field.formfield(), diff --git a/src/django_mysql/models/query.py b/src/django_mysql/models/query.py index 61a16837..bcb67854 100644 --- a/src/django_mysql/models/query.py +++ b/src/django_mysql/models/query.py @@ -698,7 +698,7 @@ def approx_count(queryset: models.QuerySet) -> int: ) # N.B. when we support more complex QuerySets they should be estimated # with 'EXPLAIN SELECT' - approx_count = cursor.fetchone()[0] + approx_count: int = cursor.fetchone()[0] return approx_count From 8758ed10b29cd8b9d56b1fa6f495062aa3c0bfd5 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Tue, 18 Oct 2022 19:32:29 +0100 Subject: [PATCH 16/29] Improve function types --- src/django_mysql/models/functions.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/django_mysql/models/functions.py b/src/django_mysql/models/functions.py index be697d59..95acd951 100644 --- a/src/django_mysql/models/functions.py +++ b/src/django_mysql/models/functions.py @@ -169,7 +169,8 @@ def get(cls, using: str = DEFAULT_DB_ALIAS) -> int: # database connections in Django, and the reason was not clear with connections[using].cursor() as cursor: cursor.execute("SELECT LAST_INSERT_ID()") - return cursor.fetchone()[0] + id_: int = cursor.fetchone()[0] + return id_ # JSON Functions @@ -258,7 +259,11 @@ def as_sql( if connection.vendor != "mysql": # pragma: no cover raise AssertionError("JSONValue only supports MySQL/MariaDB") json_string = json.dumps(self._data, allow_nan=False) - if connection.vendor == "mysql" and connection.mysql_is_mariadb: + if ( + connection.vendor == "mysql" + # type narrowed by vendor check + and connection.mysql_is_mariadb # type: ignore [attr-defined] + ): # MariaDB doesn't support explicit cast to JSON. return "JSON_EXTRACT(%s, '$')", (json_string,) else: @@ -270,7 +275,7 @@ def __init__( self, expression: ExpressionArgument, data: dict[ - str, + ExpressionArgument, ( ExpressionArgument | None @@ -288,12 +293,12 @@ def __init__( exprs = [expression] for path, value in data.items(): - if not hasattr(path, "resolve_expression"): + if not isinstance(path, Expression): path = Value(path) exprs.append(path) - if not hasattr(value, "resolve_expression"): + if not isinstance(value, Expression): value = JSONValue(value) exprs.append(value) @@ -392,7 +397,7 @@ def __init__( self, expression: ExpressionArgument, to_add: dict[ - str, + ExpressionArgument, ExpressionArgument | float | int | dt.date | dt.time | dt.datetime, ], ) -> None: @@ -400,12 +405,12 @@ def __init__( expressions = [expression] for name, value in to_add.items(): - if not hasattr(name, "resolve_expression"): + if not isinstance(name, Expression): name = Value(name) - if isinstance(value, dict): # type: ignore [unreachable] + if isinstance(value, dict): raise ValueError("ColumnAdd with nested values is not supported") - if not hasattr(value, "resolve_expression"): + if not isinstance(value, Expression): value = Value(value) expressions.extend((name, value)) From 42f84b0bfcb1823e954aef2c567301bc6a085827 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Tue, 18 Oct 2022 19:33:19 +0100 Subject: [PATCH 17/29] Add cast --- tests/testapp/test_dynamicfield.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/testapp/test_dynamicfield.py b/tests/testapp/test_dynamicfield.py index 7beaecf2..3684f1d9 100644 --- a/tests/testapp/test_dynamicfield.py +++ b/tests/testapp/test_dynamicfield.py @@ -2,6 +2,7 @@ import datetime as dt import json +from typing import cast from unittest import SkipTest from unittest import mock @@ -12,6 +13,7 @@ from django.db import connection from django.db import connections from django.db import models +from django.db.backends.mysql.base import DatabaseWrapper from django.db.migrations.writer import MigrationWriter from django.db.models import CharField from django.db.models import Transform @@ -27,7 +29,7 @@ class DynColTestCase(TestCase): @classmethod def setUpClass(cls): - if not connection.mysql_is_mariadb: + if not cast(DatabaseWrapper, connection).mysql_is_mariadb: raise SkipTest("Dynamic Columns require MariaDB") super().setUpClass() From 1d0bcd96d0710f42ac4ae9a76e498350d147a26a Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Wed, 19 Oct 2022 09:24:43 +0100 Subject: [PATCH 18/29] Fix AsType signature --- src/django_mysql/models/functions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/django_mysql/models/functions.py b/src/django_mysql/models/functions.py index 95acd951..fa252190 100644 --- a/src/django_mysql/models/functions.py +++ b/src/django_mysql/models/functions.py @@ -378,7 +378,11 @@ class AsType(Func): function = "" template = "%(expressions)s AS %(data_type)s" - def __init__(self, expression: ExpressionArgument, data_type: str) -> None: + def __init__( + self, + expression: Expression | str | float | int | dt.date | dt.time | dt.datetime, + data_type: str, + ) -> None: from django_mysql.models.fields.dynamic import KeyTransform if not hasattr(expression, "resolve_expression"): From 0d700019f3b8338e9ac2fc9a007c40bcd48ae7c0 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Wed, 19 Oct 2022 09:27:25 +0100 Subject: [PATCH 19/29] Fix name of pytest fixture to avoid collision --- tests/testapp/test_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testapp/test_cache.py b/tests/testapp/test_cache.py index cc19bd45..4d404b81 100644 --- a/tests/testapp/test_cache.py +++ b/tests/testapp/test_cache.py @@ -1301,7 +1301,7 @@ def test_cull_mysql_caches_bad_cache_name(self): @override_cache_settings() class MySQLCacheMigrationTests(MySQLCacheTableMixin, TransactionTestCase): @pytest.fixture(autouse=True) - def flake8_path(self, flake8_path): + def set_flake8_path(self, flake8_path): self.flake8_path = flake8_path def test_mysql_cache_migration(self): From 1ce99b7c18ab50fb184ada65f38ab8c4ac5848ca Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Wed, 19 Oct 2022 09:41:27 +0100 Subject: [PATCH 20/29] Fix some errors in cache tests --- tests/testapp/test_cache.py | 38 +++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/testapp/test_cache.py b/tests/testapp/test_cache.py index 4d404b81..60cb33a3 100644 --- a/tests/testapp/test_cache.py +++ b/tests/testapp/test_cache.py @@ -9,6 +9,7 @@ from typing import Any import pytest +from django.core.cache import BaseCache from django.core.cache import CacheKeyWarning from django.core.cache import cache from django.core.cache import caches @@ -100,7 +101,10 @@ def reverse_custom_key_func(full_key): } -def caches_setting_for_tests(options=None, **params): +def caches_setting_for_tests( + options: dict[str, Any] | None = None, + **params: Any, +) -> dict[str, Any]: # `params` are test specific overrides and `_caches_settings_base` is the # base config for the tests. # This results in the following search order: @@ -117,8 +121,10 @@ def caches_setting_for_tests(options=None, **params): # Spaces are used in the table name to ensure quoting/escaping is working def override_cache_settings( - BACKEND="django_mysql.cache.MySQLCache", LOCATION="test cache table", **kwargs -): + BACKEND: str = "django_mysql.cache.MySQLCache", + LOCATION: str = "test cache table", + **kwargs: Any, +) -> override_settings: return override_settings( CACHES=caches_setting_for_tests(BACKEND=BACKEND, LOCATION=LOCATION, **kwargs) ) @@ -128,13 +134,13 @@ class MySQLCacheTableMixin(TransactionTestCase): table_name = "test cache table" @classmethod - def create_table(self): + def create_table(self) -> None: sql = MySQLCache.create_table_sql.format(table_name=self.table_name) with connection.cursor() as cursor: cursor.execute(sql) @classmethod - def drop_table(self): + def drop_table(self) -> None: with connection.cursor() as cursor: cursor.execute("DROP TABLE `%s`" % self.table_name) @@ -153,10 +159,11 @@ def tearDownClass(cls): super().tearDownClass() cls.drop_table() - def table_count(self): + def table_count(self) -> int: with connection.cursor() as cursor: cursor.execute("SELECT COUNT(*) FROM `%s`" % self.table_name) - return cursor.fetchone()[0] + count: int = cursor.fetchone()[0] + return count # These tests were copied from django's tests/cache/tests.py file @@ -726,7 +733,7 @@ def test_cache_write_unpicklable_object(self): fetch_middleware = FetchFromCacheMiddleware(empty_response) request = self.factory.get("/cache/test") - request._cache_update_cache = True + request._cache_update_cache = True # type: ignore [attr-defined] get_cache_data = FetchFromCacheMiddleware(empty_response).process_request( request ) @@ -779,10 +786,10 @@ def test_get_or_set_version(self): cache.get_or_set("brian", 1979, version=2) with pytest.raises(TypeError, match=msg_re): - cache.get_or_set("brian") + cache.get_or_set("brian") # type: ignore [call-arg] with pytest.raises(TypeError, match=msg_re): - cache.get_or_set("brian", version=1) + cache.get_or_set("brian", version=1) # type: ignore [call-arg] assert cache.get("brian", version=1) is None assert cache.get_or_set("brian", 42, version=1) == 42 @@ -915,6 +922,7 @@ def func(key, *args): # Original tests def test_base_set_bad_value(self): + assert isinstance(cache, MySQLCache) with pytest.raises(ValueError) as excinfo: cache._base_set("foo", "key", "value") assert "'mode' should be" in str(excinfo.value) @@ -997,7 +1005,9 @@ def test_cull_deletes_expired_first(self): self._perform_cull_test(cull_cache, 30, 30) assert cull_cache.get("key") is None - def _perform_cull_test(self, cull_cache, initial_count, final_count): + def _perform_cull_test( + self, cull_cache: BaseCache, initial_count: int, final_count: int + ) -> None: # Create initial cache key entries. This will overflow the cache, # causing a cull. for i in range(1, initial_count + 1): @@ -1137,6 +1147,7 @@ def test_keys_with_prefix_version(self, cache_name): @override_cache_settings(KEY_FUNCTION=custom_key_func) def test_keys_with_prefix_with_bad_cache(self): + assert isinstance(cache, MySQLCache) with pytest.raises(ValueError) as excinfo: cache.keys_with_prefix("") assert str(excinfo.value).startswith("To use the _with_prefix commands") @@ -1176,6 +1187,7 @@ def test_get_with_prefix_version(self, cache_name): @override_cache_settings(KEY_FUNCTION=custom_key_func) def test_get_with_prefix_with_bad_cache(self): + assert isinstance(cache, MySQLCache) with pytest.raises(ValueError) as excinfo: cache.get_with_prefix("") assert str(excinfo.value).startswith("To use the _with_prefix commands") @@ -1233,6 +1245,7 @@ def test_delete_with_prefix_version(self, cache_name): @override_cache_settings(KEY_FUNCTION=custom_key_func) def test_delete_with_prefix_with_no_reverse_works(self): + assert isinstance(cache, MySQLCache) cache.set_many({"K1": "value", "K2": "value2", "B2": "Anothervalue"}) assert cache.delete_with_prefix("K") == 2 assert cache.get_many(["K1", "K2", "B2"]) == {"B2": "Anothervalue"} @@ -1262,6 +1275,7 @@ def test_mysql_cache_migration_no_mysql_caches(self): def test_cull_max_entries_minus_one(self): # cull with MAX_ENTRIES = -1 should never clear anything that is not # expired + assert isinstance(cache, MySQLCache) # one expired key cache.set("key", "value", 0.1) @@ -1341,7 +1355,7 @@ def test_mysql_cache_migration(self): operation.database_backwards("testapp", editor, new_state, state) assert not self.table_exists(self.table_name) - def table_exists(self, table_name): + def table_exists(self, table_name: str) -> bool: with connection.cursor() as cursor: cursor.execute( """SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES From dd2c058de28a78bb99ab64505e99d1f242b91280 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Wed, 19 Oct 2022 10:02:20 +0100 Subject: [PATCH 21/29] Fix return type of IndexLookup.as_sql --- src/django_mysql/models/fields/lists.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_mysql/models/fields/lists.py b/src/django_mysql/models/fields/lists.py index 8888be6d..0a2c8f71 100644 --- a/src/django_mysql/models/fields/lists.py +++ b/src/django_mysql/models/fields/lists.py @@ -233,7 +233,7 @@ def as_sql( ) -> tuple[str, list[str | int]]: lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) - params = tuple(lhs_params) + tuple(rhs_params) + params = list(lhs_params) + list(rhs_params) # Put rhs on the left since that's the order FIND_IN_SET uses return f"(FIND_IN_SET({rhs}, {lhs}) = {self.index})", params From a2c8e60a8ba9e0b84d91a9f21bdf99ac5f5e0ab2 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Wed, 19 Oct 2022 10:04:34 +0100 Subject: [PATCH 22/29] make formfield() methods return Any --- src/django_mysql/models/fields/dynamic.py | 3 +-- src/django_mysql/models/fields/sets.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index 67dcd1d0..97c6f49c 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -22,7 +22,6 @@ from django.db.models import TimeField from django.db.models import Transform from django.db.models.sql.compiler import SQLCompiler -from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ from django_mysql.models.lookups import DynColHasKey @@ -292,7 +291,7 @@ def deconstruct(self) -> DeconstructResult: kwargs["blank"] = False return name, path, args, kwargs - def formfield(self, *args: Any, **kwargs: Any) -> FormField | None: + def formfield(self, *args: Any, **kwargs: Any) -> Any: """ Disabled in forms - there is no sensible way of editing this """ diff --git a/src/django_mysql/models/fields/sets.py b/src/django_mysql/models/fields/sets.py index b050d1f3..40a9e1e6 100644 --- a/src/django_mysql/models/fields/sets.py +++ b/src/django_mysql/models/fields/sets.py @@ -11,7 +11,6 @@ from django.db.models import Model from django.db.models import TextField from django.db.models.expressions import BaseExpression -from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ from django_mysql.forms import SimpleSetField @@ -136,7 +135,7 @@ def value_to_string(self, obj: Any) -> str: vals = self.value_from_object(obj) return self.get_prep_value(vals) - def formfield(self, **kwargs: Any) -> FormField: + def formfield(self, **kwargs: Any) -> Any: defaults = { "form_class": SimpleSetField, "base_field": self.base_field.formfield(), From fc8af0aa20ae779b10eaf35860764586e152af41 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 9 Sep 2024 11:57:25 +0100 Subject: [PATCH 23/29] Add extra model type asserts for deserialization tests --- tests/testapp/test_bit1_field.py | 2 ++ tests/testapp/test_dynamicfield.py | 1 + tests/testapp/test_listcharfield.py | 1 + tests/testapp/test_listtextfield.py | 1 + tests/testapp/test_setcharfield.py | 1 + tests/testapp/test_settextfield.py | 1 + 6 files changed, 7 insertions(+) diff --git a/tests/testapp/test_bit1_field.py b/tests/testapp/test_bit1_field.py index 47df7a12..92f8ae95 100644 --- a/tests/testapp/test_bit1_field.py +++ b/tests/testapp/test_bit1_field.py @@ -95,6 +95,7 @@ def test_loading(self): objs = list(serializers.deserialize("json", test_data)) assert len(objs) == 1 instance = objs[0].object + assert isinstance(instance, Bit1Model) assert not instance.flag_a assert instance.flag_b @@ -180,6 +181,7 @@ def test_loading(self): objs = list(serializers.deserialize("json", test_data)) assert len(objs) == 1 instance = objs[0].object + assert isinstance(instance, NullBit1Model) assert instance.flag is None else: diff --git a/tests/testapp/test_dynamicfield.py b/tests/testapp/test_dynamicfield.py index 3684f1d9..c07cab4a 100644 --- a/tests/testapp/test_dynamicfield.py +++ b/tests/testapp/test_dynamicfield.py @@ -495,6 +495,7 @@ def test_dumping(self): def test_loading(self): deserialized = list(serializers.deserialize("json", self.test_data)) instance = deserialized[0].object + assert isinstance(instance, DynamicModel) assert instance.attrs == {"a": "b"} diff --git a/tests/testapp/test_listcharfield.py b/tests/testapp/test_listcharfield.py index 28f8a850..8e1afa98 100644 --- a/tests/testapp/test_listcharfield.py +++ b/tests/testapp/test_listcharfield.py @@ -567,6 +567,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, CharListModel) assert instance.field == ["big", "leather", "comfy"] diff --git a/tests/testapp/test_listtextfield.py b/tests/testapp/test_listtextfield.py index eb0f8523..6476f291 100644 --- a/tests/testapp/test_listtextfield.py +++ b/tests/testapp/test_listtextfield.py @@ -358,6 +358,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, BigCharListModel) assert instance.field == ["big", "leather", "comfy"] def test_dumping_loading_empty(self): diff --git a/tests/testapp/test_setcharfield.py b/tests/testapp/test_setcharfield.py index 004a41d6..25e9066e 100644 --- a/tests/testapp/test_setcharfield.py +++ b/tests/testapp/test_setcharfield.py @@ -543,6 +543,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, CharSetModel) assert instance.field == {"big", "leather", "comfy"} diff --git a/tests/testapp/test_settextfield.py b/tests/testapp/test_settextfield.py index fe19f16c..ed8501ab 100644 --- a/tests/testapp/test_settextfield.py +++ b/tests/testapp/test_settextfield.py @@ -336,6 +336,7 @@ def test_loading(self): """ objs = list(serializers.deserialize("json", test_data)) instance = objs[0].object + assert isinstance(instance, BigCharSetModel) assert instance.field == {"big", "leather", "comfy"} def test_empty(self): From 873b6dab53937537d04ff6255163710264ebc942 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 9 Sep 2024 14:37:12 +0100 Subject: [PATCH 24/29] Some extra hints --- src/django_mysql/utils.py | 4 +++- tests/testapp/test_locks.py | 10 +++++++--- tests/testapp/test_size_fields.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/django_mysql/utils.py b/src/django_mysql/utils.py index 9b4f9afd..2011decd 100644 --- a/src/django_mysql/utils.py +++ b/src/django_mysql/utils.py @@ -132,7 +132,9 @@ def collapse_spaces(string: str) -> str: return " ".join(filter(None, bits)) -def index_name(model: Model, *field_names: str, using: str = DEFAULT_DB_ALIAS) -> str: +def index_name( + model: type[Model], *field_names: str, using: str = DEFAULT_DB_ALIAS +) -> str: """ Returns the name of the index existing on field_names, or raises KeyError if no such index exists. diff --git a/tests/testapp/test_locks.py b/tests/testapp/test_locks.py index 2269675d..c3ca46e6 100644 --- a/tests/testapp/test_locks.py +++ b/tests/testapp/test_locks.py @@ -3,11 +3,13 @@ import queue from threading import Thread from typing import TYPE_CHECKING +from typing import cast import pytest from django.db import OperationalError from django.db import connection from django.db import connections +from django.db.backends.mysql.base import DatabaseWrapper from django.db.transaction import TransactionManagementError from django.db.transaction import atomic from django.test import TestCase @@ -34,7 +36,7 @@ class LockTests(TestCase): def setUpClass(cls): super().setUpClass() - cls.supports_lock_info = connection.mysql_is_mariadb + cls.supports_lock_info = cast(DatabaseWrapper, connection).mysql_is_mariadb if cls.supports_lock_info: with connection.cursor() as cursor: cursor.execute( @@ -229,7 +231,7 @@ def tearDown(self): Customer.objects.using("other").all().delete() super().tearDown() - def is_locked(self, connection_name, table_name): + def is_locked(self, connection_name: str, table_name: str) -> bool: conn = connections[connection_name] with conn.cursor() as cursor: cursor.execute( @@ -239,7 +241,9 @@ def is_locked(self, connection_name, table_name): rows = cursor.fetchall() if rows: assert len(rows) == 1 - return rows[0][2] > 0 + value = rows[0][2] + assert isinstance(value, int) + return value > 0 else: # pragma: no cover # MySQL 8+ closes the table really quickly. If it's closed, # it's not locked. diff --git a/tests/testapp/test_size_fields.py b/tests/testapp/test_size_fields.py index b480ab1a..ed97687d 100644 --- a/tests/testapp/test_size_fields.py +++ b/tests/testapp/test_size_fields.py @@ -22,7 +22,7 @@ forceDataError = override_mysql_variables(SQL_MODE="STRICT_TRANS_TABLES") -def migrate(name): +def migrate(name: str) -> None: call_command( "migrate", "testapp", name, verbosity=0, skip_checks=True, interactive=False ) From 25335191b9e89e7b6c96fe5c7e148491ddb56ce1 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 9 Sep 2024 21:22:38 +0100 Subject: [PATCH 25/29] Fix formfield() methods --- src/django_mysql/models/fields/dynamic.py | 3 ++- src/django_mysql/models/fields/lists.py | 3 ++- src/django_mysql/models/fields/sets.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index 97c6f49c..63efbcc1 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -10,6 +10,7 @@ from typing import Union from typing import cast +from django import forms from django.core import checks from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import DateField @@ -291,7 +292,7 @@ def deconstruct(self) -> DeconstructResult: kwargs["blank"] = False return name, path, args, kwargs - def formfield(self, *args: Any, **kwargs: Any) -> Any: + def formfield(self, *args: Any, **kwargs: Any) -> forms.Field | None: """ Disabled in forms - there is no sensible way of editing this """ diff --git a/src/django_mysql/models/fields/lists.py b/src/django_mysql/models/fields/lists.py index 0a2c8f71..52732f8e 100644 --- a/src/django_mysql/models/fields/lists.py +++ b/src/django_mysql/models/fields/lists.py @@ -4,6 +4,7 @@ from typing import Callable from typing import cast +from django import forms from django.core import checks from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import CharField @@ -153,7 +154,7 @@ def value_to_string(self, obj: Any) -> str: vals = self.value_from_object(obj) return self.get_prep_value(vals) - def formfield(self, **kwargs: Any) -> Any: + def formfield(self, **kwargs: Any) -> forms.Field | None: defaults = { "form_class": SimpleListField, "base_field": self.base_field.formfield(), diff --git a/src/django_mysql/models/fields/sets.py b/src/django_mysql/models/fields/sets.py index 40a9e1e6..24a29c8d 100644 --- a/src/django_mysql/models/fields/sets.py +++ b/src/django_mysql/models/fields/sets.py @@ -3,6 +3,7 @@ from typing import Any from typing import cast +from django import forms from django.core import checks from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import CharField @@ -135,7 +136,7 @@ def value_to_string(self, obj: Any) -> str: vals = self.value_from_object(obj) return self.get_prep_value(vals) - def formfield(self, **kwargs: Any) -> Any: + def formfield(self, **kwargs: Any) -> forms.Field | None: defaults = { "form_class": SimpleSetField, "base_field": self.base_field.formfield(), From a0874a213186c1fcb6e02e1bb9f8464be0db351f Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 9 Sep 2024 21:24:25 +0100 Subject: [PATCH 26/29] Allow bad arg types for tests checking that --- tests/testapp/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testapp/test_models.py b/tests/testapp/test_models.py index f9251d27..ac8eb161 100644 --- a/tests/testapp/test_models.py +++ b/tests/testapp/test_models.py @@ -330,7 +330,7 @@ def test_force_index_at_least_one(self): def test_force_index_invalid_for(self): with pytest.raises(ValueError) as excinfo: - Author.objects.force_index("a", for_="INVALID") + Author.objects.force_index("a", for_="INVALID") # type: ignore [arg-type] assert "for_ must be one of" in str(excinfo.value) def test_index_hint_force_order_by(self): @@ -527,7 +527,7 @@ def test_objects_pk_range_reversed(self): def test_objects_pk_range_bad(self): with pytest.raises(ValueError) as excinfo: - list(Author.objects.iter_smart(pk_range="My Bad Value")) + list(Author.objects.iter_smart(pk_range="My Bad Value")) # type: ignore [arg-type] assert "Unrecognized value for pk_range" in str(excinfo.value) def test_pk_range_race_condition(self): From 842cbcf967e0abed778a856187b2f265561b5007 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 9 Sep 2024 21:24:45 +0100 Subject: [PATCH 27/29] Pass connection to db_type() --- tests/testapp/test_size_fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testapp/test_size_fields.py b/tests/testapp/test_size_fields.py index ed97687d..399597f4 100644 --- a/tests/testapp/test_size_fields.py +++ b/tests/testapp/test_size_fields.py @@ -50,7 +50,7 @@ def test_binaryfield_default_length(self): # By default, SizedBinaryField should act like BinaryField field = SizedBinaryField() assert field.size_class == 4 - assert field.db_type(None) == "longblob" + assert field.db_type(connection) == "longblob" @atomic def test_binary_1_max_length(self): @@ -153,7 +153,7 @@ def test_textfield_default_length(self): # By default, SizedTextField should act like TextField field = SizedTextField() assert field.size_class == 4 - assert field.db_type(None) == "longtext" + assert field.db_type(connection) == "longtext" def test_tinytext_max_length(self): # Okay From 07f44d08ed3b492d80a2841e96ac56d5aa9789ac Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 9 Sep 2024 21:26:06 +0100 Subject: [PATCH 28/29] Correct mysql_connections() --- src/django_mysql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/django_mysql/utils.py b/src/django_mysql/utils.py index 2011decd..5d5de374 100644 --- a/src/django_mysql/utils.py +++ b/src/django_mysql/utils.py @@ -179,7 +179,7 @@ def get_list_sql(sequence: list[str] | tuple[str, ...]) -> str: return "({})".format(",".join("%s" for x in sequence)) -def mysql_connections() -> Generator[[str, BaseDatabaseWrapper]]: +def mysql_connections() -> Generator[tuple[str, BaseDatabaseWrapper]]: conn_names = [DEFAULT_DB_ALIAS] + list(set(connections) - {DEFAULT_DB_ALIAS}) for alias in conn_names: connection = connections[alias] From de983de210a06e343cc592a85fb7e0d1790c2eed Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 9 Sep 2024 21:31:38 +0100 Subject: [PATCH 29/29] Correct types of source expression functions --- src/django_mysql/models/expressions.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/django_mysql/models/expressions.py b/src/django_mysql/models/expressions.py index 464314ba..35c8ff51 100644 --- a/src/django_mysql/models/expressions.py +++ b/src/django_mysql/models/expressions.py @@ -1,12 +1,14 @@ from __future__ import annotations from typing import Any -from typing import Iterable +from typing import Sequence from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import F from django.db.models import Value from django.db.models.expressions import BaseExpression +from django.db.models.expressions import Combinable +from django.db.models.expressions import Expression from django.db.models.sql.compiler import SQLCompiler from django_mysql.utils import collapse_spaces @@ -18,10 +20,10 @@ def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None: self.lhs = lhs self.rhs = rhs - def get_source_expressions(self) -> list[BaseExpression]: + def get_source_expressions(self) -> list[Expression]: return [self.lhs, self.rhs] - def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None: + def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: self.lhs, self.rhs = exprs @@ -138,10 +140,10 @@ def __init__(self, lhs: BaseExpression) -> None: super().__init__() self.lhs = lhs - def get_source_expressions(self) -> list[BaseExpression]: + def get_source_expressions(self) -> list[Expression]: return [self.lhs] - def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None: + def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: (self.lhs,) = exprs def as_sql( @@ -170,10 +172,10 @@ def __init__(self, lhs: BaseExpression) -> None: super().__init__() self.lhs = lhs - def get_source_expressions(self) -> list[BaseExpression]: + def get_source_expressions(self) -> list[Expression]: return [self.lhs] - def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None: + def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: (self.lhs,) = exprs def as_sql(