Skip to content

Commit

Permalink
fix otp bypass token model
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jan 24, 2024
1 parent 702e656 commit a7e8bc4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 33 deletions.
21 changes: 9 additions & 12 deletions codeforlife/user/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 3.2.20 on 2023-09-29 17:53
# Generated by Django 3.2.20 on 2024-01-24 18:42

import django.contrib.auth.models
import django.core.validators
Expand Down Expand Up @@ -43,6 +43,14 @@ class Migration(migrations.Migration):
'abstract': False,
},
),
migrations.CreateModel(
name='OtpBypassToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')),
],
),
migrations.CreateModel(
name='AuthFactor',
fields=[
Expand All @@ -65,15 +73,4 @@ class Migration(migrations.Migration):
'unique_together': {('session', 'auth_factor')},
},
),
migrations.CreateModel(
name='OtpBypassToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('token', models.CharField(max_length=8, validators=[django.core.validators.MinLengthValidator(8)])),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='otp_bypass_tokens', to='user.user')),
],
options={
'unique_together': {('user', 'token')},
},
),
]
29 changes: 23 additions & 6 deletions codeforlife/user/models/otp_bypass_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.utils.crypto import get_random_string

from . import user


class OtpBypassToken(models.Model):
length = 8
allowed_chars = "abcdefghijklmnopqrstuv"
max_count = 10
max_count_validation_error = ValidationError(
f"Exceeded max count of {max_count}"
Expand Down Expand Up @@ -51,13 +54,10 @@ def key(otp_bypass_token: OtpBypassToken):
)

token = models.CharField(
max_length=8,
validators=[MinLengthValidator(8)],
max_length=length,
validators=[MinLengthValidator(length)],
)

class Meta:
unique_together = ["user", "token"]

def save(self, *args, **kwargs):
if self.id is None:
if (
Expand All @@ -69,7 +69,24 @@ def save(self, *args, **kwargs):
return super().save(*args, **kwargs)

def check_token(self, token: str):
if check_password(token, self.token):
if check_password(token.lower(), self.token):
self.delete()
return True
return False

@classmethod
def generate_tokens(cls, count: int = max_count):
"""Generates a number of tokens.
Args:
count: The number of tokens to generate. Default to max.
Returns:
Raw tokens that are random and unique.
"""

tokens: t.Set[str] = set()
while len(tokens) < count:
tokens.add(get_random_string(cls.length, cls.allowed_chars))

return tokens
65 changes: 50 additions & 15 deletions codeforlife/user/tests/models/test_otp_bypass_token.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
"""
© Ocado Group
Created on 24/01/2024 at 16:17:22(+00:00).
"""

from unittest.mock import call, patch

from django.contrib.auth.hashers import check_password
from django.core.exceptions import ValidationError
from django.test import TestCase
from django.utils.crypto import get_random_string

from ...models import OtpBypassToken, User
from ...models import OtpBypassToken, User, otp_bypass_token


class TestOtpBypassToken(TestCase):
def setUp(self):
self.user = User.objects.get(id=2)

def test_bulk_create(self):
token = get_random_string(8)
token = next(iter(OtpBypassToken.generate_tokens(1)))
otp_bypass_tokens = OtpBypassToken.objects.bulk_create(
[OtpBypassToken(user=self.user, token=token)]
)
Expand All @@ -20,16 +26,13 @@ def test_bulk_create(self):
with self.assertRaises(ValidationError):
OtpBypassToken.objects.bulk_create(
[
OtpBypassToken(
user=self.user,
token=get_random_string(8),
)
for _ in range(OtpBypassToken.max_count)
OtpBypassToken(user=self.user, token=token)
for token in OtpBypassToken.generate_tokens()
]
)

def test_create(self):
token = get_random_string(8)
token = next(iter(OtpBypassToken.generate_tokens(1)))
otp_bypass_token = OtpBypassToken.objects.create(
user=self.user, token=token
)
Expand All @@ -38,22 +41,21 @@ def test_create(self):

OtpBypassToken.objects.bulk_create(
[
OtpBypassToken(
user=self.user,
token=get_random_string(8),
OtpBypassToken(user=self.user, token=token)
for token in OtpBypassToken.generate_tokens(
OtpBypassToken.max_count - 1
)
for _ in range(OtpBypassToken.max_count - 1)
]
)

with self.assertRaises(ValidationError):
OtpBypassToken.objects.create(
user=self.user,
token=get_random_string(8),
token=next(iter(OtpBypassToken.generate_tokens(1))),
)

def test_check_token(self):
token = get_random_string(8)
token = next(iter(OtpBypassToken.generate_tokens(1)))
otp_bypass_token = OtpBypassToken.objects.create(
user=self.user, token=token
)
Expand All @@ -65,3 +67,36 @@ def test_check_token(self):
user=otp_bypass_token.user,
token=otp_bypass_token.token,
)

def test_generate_tokens(self):
"""
Generates a number of unique tokens.
"""

count = 3
get_random_string_side_effect = [
"aaaaaaaa",
"aaaaaaaa",
"bbbbbbbb",
"cccccccc",
]

with patch.object(
otp_bypass_token,
"get_random_string",
side_effect=get_random_string_side_effect,
) as get_random_string:
tokens = OtpBypassToken.generate_tokens(count)
assert len(tokens) == count
assert tokens == {
"aaaaaaaa",
"bbbbbbbb",
"cccccccc",
}

get_random_string.assert_has_calls(
[
call(OtpBypassToken.length, OtpBypassToken.allowed_chars)
for _ in range(len(get_random_string_side_effect))
]
)

0 comments on commit a7e8bc4

Please sign in to comment.