From 9f1a600b20686f07bb82323d65c08d8bd684f07f Mon Sep 17 00:00:00 2001 From: Hugo Rodger-Brown Date: Tue, 19 Sep 2023 13:13:02 +0100 Subject: [PATCH] Add SQLite support for generating UUIDs --- anonymiser/db/expressions.py | 40 ++++++++++++++++++++++++++++-------- tests/test_expressions.py | 40 ++++++++++++++++++++++++++++++++++++ tests/test_models.py | 22 ++------------------ 3 files changed, 74 insertions(+), 28 deletions(-) create mode 100644 tests/test_expressions.py diff --git a/anonymiser/db/expressions.py b/anonymiser/db/expressions.py index bcb4598..0fc1894 100644 --- a/anonymiser/db/expressions.py +++ b/anonymiser/db/expressions.py @@ -1,18 +1,42 @@ +from typing import Any + from django.db import models +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.models.sql.compiler import SQLCompiler + +# sql methods return a tuple of (sql, params) +EXPR_RETURN_TYPE = tuple[str, list] class GenerateUuid4(models.Func): """Run uuid_generate_v4() Postgres function.""" - function = "uuid_generate_v4" output_field = models.UUIDField() - arity = 0 - def as_sqlite(self, compiler, connection, **extra_context): # type: ignore - raise NotImplementedError("SQLite does not support native generation of UUIDs.") + def as_sql( + self, + compiler: SQLCompiler, + connection: BaseDatabaseWrapper, + **extra_context: Any, + ) -> EXPR_RETURN_TYPE: + if connection.vendor in ("sqlite", "postgresql"): + return super().as_sql(compiler, connection, **extra_context) + raise NotImplementedError( + f"GenerateUuid4 is not implemented for {connection.vendor}" + ) - def as_mysql(self, compiler, connection, **extra_context): # type: ignore - raise NotImplementedError + def as_sqlite( + self, + compiler: SQLCompiler, + connection: BaseDatabaseWrapper, + **extra_context: Any, + ) -> EXPR_RETURN_TYPE: + return "HEX(RANDOMBLOB(16))", [] - def as_oracle(self, compiler, connection, **extra_context): # type: ignore - raise NotImplementedError + def as_postgresql( + self, + compiler: SQLCompiler, + connection: BaseDatabaseWrapper, + **extra_context: Any, + ) -> EXPR_RETURN_TYPE: + return "get_random_uuid()", [] diff --git a/tests/test_expressions.py b/tests/test_expressions.py new file mode 100644 index 0000000..53d4354 --- /dev/null +++ b/tests/test_expressions.py @@ -0,0 +1,40 @@ +from unittest import mock + +import pytest +from django.db import connection +from django.db.backends.utils import CursorWrapper + +from anonymiser.db.expressions import GenerateUuid4 + +from .models import User + + +@pytest.mark.django_db +@pytest.mark.parametrize( + "vendor,sql_func", + [ + ("sqlite", "HEX(RANDOMBLOB(16))"), + ("postgresql", "get_random_uuid()"), + ], +) +@mock.patch.object(CursorWrapper, "execute") +def test_generate_uuid4( + mock_execute: mock.MagicMock, vendor: str, sql_func: str +) -> None: + uuid_expression = GenerateUuid4() + with mock.patch.object(connection, "vendor", vendor): + assert connection.vendor == vendor + User.objects.update(uuid=uuid_expression) + assert mock_execute.call_args[0][0] == ( + f'UPDATE "tests_user" SET "uuid" = {sql_func}' # noqa: S608 + ) + + +@pytest.mark.django_db +@pytest.mark.parametrize("vendor", ["mysql", "oracle"]) +def test_unsupported_databases_engines(vendor: str) -> None: + uuid_expression = GenerateUuid4() + with mock.patch.object(connection, "vendor", vendor): + assert connection.vendor == vendor + with pytest.raises(NotImplementedError): + uuid_expression.as_sql(None, connection) diff --git a/tests/test_models.py b/tests/test_models.py index 49d1891..6f51d40 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,9 +1,7 @@ -from unittest import mock, skipUnless +from unittest import mock import pytest -from django.conf import settings -from django.db import connection, models -from django.db.backends.utils import CursorWrapper +from django.db import models from anonymiser.db.expressions import GenerateUuid4 from anonymiser.models import FieldSummaryData @@ -126,21 +124,6 @@ def test_bad_anonymiser() -> None: @pytest.mark.django_db class TestRedaction: - @pytest.fixture(autouse=settings.IS_POSTGRES) - def activate_postgresql_uuid(self) -> None: - """Activate the uuid-ossp extension in the test database.""" - with connection.cursor() as cursor: - cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') - - @skipUnless(settings.IS_POSTGRES, "Test requires Postgres.") - @mock.patch.object(CursorWrapper, "execute") - def test_generate_uuid4(self, mock_execute: mock.MagicMock) -> None: - User.objects.update(uuid=GenerateUuid4()) - assert ( - mock_execute.call_args[0][0] - == 'UPDATE "tests_user" SET "uuid" = uuid_generate_v4()' - ) - def test_redact_queryset_none( self, user: User, user_redacter: UserRedacter ) -> None: @@ -195,7 +178,6 @@ def test_redact_queryset__field_overrides( user.refresh_from_db() assert user.location == "Area 51" - @skipUnless(settings.IS_POSTGRES, "Test requires Postgres.") def test_redact_queryset__field_overrides__postgres( self, user: User,