Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add service token authentication mechanism #75

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions openshift/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ objects:
configMapKeyRef:
name: bayesian-config
key: keycloak-url
- name: BAYESIAN_AUTH_PUBLIC_KEYS_URL
valueFrom:
configMapKeyRef:
name: bayesian-config
key: auth-url
- name: BAYESIAN_JWT_AUDIENCE
value: "fabric8-online-platform,openshiftio-public"
image: "${DOCKER_REGISTRY}/${DOCKER_IMAGE}:${IMAGE_TAG}"
Expand Down
71 changes: 67 additions & 4 deletions src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import jwt
from os import getenv


from exceptions import HTTPError
from utils import fetch_public_key
from utils import fetch_public_key, fetch_service_public_keys


def decode_token(token):
def decode_user_token(token):
"""Decode the authorization token read from the request header."""
if token is None:
return {}
Expand Down Expand Up @@ -38,6 +37,40 @@ def decode_token(token):
return decoded_token


def decode_service_token(token): # pragma: no cover
"""Decode OSIO service token."""
# TODO: Merge this function and user token function once audience is removed from user tokens.
if token is None:
return {}

if token.startswith('Bearer '):
_, token = token.split(' ', 1)

pub_keys = fetch_service_public_keys(current_app)
decoded_token = None

# Since we have multiple public keys, we need to verify against every public key.
# Token can be decoded by any one of the available public keys.
for pub_key in pub_keys:
try:
pub_key = pub_key.get("key", "")
pub_key = '-----BEGIN PUBLIC KEY-----\n{pkey}\n-----END PUBLIC KEY-----'\
.format(pkey=pub_key)
decoded_token = jwt.decode(token, pub_key, algorithms=['RS256'])
except jwt.InvalidTokenError:
current_app.logger.error("Auth token couldn't be decoded for public key: {}"
.format(pub_key))
decoded_token = None

if decoded_token:
break

if not decoded_token:
raise jwt.InvalidTokenError('Auth token cannot be verified.')

return decoded_token


def get_token_from_auth_header():
"""Get the authorization token read from the request header."""
return request.headers.get('Authorization')
Expand All @@ -62,7 +95,37 @@ def wrapper(*args, **kwargs):
lgr = current_app.logger

try:
decoded = decode_token(get_token_from_auth_header())
decoded = decode_user_token(get_token_from_auth_header())
if not decoded:
lgr.exception('Provide an Authorization token with the API request')
raise HTTPError(401, 'Authentication failed - token missing')

lgr.info('Successfuly authenticated user {e} using JWT'.
format(e=decoded.get('email')))
except jwt.ExpiredSignatureError as exc:
lgr.exception('Expired JWT token')
raise HTTPError(401, 'Authentication failed - token has expired') from exc
except Exception as exc:
lgr.exception('Failed decoding JWT token')
raise HTTPError(401, 'Authentication failed - could not decode JWT token') from exc

return view(*args, **kwargs)

return wrapper


def service_token_required(view): # pragma: no cover
"""Check if the request contains a valid service token."""
@wraps(view)
def wrapper(*args, **kwargs):
# Disable authentication for local setup
if getenv('DISABLE_AUTHENTICATION') in ('1', 'True', 'true'):
return view(*args, **kwargs)

lgr = current_app.logger

try:
decoded = decode_service_token(get_token_from_auth_header())
if not decoded:
lgr.exception('Provide an Authorization token with the API request')
raise HTTPError(401, 'Authentication failed - token missing')
Expand Down
4 changes: 2 additions & 2 deletions src/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask_cors import CORS
from utils import DatabaseIngestion, scan_repo, validate_request_data, retrieve_worker_result
from f8a_worker.setup_celery import init_selinon
from auth import login_required
from auth import login_required, service_token_required
from exceptions import HTTPError

app = Flask(__name__)
Expand Down Expand Up @@ -160,7 +160,7 @@ def user_repo_scan():


@app.route('/api/v1/user-repo/notify', methods=['POST'])
@login_required
@service_token_required
def notify_user():
"""
Endpoint for notifying security vulnerability in a repository.
Expand Down
25 changes: 25 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,28 @@ def fetch_public_key(app):
app.public_key = None

return app.public_key


def fetch_service_public_keys(app): # pragma: no cover
"""Get public keys for OSIO service account. Currently, there are three public keys."""
if not getattr(app, "service_public_keys", []):
auth_url = os.getenv('BAYESIAN_AUTH_PUBLIC_KEYS_URL', '')
if auth_url:
try:
auth_url = auth_url.strip('/') + '/api/token/keys?format=pem'
result = requests.get(auth_url, timeout=0.5)
app.logger.info('Fetching public key from %s, status %d, result: %s',
auth_url, result.status_code, result.text)
except requests.exceptions.Timeout:
app.logger.error('Timeout fetching public key from %s', auth_url)
return ''
if result.status_code != 200:
return ''

keys = result.json().get('keys', [])
app.service_public_keys = keys

else:
app.service_public_keys = None

return app.service_public_keys
12 changes: 6 additions & 6 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,39 @@ def mocked_get_audiences_3():
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
def test_decode_token_invalid_input_1(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
assert decode_token(None) == {}
assert decode_user_token(None) == {}


@patch("auth.get_audiences", side_effect=mocked_get_audiences)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
def test_decode_token_invalid_input_2(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
with pytest.raises(Exception):
assert decode_token("Foobar") is None
assert decode_user_token("Foobar") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1)
def test_decode_token_invalid_input_3(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
with pytest.raises(Exception):
assert decode_token("Bearer ") is None
assert decode_user_token("Bearer ") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_2)
def test_decode_token_invalid_input_4(mocked_fetch_public_key, mocked_get_audiences):
"""Test the invalid input handling during token decoding."""
with pytest.raises(Exception):
assert decode_token("Bearer ") is None
assert decode_user_token("Bearer ") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences_2)
@patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_2)
def test_decode_token_invalid_input_5(mocked_fetch_public_key, mocked_get_audiences):
"""Test the handling wrong JWT tokens."""
with pytest.raises(Exception):
assert decode_token("Bearer something") is None
assert decode_user_token("Bearer something") is None


@patch("auth.get_audiences", side_effect=mocked_get_audiences_3)
Expand All @@ -112,7 +112,7 @@ def test_decode_token_invalid_input_6(mocked_fetch_public_key, mocked_get_audien
'aud': 'foo:bar'
}
token = jwt.encode(payload, PRIVATE_KEY, algorithm='RS256').decode("utf-8")
assert decode_token(token) is not None
assert decode_user_token(token) is not None


def test_audiences():
Expand Down