diff --git a/python/pyhive/tests/test_hive.py b/python/pyhive/tests/test_hive.py index b6e9c62b691..428a0211624 100644 --- a/python/pyhive/tests/test_hive.py +++ b/python/pyhive/tests/test_hive.py @@ -87,95 +87,6 @@ def test_complex(self, cursor): # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) - @mock.patch('thrift.transport.THttpClient.THttpClient') - @mock.patch('ssl.create_default_context') - def test_default_ssl_parameters(self, mock_create_context, mock_http_client): - """Test that default SSL parameters work correctly when no custom context is provided""" - mock_default_context = mock.MagicMock() - mock_create_context.return_value = mock_default_context - - try: - conn = hive.Connection( - host=_HOST, - scheme='https', - check_hostname='true', - ssl_cert='required' - ) - except TTransportException: - pass # Expected since we're mocking the transport - - # Verify default context was created and configured correctly - mock_create_context.assert_called_once() - self.assertEqual(mock_default_context.check_hostname, True) - self.assertEqual(mock_default_context.verify_mode, ssl.CERT_REQUIRED) - - # Verify the transport was created with our SSL context - mock_http_client.assert_called_once() - call_kwargs = mock_http_client.call_args[1] - self.assertEqual(call_kwargs['ssl_context'], mock_default_context) - - @mock.patch('thrift.transport.THttpClient.THttpClient') - @mock.patch('ssl.create_default_context') - def test_custom_ssl_context_overrides_defaults(self, mock_create_context, mock_http_client): - """Test that custom SSL context overrides default SSL parameters""" - custom_context = mock.MagicMock() - custom_context.check_hostname = False - custom_context.verify_mode = ssl.CERT_NONE - - try: - conn = hive.Connection( - host=_HOST, - scheme='https', - ssl_context=custom_context, - check_hostname='true', - ssl_cert='required' - ) - except TTransportException: - pass # Expected since we're mocking the transport - - # Verify default context was NOT created - mock_create_context.assert_not_called() - - # Verify the transport was created with our custom context - mock_http_client.assert_called_once() - call_kwargs = mock_http_client.call_args[1] - self.assertEqual(call_kwargs['ssl_context'], custom_context) - - def test_ssl_context_mtls_configuration(self): - """Test that SSL context can be configured for mTLS""" - # Create and configure mTLS context before creating the connection - mtls_context = mock.MagicMock() - mtls_context.load_cert_chain = mock.MagicMock() - - # Configure mTLS before passing to Connection - mtls_context.load_cert_chain( - certfile='client.crt', - keyfile='client.key' - ) - - conn = hive.Connection( - host=_HOST, - scheme='https', - ssl_context=mtls_context - ) - - # Verify mTLS was configured correctly - mtls_context.load_cert_chain.assert_called_once_with( - certfile='client.crt', - keyfile='client.key' - ) - - def test_http_connection_no_ssl(self): - """Test that HTTP connections don't create SSL context""" - with mock.patch('ssl.create_default_context') as mock_create_context: - conn = hive.Connection( - host=_HOST, - scheme='http' # Note: using http instead of https - ) - - # Verify no SSL context was created for HTTP - mock_create_context.assert_not_called() - @with_cursor def test_async(self, cursor): cursor.execute('SELECT * FROM one_row', async_=True) @@ -328,9 +239,29 @@ def test_custom_connection(self): subprocess.check_call(['sudo', 'cp', orig_none, des]) _restart_hs2() + @pytest.mark.skip(reason="Need a proper setup for SSL context testing") + def test_basic_ssl_context(self): + """Test that connection works with a custom SSL context that mimics the default behavior.""" + # Create an SSL context similar to what Connection creates by default + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + # Connect using the same parameters as self.connect() but with our custom context + with contextlib.closing(hive.connect( + host=_HOST, + port=10000, + configuration={'mapred.job.tracker': 'local'}, + ssl_context=ssl_context + )) as connection: + with contextlib.closing(connection.cursor()) as cursor: + # Use the same query pattern as other tests + cursor.execute('SELECT 1 FROM one_row') + self.assertEqual(cursor.fetchall(), [(1,)]) + def _restart_hs2(): subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart']) with contextlib.closing(socket.socket()) as s: while s.connect_ex(('localhost', 10000)) != 0: - time.sleep(1) + time.sleep(1) \ No newline at end of file