Skip to content

Commit

Permalink
Support checking service ports with ssl connection
Browse files Browse the repository at this point in the history
Related-Bug: #1920770
  • Loading branch information
dosaboy committed Mar 12, 2024
1 parent b78107d commit bd09435
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
24 changes: 24 additions & 0 deletions charmhelpers/contrib/network/ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import subprocess
import socket
import ssl

from functools import partial

Expand Down Expand Up @@ -542,6 +543,29 @@ def port_has_listener(address, port):
return not (bool(result))


def port_has_listener_ssl(address, port, key, cert, ca_cert):
"""
Returns True if the address:port is open and being listened to,
else False.
@param address: an IP address or hostname
@param port: integer port
@param: cert: path to cert
@param: key: path to key
@param: ca_cert: path to ca cert
"""
hostname = address
context = ssl.create_default_context()
context.check_hostname = False
context.load_cert_chain(cert, key)
context.load_verify_locations(ca_cert)
try:
with socket.create_connection((hostname, port)) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
return True
except ConnectionRefusedError:
return False

def assert_charm_supports_ipv6():
"""Check whether we are able to support charms ipv6."""
release = lsb_release()['DISTRIB_CODENAME'].lower()
Expand Down
28 changes: 24 additions & 4 deletions charmhelpers/contrib/openstack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
get_ipv6_addr,
is_ipv6,
port_has_listener,
port_has_listener_ssl,
)

from charmhelpers.core.host import (
Expand Down Expand Up @@ -1207,12 +1208,19 @@ def _ows_check_services_running(services, ports):
return ows_check_services_running(services, ports)


def ows_check_services_running(services, ports):
def ows_check_services_running(services, ports, use_ssl_check=False,
ssl_check_info=None):
"""Check that the services that should be running are actually running
and that any ports specified are being listened to.
@param services: list of strings OR dictionary specifying services/ports
@param ports: list of ports
@param use_ssl_check: Set to True if you want to use ssl to check
service ports rather than netcat. Default is
False.
@param ssl_check_info: If use_ssl_check is True this is a dict of
{key: <path to key>, cert: <path to cert>} used
by ssl when checking ports using SSL.
@returns state, message: strings or None, None
"""
messages = []
Expand All @@ -1228,7 +1236,10 @@ def ows_check_services_running(services, ports):
# also verify that the ports that should be open are open
# NB, that ServiceManager objects only OPTIONALLY have ports
map_not_open, ports_open = (
_check_listening_on_services_ports(services))
_check_listening_on_services_ports(
services,
use_ssl_check=use_ssl_check,
ssl_check_info=ssl_check_info))
if not all(ports_open):
# find which service has missing ports. They are in service
# order which makes it a bit easier.
Expand Down Expand Up @@ -1302,7 +1313,9 @@ def _check_running_services(services):
return list(zip(services, services_running)), services_running


def _check_listening_on_services_ports(services, test=False):
def _check_listening_on_services_ports(services, test=False,
use_ssl_check=False,
ssl_check_info=None):
"""Check that the unit is actually listening (has the port open) on the
ports that the service specifies are open. If test is True then the
function returns the services with ports that are open rather than
Expand All @@ -1316,7 +1329,14 @@ def _check_listening_on_services_ports(services, test=False):
"""
test = not (not (test)) # ensure test is True or False
all_ports = list(itertools.chain(*services.values()))
ports_states = [port_has_listener('0.0.0.0', p) for p in all_ports]
if use_ssl_check:
f_port_listener_check = lambda *args: port_has_listener_ssl(
*args,
**ssl_check_info)
else:
f_port_listener_check = port_has_listener

ports_states = [f_port_listener_check('0.0.0.0', p) for p in all_ports]
map_ports = OrderedDict()
matched_ports = [p for p, opened in zip(all_ports, ports_states)
if opened == test] # essentially opened xor test
Expand Down

0 comments on commit bd09435

Please sign in to comment.