Skip to content

Commit

Permalink
TLS Support (#74)
Browse files Browse the repository at this point in the history
* add tests for get server certificate

* update change log
  • Loading branch information
ajshedivy authored Jan 28, 2025
1 parent 23b6a9a commit 69c7f92
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
- add TLS support

## [v0.2.0](https://github.com/Mapepire-IBMi/mapepire-python/releases/tag/v0.2.0) - 2024-11-26
- replace `websocket-client` with `websockets`
Expand Down
15 changes: 13 additions & 2 deletions mapepire_python/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from websockets import ConcurrencyError, ConnectionClosed, InvalidHandshake, InvalidURI

from mapepire_python.data_types import DaemonServer
from mapepire_python.ssl import get_certificate

ReturnType = TypeVar("ReturnType")

Expand All @@ -24,8 +25,18 @@ def _create_ssl_context(self, db2_server: DaemonServer):
if db2_server.ignoreUnauthorized:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif db2_server.ca:
ssl_context.load_verify_locations(cadata=db2_server.ca)
else:
if db2_server.ca:
ssl_context.load_verify_locations(cadata=db2_server.ca)
else:
cert = get_certificate(db2_server)
if cert:
ssl_context.load_verify_locations(cadata=cert)
else:
raise ssl.SSLError("Failed to retrieve server certificate")

ssl_context.check_hostname = True
ssl_context.verify_mode = ssl.CERT_REQUIRED
return ssl_context


Expand Down
63 changes: 53 additions & 10 deletions tests/tls_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import ssl

import pytest

from mapepire_python.client.sql_job import SQLJob
from mapepire_python.data_types import DaemonServer
from mapepire_python.ssl import get_certificate

# Fetch environment variables
server = os.getenv("VITE_SERVER")
Expand All @@ -15,15 +20,53 @@
creds = DaemonServer(host=server, port=port, user=user, password=password, ignoreUnauthorized=False)


# def test_get_cert():
# cert = get_certificate(creds)
# print(cert)
# assert cert != None
def test_get_cert():
cert = get_certificate(creds)
print(cert)
assert cert != None


def test_verify_cert():
cert = get_certificate(creds)
creds.ignoreUnauthorized = False
creds.ca = cert
job = SQLJob()
result = job.connect(creds)
assert result["success"]


def test_verify_cert_not_provided():
creds.ignoreUnauthorized = False
job = SQLJob()
result = job.connect(creds)
assert result["success"]


def test_bad_cert():
badCert = """-----BEGIN CERTIFICATE-----
mIIDhTCCAm2gAwIBAgIEYRpOADANBgkqhkiG9w0BAQsFADBzMRAwDgYDVQQIEwdV
bmtub3duMRAwDgYDVQQGEwdVbmtub3duMRYwFAYDVQQKEw1EYjIgZm9yIElCTSBp
MRowGAYDVQQLExFXZWIgU29ja2V0IFNlcnZlcjEZMBcGA1UEAxMQT1NTQlVJTEQu
UlpLSC5ERTAeFw0yNDA4MjMxODE2MDJaFw0zNDA4MjUxODE2MDJaMHMxEDAOBgNV
BAgTB1Vua25vd24xEDAOBgNVBAYTB1Vua25vd24xFjAUBgNVBAoTDURiMiBmb3Ig
SUJNIGkxGjAYBgNVBAsTEVdlYiBTb2NrZXQgU2VydmVyMRkwFwYDVQQDExBPU1NC
VUlMRC5SWktILkRFMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAhKBx
5KoTsBs3dHibT/j8ycApa6teJOclaiCl9fX5IwKP0dli5qZ91t5sZ+51qS3mgLny
zMWBCaSIQYsuDEE374lHYpYB6wh/00VE1NseJpHbqbCQz1GSUz/d4tK4R1qx0Gv0
lpKOd8/oMLUZ24FCEUKaqeQBxTzlQxkI9t2DbIRwS6U6oc4uj5DN2EIU+mfLb17y
j8iA7VMKsRmoke2vyOLXJJYJeASNI02AbHcbYkd6BaoyNeb3BlpssEhgZribWmdy
FhrJldpGtJyirvABaZQaEFelEqmSVbdPWccX3JWQdorZoNVXCypxJatxOZAhCg6f
iu3AceHUr+dMAS8z4QIDAQABoyEwHzAdBgNVHQ4EFgQU6VvyCjQ5574xtCg0oypV
zHP0vAMwDQYJKoZIhvcNAQELBQADggEBAD4bKhansD+uuUPYaIvPwyclr4zPvuyg
QAFu5oILqddzgPGIwogbxTxQkNjEGyorFJj1vJBCVIq4zJJ0DIv57BK/oVMy4Byl
6zMhJTjS74assgjCq1pVjIBtc2PCfiWxzo0wQCOEL8gsNCy/w5EaIATKfLtx6+Fd
CHsadf7fvJnLnK3FXOStAnN31ISSTwsvsRobdXX70nlYM/2OaZQsIlndftVRbI39
2+94KHciPSwo/4fu+FLuvOm37GS+/ST3BKDSvwRJRxUc0r8lo1STiQz0cXC6uqDd
79/VBUN4NLZ3mBVk2FGuazIu9n80+o0fI5sg1ucQ/hBt8WR8iQ6sZUc=
-----END CERTIFICATE-----"""

# def test_verify_cert():
# cert = get_certificate(creds)
# creds.ca = cert
# job = SQLJob()
# result = job.connect(creds)
# print(result)
creds.ca = badCert
creds.ignoreUnauthorized = False
job = SQLJob()
with pytest.raises(ssl.SSLError) as err:
res = job.connect(creds)

0 comments on commit 69c7f92

Please sign in to comment.