Skip to content

Commit

Permalink
Refactor GenerateUuid4 function
Browse files Browse the repository at this point in the history
  • Loading branch information
hugorodgerbrown committed Sep 19, 2023
1 parent 502b9a9 commit 68462a2
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 46 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.egg-info
*.bak
*.db
.env
.coverage
.tox
.venv
Expand Down
2 changes: 1 addition & 1 deletion anonymiser/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .expressions import GenerateUuid4
from .functions import GenerateUuid4

__all__ = ["GenerateUuid4"]
42 changes: 0 additions & 42 deletions anonymiser/db/expressions.py

This file was deleted.

119 changes: 119 additions & 0 deletions anonymiser/db/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any

from django.db import models
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.sql.compiler import SQLCompiler


class GenerateUuid4(models.Func):
"""
Generate a new UUID (v4) value.
Most databases support some form of UUID generation, but the syntax
varies between them. This expression is only implemented for SQLite
and PostgreSQL.
The expression can be used to generate a UUID value for a field
where the value needs to be generated by the database, rather than
by the application.
As an example, if you have a model with a UUID field and you want to
roll all the ids, if you try to do this:
>>> User.objects.all().update(uuid=uuid.uuid4())
You will end up with every row in the database having the same UUID.
This is because the uuid.uuid4() function is evaluated once and the
same value is used for every row:
UPDATE "user" SET "uuid" =
"3742138f-1399-47b2-a721-1710abefded6"
If you want to generate a new UUID for each row whilst updating with
a single "UPDATE" SQL statement, you can use this expression:
>>> User.objects.all().update(uuid=GenerateUuid4())
Which will generate something like this:
UPDATE "user" SET "uuid" = <function_to_generate_uuid>
The specific function that is called is vendor specific.
"""

output_field = models.UUIDField()

def as_sql(
self,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
**extra_context: Any,
) -> tuple[str, list]:
"""
Generate the SQL fragment for the expression.
Note that this is only called when `as_{{vendor}}` method does
not exist for the expression.
(From https://docs.djangoproject.com/en/4.2/ref/models/lookups/)
"Returns a tuple (sql, params), where sql is the SQL string, and
params is the list or tuple of query parameters. The compiler is
an SQLCompiler object, which has a compile() method that can be
used to compile other expressions. The connection is the
connection used to execute the query.
Calling expression.as_sql() is usually incorrect - instead
compiler.compile(expression) should be used. The
compiler.compile() method will take care of calling
vendor-specific methods of the expression.
Custom keyword arguments may be defined on this method if it's
likely that as_vendorname() methods or subclasses will need to
supply data to override the generation of the SQL string. See
Func.as_sql() for example usage."
"""
raise NotImplementedError(
f"GenerateUuid4 is not implemented for {connection.vendor}"
)

def as_sqlite(
self,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
**extra_context: Any,
) -> tuple[str, list]:
"""
Generate the SQL fragment for the expression in SQLite format.
(From https://docs.djangoproject.com/en/4.2/ref/models/lookups/)
"Works like as_sql() method. When an expression is compiled by
compiler.compile(), Django will first try to call
as_vendorname(), where vendorname is the vendor name of the
backend used for executing the query. The vendorname is one of
postgresql, oracle, sqlite, or mysql for Django's built-in
backends."
"""
return super().as_sql(
compiler,
connection,
function="HEX(RANDOMBLOB(16))",
# override the default template as otherwise we end up
# trying to append parentheses: HEX(RANDOMBLOB(16))()
template="%(function)s",
**extra_context,
)

def as_postgresql(
self,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
**extra_context: Any,
) -> tuple[str, list]:
return super().as_sql(
compiler, connection, function="gen_random_uuid", **extra_context
)
4 changes: 2 additions & 2 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.db import connection
from django.db.backends.utils import CursorWrapper

from anonymiser.db.expressions import GenerateUuid4
from anonymiser.db.functions import GenerateUuid4

from .models import User

Expand All @@ -14,7 +14,7 @@
"vendor,sql_func",
[
("sqlite", "HEX(RANDOMBLOB(16))"),
("postgresql", "get_random_uuid()"),
("postgresql", "gen_random_uuid()"),
],
)
@mock.patch.object(CursorWrapper, "execute")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from django.db import models

from anonymiser.db.expressions import GenerateUuid4
from anonymiser.db.functions import GenerateUuid4
from anonymiser.models import FieldSummaryData

from .anonymisers import BadUserAnonymiser, UserAnonymiser, UserRedacter
Expand Down

0 comments on commit 68462a2

Please sign in to comment.