From 69c7f92905d1f77b5cf265c6a2f38692eeb1149f Mon Sep 17 00:00:00 2001 From: Adam Shedivy <50843283+ajshedivy@users.noreply.github.com> Date: Tue, 28 Jan 2025 12:42:00 -0600 Subject: [PATCH] TLS Support (#74) * add tests for get server certificate * update change log --- CHANGELOG.md | 1 + mapepire_python/websocket.py | 15 +++++++-- tests/tls_test.py | 63 ++++++++++++++++++++++++++++++------ 3 files changed, 67 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51d6581..bd03010 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +- add TLS support ## [v0.2.0](https://github.com/Mapepire-IBMi/mapepire-python/releases/tag/v0.2.0) - 2024-11-26 - replace `websocket-client` with `websockets` diff --git a/mapepire_python/websocket.py b/mapepire_python/websocket.py index 6d621b5..1f87ee7 100644 --- a/mapepire_python/websocket.py +++ b/mapepire_python/websocket.py @@ -6,6 +6,7 @@ from websockets import ConcurrencyError, ConnectionClosed, InvalidHandshake, InvalidURI from mapepire_python.data_types import DaemonServer +from mapepire_python.ssl import get_certificate ReturnType = TypeVar("ReturnType") @@ -24,8 +25,18 @@ def _create_ssl_context(self, db2_server: DaemonServer): if db2_server.ignoreUnauthorized: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE - elif db2_server.ca: - ssl_context.load_verify_locations(cadata=db2_server.ca) + else: + if db2_server.ca: + ssl_context.load_verify_locations(cadata=db2_server.ca) + else: + cert = get_certificate(db2_server) + if cert: + ssl_context.load_verify_locations(cadata=cert) + else: + raise ssl.SSLError("Failed to retrieve server certificate") + + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED return ssl_context diff --git a/tests/tls_test.py b/tests/tls_test.py index a2afa01..9a8e9d5 100644 --- a/tests/tls_test.py +++ b/tests/tls_test.py @@ -1,6 +1,11 @@ import os +import ssl +import pytest + +from mapepire_python.client.sql_job import SQLJob from mapepire_python.data_types import DaemonServer +from mapepire_python.ssl import get_certificate # Fetch environment variables server = os.getenv("VITE_SERVER") @@ -15,15 +20,53 @@ creds = DaemonServer(host=server, port=port, user=user, password=password, ignoreUnauthorized=False) -# def test_get_cert(): -# cert = get_certificate(creds) -# print(cert) -# assert cert != None +def test_get_cert(): + cert = get_certificate(creds) + print(cert) + assert cert != None + + +def test_verify_cert(): + cert = get_certificate(creds) + creds.ignoreUnauthorized = False + creds.ca = cert + job = SQLJob() + result = job.connect(creds) + assert result["success"] + + +def test_verify_cert_not_provided(): + creds.ignoreUnauthorized = False + job = SQLJob() + result = job.connect(creds) + assert result["success"] + +def test_bad_cert(): + badCert = """-----BEGIN CERTIFICATE----- +mIIDhTCCAm2gAwIBAgIEYRpOADANBgkqhkiG9w0BAQsFADBzMRAwDgYDVQQIEwdV +bmtub3duMRAwDgYDVQQGEwdVbmtub3duMRYwFAYDVQQKEw1EYjIgZm9yIElCTSBp +MRowGAYDVQQLExFXZWIgU29ja2V0IFNlcnZlcjEZMBcGA1UEAxMQT1NTQlVJTEQu +UlpLSC5ERTAeFw0yNDA4MjMxODE2MDJaFw0zNDA4MjUxODE2MDJaMHMxEDAOBgNV +BAgTB1Vua25vd24xEDAOBgNVBAYTB1Vua25vd24xFjAUBgNVBAoTDURiMiBmb3Ig +SUJNIGkxGjAYBgNVBAsTEVdlYiBTb2NrZXQgU2VydmVyMRkwFwYDVQQDExBPU1NC +VUlMRC5SWktILkRFMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAhKBx +5KoTsBs3dHibT/j8ycApa6teJOclaiCl9fX5IwKP0dli5qZ91t5sZ+51qS3mgLny +zMWBCaSIQYsuDEE374lHYpYB6wh/00VE1NseJpHbqbCQz1GSUz/d4tK4R1qx0Gv0 +lpKOd8/oMLUZ24FCEUKaqeQBxTzlQxkI9t2DbIRwS6U6oc4uj5DN2EIU+mfLb17y +j8iA7VMKsRmoke2vyOLXJJYJeASNI02AbHcbYkd6BaoyNeb3BlpssEhgZribWmdy +FhrJldpGtJyirvABaZQaEFelEqmSVbdPWccX3JWQdorZoNVXCypxJatxOZAhCg6f +iu3AceHUr+dMAS8z4QIDAQABoyEwHzAdBgNVHQ4EFgQU6VvyCjQ5574xtCg0oypV +zHP0vAMwDQYJKoZIhvcNAQELBQADggEBAD4bKhansD+uuUPYaIvPwyclr4zPvuyg +QAFu5oILqddzgPGIwogbxTxQkNjEGyorFJj1vJBCVIq4zJJ0DIv57BK/oVMy4Byl +6zMhJTjS74assgjCq1pVjIBtc2PCfiWxzo0wQCOEL8gsNCy/w5EaIATKfLtx6+Fd +CHsadf7fvJnLnK3FXOStAnN31ISSTwsvsRobdXX70nlYM/2OaZQsIlndftVRbI39 +2+94KHciPSwo/4fu+FLuvOm37GS+/ST3BKDSvwRJRxUc0r8lo1STiQz0cXC6uqDd +79/VBUN4NLZ3mBVk2FGuazIu9n80+o0fI5sg1ucQ/hBt8WR8iQ6sZUc= +-----END CERTIFICATE-----""" -# def test_verify_cert(): -# cert = get_certificate(creds) -# creds.ca = cert -# job = SQLJob() -# result = job.connect(creds) -# print(result) + creds.ca = badCert + creds.ignoreUnauthorized = False + job = SQLJob() + with pytest.raises(ssl.SSLError) as err: + res = job.connect(creds)