Skip to content

Commit

Permalink
Merge pull request #181 from bkemper24/main
Browse files Browse the repository at this point in the history
Add support for Proof Key for Code Exchange (PKCE)
  • Loading branch information
bkemper24 authored Jun 7, 2024
2 parents 5f81609 + 41619cc commit c147ef4
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 7 deletions.
22 changes: 22 additions & 0 deletions doc/source/getting-started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,28 @@ The authentication code can also be specified using one of the following environ
- CASAUTHCODE
- VIYAAUTHCODE

Beginning with release v1.14.0, the SWAT package supports using Proof Key for Code Exchange ( PKCE )
when using authentication codes to obtain an OAuth token with HTTP. Python 3.6 or later is required
for PKCE.

To use PKCE, specify the pkce=True parameter in the :class:`CAS` constructor. When specifying pkce=True,
do not specify the authcode parameter. You will be provided a URL to use to obtain the
authentication code and prompted to enter the authentication code obtained from that URL.

.. ipython:: python
:verbatim:
conn = swat.CAS('https://my-cas-host.com:443/cas-shared-default-http/',
pkce=True)
The pkce parameter can also be specified using one of the following environment variables

- CAS_PKCE
- VIYA_PKCE
- CASPKCE
- VIYAPKCE

Kerberos
~~~~~~~~~~~~~~~~~~~~~

Expand Down
85 changes: 79 additions & 6 deletions swat/cas/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class CAS(object):
The path to the SSL certificates for the CAS server.
authcode : string, optional
Authorization code from SASLogon used to retrieve an OAuth token.
pkce : boolean, optional
Use Proof Key for Code Exchange to obtain the Authorization code
**kwargs : any, optional
Arbitrary keyword arguments used for internal purposes only.
Expand Down Expand Up @@ -353,7 +355,7 @@ def _get_connection_info(cls, hostname, port, username, password, protocol, path
def __init__(self, hostname=None, port=None, username=None, password=None,
session=None, locale=None, nworkers=None, name=None,
authinfo=None, protocol=None, path=None, ssl_ca_list=None,
authcode=None, **kwargs):
authcode=None, pkce=False, **kwargs):

# Filter session options allowed as parameters
_kwargs = {}
Expand Down Expand Up @@ -399,11 +401,23 @@ def __init__(self, hostname=None, port=None, username=None, password=None,
soptions = a2n(getsoptions(session=session, locale=locale,
nworkers=nworkers, protocol=protocol))

# Check for Proof Key for Code Exchange
pkce = pkce or cf.get_option('cas.pkce')
# Check for authcode authentication
authcode = authcode or cf.get_option('cas.authcode')
if protocol in ['http', 'https'] and authcode:
if protocol in ['http', 'https'] and (authcode or pkce):
username = None
password = type(self)._get_token(authcode=authcode, url=hostname)
verifystring = None
if pkce:
if authcode:
# User will be prompted for authcode,
# do not enter it in CAS() when using pkce
raise SWATError('Do not specify authcode with pkce')
# Get the authcode from SASLogon using Proof Key for Code Exchange
authcode, verifystring = type(self)._get_authcode(url=hostname)
# Get the OAuth token from SASLogon
password = type(self)._get_token(authcode=authcode, url=hostname,
verifystring=verifystring, pkce=pkce)

# Create error handler
try:
Expand Down Expand Up @@ -538,7 +552,8 @@ def _id_generator():

@classmethod
def _get_token(cls, username=None, password=None, authcode=None,
client_id=None, client_secret=None, url=None):
client_id=None, client_secret=None, url=None,
verifystring=None, pkce=False):
''' Retrieve token from Viya installation '''
from .rest.connection import _print_request, _setup_ssl

Expand All @@ -552,10 +567,21 @@ def _get_token(cls, username=None, password=None, authcode=None,
client_id = client_id or cf.get_option('cas.client_id') or 'SWAT'

authcode = authcode or cf.get_option('cas.authcode')
pkce = pkce or cf.get_option('cas.pkce')

if authcode:
client_secret = client_secret or cf.get_option('cas.client_secret') or ''
body = {'grant_type': 'authorization_code', 'code': authcode,
'client_id': client_id, 'client_secret': client_secret}

if pkce:
if verifystring is None:
raise SWATError('A code verifier must be supplied for pkce')

body = {'grant_type': 'authorization_code',
'code': authcode, 'code_verifier': verifystring,
'client_id': client_id, 'client_secret': client_secret}
else:
body = {'grant_type': 'authorization_code', 'code': authcode,
'client_id': client_id, 'client_secret': client_secret}
else:
username = username or cf.get_option('cas.username')
password = password or cf.get_option('cas.token')
Expand All @@ -567,11 +593,58 @@ def _get_token(cls, username=None, password=None, authcode=None,
data=urlencode(body))

if resp.status_code >= 300:
logger.debug('Token request resulted in status code %d : \n %s',
resp.status_code, resp.json())
raise SWATError('Token request resulted in a status of %s' %
resp.status_code)

return resp.json()['access_token']

@classmethod
def _get_authcode(cls, url=None, client_id=None, client_secret=None):
'''
Generate the Proof Key for Code Exchange URL to retrieve the authentication code
from the Viya installation.
Wait for the user to provide the authentication code
'''
try:
# The secrets package was introduced in Python 3.6
import secrets
except ImportError:
raise SWATError("Python 3.6 or later is required for "
"Proof Key for Code Exchange.")

import hashlib
import base64

client_id = client_id or cf.get_option('cas.client_id') or 'SWAT'
client_secret = client_secret or cf.get_option('cas.client_secret') or ''

# Generate the URL for the authcode request
cv = secrets.token_urlsafe(32)
cvh = hashlib.sha256(cv.encode('ascii')).digest()
cvhe = base64.urlsafe_b64encode(cvh)
cc = cvhe.decode('ascii')[:-1]
# Note, for pkce "cc" is provided in the authcode request
# and "cv" is provided in the OAuth token request
purl = ("/SASLogon/oauth/authorize?client_id={}&response_type=code"
"&code_challenge_method=S256&code_challenge={}").format(client_id, cc)
authurl = urljoin(url, purl)

# Display the URL to the user and wait while they go off and get the authcode
# to respond to the prompt
msg = ("Please enter the authorization code obtained from the following url : "
"\n {} \n").format(authurl)
authcode = input(msg)

# trim leading trailing whitespace and verify something was entered
authcode = authcode.strip()
if len(authcode) == 0:
raise SWATError(
"You must provide an authorization code to connect to the CAS server")

return authcode, cv

def _gen_id(self):
''' Generate an ID unique to the session '''
import numpy
Expand Down
5 changes: 5 additions & 0 deletions swat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ def check_tz(value):
'Using "http" or "https" will use the REST interface.',
environ='CAS_PROTOCOL')

register_option('cas.pkce', 'boolean', check_boolean, False,
'Indicates whether or not Proof Key for Code Exchange should\n'
'be used to obtain an authorization code.',
environ=['CAS_PKCE', 'VIYA_PKCE'])


def get_default_cafile():
''' Retrieve the default CA file in the ssl module '''
Expand Down
2 changes: 1 addition & 1 deletion swat/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_suboptions(self):
'connection_retries', 'connection_retry_interval',
'dataset', 'debug', 'exception_on_severity',
'hostname', 'missing',
'port', 'print_messages', 'protocol',
'pkce', 'port', 'print_messages', 'protocol',
'reflection_levels', 'ssl_ca_list', 'token',
'trace_actions', 'trace_ui_actions', 'username'])

Expand Down

0 comments on commit c147ef4

Please sign in to comment.