Skip to content

Commit

Permalink
Merge pull request #127 from qpfmtlcp/develop
Browse files Browse the repository at this point in the history
Make auth header prefix configurable
  • Loading branch information
belugame authored Oct 12, 2018
2 parents 823654d + 429bc3c commit 0e42503
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
3 changes: 3 additions & 0 deletions docs/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ is used.
This is the minimum time in seconds that needs to pass for the token expiry to be updated
in the database.

## AUTH_HEADER_PREFIX
This is the Authorization header value prefix. The default is `Token`

# Constants `knox.settings`
Knox also provides some constants for information. These must not be changed in
external code; they are used in the model definitions in knox and an error will
Expand Down
16 changes: 10 additions & 6 deletions knox/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def compare_digest(a, b):
from django.conf import settings
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
from django.contrib.auth import get_user_model

from rest_framework import exceptions
from rest_framework.authentication import (
Expand All @@ -21,7 +22,9 @@ def compare_digest(a, b):
from knox.settings import CONSTANTS, knox_settings
from knox.signals import token_expired

User = settings.AUTH_USER_MODEL
User = get_user_model()

username_field = getattr(User, 'USERNAME_FIELD', 'username')


class TokenAuthentication(BaseAuthentication):
Expand All @@ -40,8 +43,9 @@ class TokenAuthentication(BaseAuthentication):

def authenticate(self, request):
auth = get_authorization_header(request).split()

if not auth or auth[0].lower() != b'token':
prefix = knox_settings.AUTH_HEADER_PREFIX.encode()

if not auth or auth[0].lower() != prefix.lower():
return None
if len(auth) == 1:
msg = _('Invalid token header. No credentials provided.')
Expand Down Expand Up @@ -93,18 +97,18 @@ def validate_user(self, auth_token):
return (auth_token.user, auth_token)

def authenticate_header(self, request):
return 'Token'
return knox_settings.AUTH_HEADER_PREFIX

def _cleanup_token(self, auth_token):
for other_token in auth_token.user.auth_token_set.all():
if other_token.digest != auth_token.digest and other_token.expires is not None:
if other_token.expires < timezone.now():
other_token.delete()
username = other_token.user.username
username = getattr(other_token.user, username_field)
token_expired.send(sender=self.__class__, username=username, source="other_token")
if auth_token.expires is not None:
if auth_token.expires < timezone.now():
username = auth_token.user.username
username = getattr(auth_token.user, username_field)
auth_token.delete()
token_expired.send(sender=self.__class__, username=username, source="auth_token")
return True
Expand Down
1 change: 1 addition & 0 deletions knox/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'TOKEN_LIMIT_PER_USER': None,
'AUTO_REFRESH': False,
'MIN_REFRESH_INTERVAL': 60,
'AUTH_HEADER_PREFIX': 'Token',
}

IMPORT_STRINGS = {
Expand Down
16 changes: 16 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def get_basic_auth_header(username, password):
user_serializer_knox = knox_settings.defaults.copy()
user_serializer_knox["USER_SERIALIZER"] = UserSerializer

auth_header_prefix_knox = knox_settings.defaults.copy()
auth_header_prefix_knox["AUTH_HEADER_PREFIX"] = 'Baerer'

class AuthTestCase(TestCase):

def setUp(self):
Expand Down Expand Up @@ -296,3 +299,16 @@ def test_does_not_exceed_on_expired_keys(self):
self.assertIn('token', response.data)
self.assertEqual(failed_response.status_code, 403)
self.assertEqual(failed_response.data, {"error": "Maximum amount of tokens allowed per user exceeded."})

def test_invalid_prefix_return_401(self):

with override_settings(REST_KNOX=auth_header_prefix_knox):
reload_module(auth)
token = AuthToken.objects.create(user=self.user)
self.client.credentials(HTTP_AUTHORIZATION=('Token %s' % token))
failed_response = self.client.get(root_url)
self.client.credentials(HTTP_AUTHORIZATION=('Baerer %s' % token))
response = self.client.get(root_url)
reload_module(auth)
self.assertEqual(failed_response.status_code, 401)
self.assertEqual(response.status_code, 200)

0 comments on commit 0e42503

Please sign in to comment.