From 68462a25961e79472636ab411d9f6caf0549061e Mon Sep 17 00:00:00 2001 From: Hugo Rodger-Brown Date: Tue, 19 Sep 2023 17:11:32 +0100 Subject: [PATCH] Refactor GenerateUuid4 function --- .gitignore | 1 + anonymiser/db/__init__.py | 2 +- anonymiser/db/expressions.py | 42 ------------- anonymiser/db/functions.py | 119 +++++++++++++++++++++++++++++++++++ tests/test_expressions.py | 4 +- tests/test_models.py | 2 +- 6 files changed, 124 insertions(+), 46 deletions(-) delete mode 100644 anonymiser/db/expressions.py create mode 100644 anonymiser/db/functions.py diff --git a/.gitignore b/.gitignore index b49a0e8..b6d6fa6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.egg-info *.bak *.db +.env .coverage .tox .venv diff --git a/anonymiser/db/__init__.py b/anonymiser/db/__init__.py index 06bf4d9..f37a458 100644 --- a/anonymiser/db/__init__.py +++ b/anonymiser/db/__init__.py @@ -1,3 +1,3 @@ -from .expressions import GenerateUuid4 +from .functions import GenerateUuid4 __all__ = ["GenerateUuid4"] diff --git a/anonymiser/db/expressions.py b/anonymiser/db/expressions.py deleted file mode 100644 index 0fc1894..0000000 --- a/anonymiser/db/expressions.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Any - -from django.db import models -from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.models.sql.compiler import SQLCompiler - -# sql methods return a tuple of (sql, params) -EXPR_RETURN_TYPE = tuple[str, list] - - -class GenerateUuid4(models.Func): - """Run uuid_generate_v4() Postgres function.""" - - output_field = models.UUIDField() - - def as_sql( - self, - compiler: SQLCompiler, - connection: BaseDatabaseWrapper, - **extra_context: Any, - ) -> EXPR_RETURN_TYPE: - if connection.vendor in ("sqlite", "postgresql"): - return super().as_sql(compiler, connection, **extra_context) - raise NotImplementedError( - f"GenerateUuid4 is not implemented for {connection.vendor}" - ) - - def as_sqlite( - self, - compiler: SQLCompiler, - connection: BaseDatabaseWrapper, - **extra_context: Any, - ) -> EXPR_RETURN_TYPE: - return "HEX(RANDOMBLOB(16))", [] - - def as_postgresql( - self, - compiler: SQLCompiler, - connection: BaseDatabaseWrapper, - **extra_context: Any, - ) -> EXPR_RETURN_TYPE: - return "get_random_uuid()", [] diff --git a/anonymiser/db/functions.py b/anonymiser/db/functions.py new file mode 100644 index 0000000..b9cb43e --- /dev/null +++ b/anonymiser/db/functions.py @@ -0,0 +1,119 @@ +from typing import Any + +from django.db import models +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.models.sql.compiler import SQLCompiler + + +class GenerateUuid4(models.Func): + """ + Generate a new UUID (v4) value. + + Most databases support some form of UUID generation, but the syntax + varies between them. This expression is only implemented for SQLite + and PostgreSQL. + + The expression can be used to generate a UUID value for a field + where the value needs to be generated by the database, rather than + by the application. + + As an example, if you have a model with a UUID field and you want to + roll all the ids, if you try to do this: + + >>> User.objects.all().update(uuid=uuid.uuid4()) + + You will end up with every row in the database having the same UUID. + This is because the uuid.uuid4() function is evaluated once and the + same value is used for every row: + + UPDATE "user" SET "uuid" = + "3742138f-1399-47b2-a721-1710abefded6" + + If you want to generate a new UUID for each row whilst updating with + a single "UPDATE" SQL statement, you can use this expression: + + >>> User.objects.all().update(uuid=GenerateUuid4()) + + Which will generate something like this: + + UPDATE "user" SET "uuid" = + + The specific function that is called is vendor specific. + + """ + + output_field = models.UUIDField() + + def as_sql( + self, + compiler: SQLCompiler, + connection: BaseDatabaseWrapper, + **extra_context: Any, + ) -> tuple[str, list]: + """ + Generate the SQL fragment for the expression. + + Note that this is only called when `as_{{vendor}}` method does + not exist for the expression. + + (From https://docs.djangoproject.com/en/4.2/ref/models/lookups/) + + "Returns a tuple (sql, params), where sql is the SQL string, and + params is the list or tuple of query parameters. The compiler is + an SQLCompiler object, which has a compile() method that can be + used to compile other expressions. The connection is the + connection used to execute the query. + + Calling expression.as_sql() is usually incorrect - instead + compiler.compile(expression) should be used. The + compiler.compile() method will take care of calling + vendor-specific methods of the expression. + + Custom keyword arguments may be defined on this method if it's + likely that as_vendorname() methods or subclasses will need to + supply data to override the generation of the SQL string. See + Func.as_sql() for example usage." + + """ + raise NotImplementedError( + f"GenerateUuid4 is not implemented for {connection.vendor}" + ) + + def as_sqlite( + self, + compiler: SQLCompiler, + connection: BaseDatabaseWrapper, + **extra_context: Any, + ) -> tuple[str, list]: + """ + Generate the SQL fragment for the expression in SQLite format. + + (From https://docs.djangoproject.com/en/4.2/ref/models/lookups/) + + "Works like as_sql() method. When an expression is compiled by + compiler.compile(), Django will first try to call + as_vendorname(), where vendorname is the vendor name of the + backend used for executing the query. The vendorname is one of + postgresql, oracle, sqlite, or mysql for Django's built-in + backends." + + """ + return super().as_sql( + compiler, + connection, + function="HEX(RANDOMBLOB(16))", + # override the default template as otherwise we end up + # trying to append parentheses: HEX(RANDOMBLOB(16))() + template="%(function)s", + **extra_context, + ) + + def as_postgresql( + self, + compiler: SQLCompiler, + connection: BaseDatabaseWrapper, + **extra_context: Any, + ) -> tuple[str, list]: + return super().as_sql( + compiler, connection, function="gen_random_uuid", **extra_context + ) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 53d4354..634ff2b 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -4,7 +4,7 @@ from django.db import connection from django.db.backends.utils import CursorWrapper -from anonymiser.db.expressions import GenerateUuid4 +from anonymiser.db.functions import GenerateUuid4 from .models import User @@ -14,7 +14,7 @@ "vendor,sql_func", [ ("sqlite", "HEX(RANDOMBLOB(16))"), - ("postgresql", "get_random_uuid()"), + ("postgresql", "gen_random_uuid()"), ], ) @mock.patch.object(CursorWrapper, "execute") diff --git a/tests/test_models.py b/tests/test_models.py index 6f51d40..5236e3d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,7 +3,7 @@ import pytest from django.db import models -from anonymiser.db.expressions import GenerateUuid4 +from anonymiser.db.functions import GenerateUuid4 from anonymiser.models import FieldSummaryData from .anonymisers import BadUserAnonymiser, UserAnonymiser, UserRedacter