Skip to content

Commit

Permalink
Simplified testing, following pattern of other tests, need proper SSL…
Browse files Browse the repository at this point in the history
… setup with nginx to test ssl_context fully
  • Loading branch information
Alex Wojtowicz committed Feb 24, 2025
1 parent b947f24 commit 14c6074
Showing 1 changed file with 21 additions and 90 deletions.
111 changes: 21 additions & 90 deletions python/pyhive/tests/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,95 +87,6 @@ def test_complex(self, cursor):
# catch unicode/str
self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))

@mock.patch('thrift.transport.THttpClient.THttpClient')
@mock.patch('ssl.create_default_context')
def test_default_ssl_parameters(self, mock_create_context, mock_http_client):
"""Test that default SSL parameters work correctly when no custom context is provided"""
mock_default_context = mock.MagicMock()
mock_create_context.return_value = mock_default_context

try:
conn = hive.Connection(
host=_HOST,
scheme='https',
check_hostname='true',
ssl_cert='required'
)
except TTransportException:
pass # Expected since we're mocking the transport

# Verify default context was created and configured correctly
mock_create_context.assert_called_once()
self.assertEqual(mock_default_context.check_hostname, True)
self.assertEqual(mock_default_context.verify_mode, ssl.CERT_REQUIRED)

# Verify the transport was created with our SSL context
mock_http_client.assert_called_once()
call_kwargs = mock_http_client.call_args[1]
self.assertEqual(call_kwargs['ssl_context'], mock_default_context)

@mock.patch('thrift.transport.THttpClient.THttpClient')
@mock.patch('ssl.create_default_context')
def test_custom_ssl_context_overrides_defaults(self, mock_create_context, mock_http_client):
"""Test that custom SSL context overrides default SSL parameters"""
custom_context = mock.MagicMock()
custom_context.check_hostname = False
custom_context.verify_mode = ssl.CERT_NONE

try:
conn = hive.Connection(
host=_HOST,
scheme='https',
ssl_context=custom_context,
check_hostname='true',
ssl_cert='required'
)
except TTransportException:
pass # Expected since we're mocking the transport

# Verify default context was NOT created
mock_create_context.assert_not_called()

# Verify the transport was created with our custom context
mock_http_client.assert_called_once()
call_kwargs = mock_http_client.call_args[1]
self.assertEqual(call_kwargs['ssl_context'], custom_context)

def test_ssl_context_mtls_configuration(self):
"""Test that SSL context can be configured for mTLS"""
# Create and configure mTLS context before creating the connection
mtls_context = mock.MagicMock()
mtls_context.load_cert_chain = mock.MagicMock()

# Configure mTLS before passing to Connection
mtls_context.load_cert_chain(
certfile='client.crt',
keyfile='client.key'
)

conn = hive.Connection(
host=_HOST,
scheme='https',
ssl_context=mtls_context
)

# Verify mTLS was configured correctly
mtls_context.load_cert_chain.assert_called_once_with(
certfile='client.crt',
keyfile='client.key'
)

def test_http_connection_no_ssl(self):
"""Test that HTTP connections don't create SSL context"""
with mock.patch('ssl.create_default_context') as mock_create_context:
conn = hive.Connection(
host=_HOST,
scheme='http' # Note: using http instead of https
)

# Verify no SSL context was created for HTTP
mock_create_context.assert_not_called()

@with_cursor
def test_async(self, cursor):
cursor.execute('SELECT * FROM one_row', async_=True)
Expand Down Expand Up @@ -328,9 +239,29 @@ def test_custom_connection(self):
subprocess.check_call(['sudo', 'cp', orig_none, des])
_restart_hs2()

@pytest.mark.skip(reason="Need a proper setup for SSL context testing")
def test_basic_ssl_context(self):
"""Test that connection works with a custom SSL context that mimics the default behavior."""
# Create an SSL context similar to what Connection creates by default
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE

# Connect using the same parameters as self.connect() but with our custom context
with contextlib.closing(hive.connect(
host=_HOST,
port=10000,
configuration={'mapred.job.tracker': 'local'},
ssl_context=ssl_context
)) as connection:
with contextlib.closing(connection.cursor()) as cursor:
# Use the same query pattern as other tests
cursor.execute('SELECT 1 FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])


def _restart_hs2():
subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart'])
with contextlib.closing(socket.socket()) as s:
while s.connect_ex(('localhost', 10000)) != 0:
time.sleep(1)
time.sleep(1)

0 comments on commit 14c6074

Please sign in to comment.