Skip to content

Commit

Permalink
Make K8s cert-manager compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
lloesche committed Oct 4, 2023
1 parent 5db69f7 commit 0b12355
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 43 deletions.
9 changes: 6 additions & 3 deletions fixca/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from resotolib.event import EventType, add_event_listener
from resotolib.x509 import gen_csr, gen_rsa_key, write_cert_to_file, write_key_to_file
from .args import parse_args
from .ca import get_ca, WebApp, CaApp
from .ca import CA, WebApp, CaApp
from threading import Event


Expand All @@ -28,7 +28,7 @@ def main() -> None:

add_event_listener(EventType.SHUTDOWN, shutdown)

CA = get_ca(namespace=args.namespace, secret_name=args.secret)
CA.initialize(namespace=args.namespace, secret_name=args.secret)

common_name = "ca.fix"
cert_key = gen_rsa_key()
Expand All @@ -52,8 +52,11 @@ def main() -> None:
web_port=args.port,
ssl_cert=cert_path,
ssl_key=key_path,
extra_config={
"tools.proxy.on": True,
},
)
web_server.mount("/ca", CaApp(get_ca(), args.psk))
web_server.mount("/ca", CaApp(CA, args.psk))

web_server.daemon = True
web_server.start()
Expand Down
4 changes: 2 additions & 2 deletions fixca/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def parse_args(add_args: List[Callable]) -> Namespace:
parser.add_argument(
"--namespace",
dest="namespace",
help="K8s namespace (default: fix)",
default=os.environ.get("FIXCA_NAMESPACE", "fix"),
help="K8s namespace (default: cert-manager)",
default=os.environ.get("FIXCA_NAMESPACE", "cert-manager"),
)
parser.add_argument(
"--secret",
Expand Down
185 changes: 150 additions & 35 deletions fixca/ca.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import cherrypy
from functools import wraps
from prometheus_client.exposition import generate_latest, CONTENT_TYPE_LATEST
from typing import Optional, Dict, Callable, Tuple, Union
from typing import Optional, Dict, Callable, Tuple, Union, Any, List
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.x509.base import Certificate, CertificateSigningRequest
from resotolib.logger import log
Expand All @@ -14,54 +15,132 @@
load_csr_from_bytes,
load_cert_from_bytes,
load_key_from_bytes,
gen_rsa_key,
gen_csr,
gen_ca_bundle_bytes,
)
from resotolib.jwt import encode_jwt, decode_jwt_from_headers
from .k8s import get_secret, set_secret
from .utils import str_to_bool


CA: Optional["CertificateAuthority"] = None
PSK: Optional[Union[str, Certificate, RSAPublicKey]] = None
class CertificateAuthority:
def __init__(self):
self.cert = None
self.__key = None
self.__initialized = False

@staticmethod
def requires_initialized(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(ca_instance: "CertificateAuthority", *args: Any, **kwargs: Any) -> Any:
if not ca_instance.initialized:
raise Exception("CA not initialized")
return func(ca_instance, *args, **kwargs)

class CertificateAuthority:
def __init__(self, ca_key: RSAPrivateKey, ca_cert: Certificate):
self.ca_key = ca_key
self.ca_cert = ca_cert
return wrapper

@requires_initialized
def sign(self, csr: CertificateSigningRequest) -> Certificate:
return sign_csr(csr, self.ca_key, self.ca_cert)
return sign_csr(csr, self.__key, self.cert)

def initialize(self, namespace: str = "cert-manager", secret_name: str = "fix-ca") -> None:
self.__key, self.cert = self.__load_ca_data(namespace=namespace, secret_name=secret_name)
self.__initialized = True

def load_ca_data(namespace: str = "fix", secret_name: str = "fix-ca") -> Tuple[RSAPrivateKey, Certificate]:
log.info("Loading CA data")
ca_secret = get_secret(namespace=namespace, secret_name=secret_name)
@property
def initialized(self) -> bool:
return self.__initialized

if isinstance(ca_secret, dict) and (not "key" in ca_secret or not "cert" in ca_secret):
ca_secret = None
log.error("CA secret is missing key or cert")
@staticmethod
def __load_ca_data(
namespace: str = "cert-manager", secret_name: str = "fix-ca"
) -> Tuple[RSAPrivateKey, Certificate]:
log.info("Loading CA data")
ca_secret = get_secret(namespace=namespace, secret_name=secret_name)

if isinstance(ca_secret, dict) and (not "tls.key" in ca_secret or not "tls.crt" in ca_secret):
ca_secret = None
log.error("CA secret is missing key or cert")

if ca_secret is None:
log.debug("Bootstrapping a new CA")
key, cert = bootstrap_ca(common_name="FIX Certification Authority")
ca_secret = {
"tls.key": key_to_bytes(key).decode("utf-8"),
"tls.crt": cert_to_bytes(cert).decode("utf-8"),
}
set_secret(namespace=namespace, secret_name=secret_name, data=ca_secret)
else:
log.debug("Loading existing CA")
key_bytes, cert_bytes = ca_secret["tls.key"].encode(), ca_secret["tls.crt"].encode()
key = load_key_from_bytes(key_bytes)
cert = load_cert_from_bytes(cert_bytes)

if ca_secret is None:
log.debug("Bootstrapping a new CA")
key, cert = bootstrap_ca(common_name="FIX Certification Authority")
ca_secret = {
"key": key_to_bytes(key).decode("utf-8"),
"cert": cert_to_bytes(cert).decode("utf-8"),
return key, cert

@requires_initialized
def generate(
self,
common_name: str,
san_dns_names: Optional[List[str]] = None,
san_ip_addresses: Optional[List[str]] = None,
) -> Tuple[RSAPrivateKey, Certificate]:
if san_dns_names is None:
san_dns_names = []
elif isinstance(san_dns_names, str):
san_dns_names = [san_dns_names]
if san_ip_addresses is None:
san_ip_addresses = []
elif isinstance(san_ip_addresses, str):
san_ip_addresses = [san_ip_addresses]

cert_key = gen_rsa_key()
cert_csr = gen_csr(
cert_key,
common_name=common_name,
san_dns_names=san_dns_names,
san_ip_addresses=san_ip_addresses,
include_loopback=False,
connect_to_ips=None,
discover_local_dns_names=False,
discover_local_ip_addresses=False,
)
cert_crt = self.sign(cert_csr)
return cert_key, cert_crt

def store_secret(
self,
cert_key: RSAPrivateKey,
cert_crt: Certificate,
namespace: str,
secret_name: str,
key_cert: str = "cert.pem",
key_key: str = "cert.key",
key_ca: str = "ca.pem",
key_ca_bundle: str = "ca.bundle.pem",
include_ca_cert: bool = False,
include_ca_bundle: bool = False,
) -> None:
log.info(f"Storing certificate {cert_crt.subject.rfc4514_string()} in {namespace}/{secret_name}")
secret = {
key_cert: cert_to_bytes(cert_crt).decode("utf-8"),
key_key: key_to_bytes(cert_key).decode("utf-8"),
}
set_secret(namespace=namespace, secret_name=secret_name, data=ca_secret)
else:
log.debug("Loading existing CA")
key_bytes, cert_bytes = ca_secret["key"].encode(), ca_secret["cert"].encode()
key = load_key_from_bytes(key_bytes)
cert = load_cert_from_bytes(cert_bytes)
if include_ca_cert:
secret[key_ca] = cert_to_bytes(self.cert).decode("utf-8")
if include_ca_bundle:
secret[key_ca_bundle] = gen_ca_bundle_bytes(self.cert).decode("utf-8")

return key, cert
set_secret(
namespace=namespace,
secret_name=secret_name,
data=secret,
)


def get_ca(namespace: str = "fix", secret_name: str = "fix-ca") -> CertificateAuthority:
global CA
if CA is None:
CA = CertificateAuthority(*load_ca_data(namespace=namespace, secret_name=secret_name))
return CA
CA: CertificateAuthority = CertificateAuthority()
PSK: Optional[Union[str, Certificate, RSAPublicKey]] = None


def jwt_check():
Expand Down Expand Up @@ -94,7 +173,7 @@ def __init__(
"tools.staticdir.on": True,
"tools.staticdir.dir": f"{local_path}/static",
}
self.ca = get_ca()
self.ca = CA
self.config = {"/": config}
self.health_conditions = health_conditions if health_conditions is not None else {}
if self.mountpoint not in ("/", ""):
Expand Down Expand Up @@ -132,14 +211,14 @@ def __init__(self, ca: CertificateAuthority, psk_or_cert: Union[str, Certificate
@cherrypy.tools.allow(methods=["GET"])
def cert(self) -> bytes:
assert self.psk_or_cert is not None
fingerprint = cert_fingerprint(self.ca.ca_cert)
fingerprint = cert_fingerprint(self.ca.cert)
cherrypy.response.headers["Content-Type"] = "application/x-pem-file"
cherrypy.response.headers["SHA256-Fingerprint"] = fingerprint
cherrypy.response.headers["Content-Disposition"] = 'attachment; filename="fix_root_ca.pem"'
cherrypy.response.headers["Authorization"] = "Bearer " + encode_jwt(
{"sha256_fingerprint": fingerprint}, self.psk_or_cert
)
return cert_to_bytes(self.ca.ca_cert)
return cert_to_bytes(self.ca.cert)

@cherrypy.expose
@cherrypy.tools.allow(methods=["POST"])
Expand All @@ -159,3 +238,39 @@ def sign(self) -> bytes:
cherrypy.response.headers["SHA256-Fingerprint"] = cert_fingerprint(crt)
cherrypy.response.headers["Content-Disposition"] = f'attachment; filename="{filename}"'
return cert_to_bytes(crt)

@cherrypy.expose
@cherrypy.tools.json_out()
@cherrypy.tools.json_in()
@cherrypy.tools.allow(methods=["POST"])
@cherrypy.tools.jwt_check()
def generate(self) -> bytes:
try:
request_json = cherrypy.request.json
remote_addr = cherrypy.request.remote.ip
include_ca_cert = str_to_bool(request_json.get("include_ca_cert", False))
include_ca_bundle = str_to_bool(request_json.get("include_ca_bundle", False))
common_name = request_json.get("common_name", remote_addr)
san_dns_name = request_json.get("common_name", "localhost")
cert_key, cert_crt = self.ca.generate(
common_name=common_name,
san_dns_names=[san_dns_name],
san_ip_addresses=[remote_addr],
)
secret_key_cert = request_json.get("key_cert", "cert.pem")
secret_key_key = request_json.get("key_key", "cert.key")
secret_key_ca = request_json.get("key_ca", "ca.pem")
secret_key_ca_bundle = request_json.get("key_ca_bundle", "ca.bundle.pem")
secret = {
secret_key_cert: cert_to_bytes(cert_crt).decode("utf-8"),
secret_key_key: key_to_bytes(cert_key).decode("utf-8"),
}
if include_ca_cert:
secret[secret_key_ca] = cert_to_bytes(self.ca.cert).decode("utf-8")
if include_ca_bundle:
secret[secret_key_ca_bundle] = gen_ca_bundle_bytes(self.ca.cert).decode("utf-8")
except Exception:
cherrypy.response.status = 400
return {"error": "Invalid request"}

return secret
8 changes: 7 additions & 1 deletion fixca/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@
from resotolib.logger import log
from kubernetes import client, config
from kubernetes.client.exceptions import ApiException
from .utils import memoize


def k8s_client() -> client.CoreV1Api:
k8s_config_load()
return client.CoreV1Api()


@memoize()
def k8s_config_load() -> None:
try:
config.load_incluster_config()
except config.config_exception.ConfigException:
Expand All @@ -15,7 +22,6 @@ def k8s_client() -> client.CoreV1Api:
except config.config_exception.ConfigException as e:
log.critical(f"Failed to load Kubernetes config: {e}")
sys.exit(1)
return client.CoreV1Api()


def get_secret(namespace: str, secret_name: str) -> Optional[dict[str, str]]:
Expand Down
4 changes: 2 additions & 2 deletions fixca/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
<link rel="manifest" href="site.webmanifest">
<meta name="msapplication-TileColor" content="#da532c">
<meta name="theme-color" content="#ffffff">
<title>FIX Certification Authority</title>
<title>FIX Certificate Authority</title>
</head>
<body>
<h1>FIX Certification Authority</h1><br/>
<h1>FIX Certificate Authority</h1><br/>
<table class="primary">
<thead>
<tr>
Expand Down
41 changes: 41 additions & 0 deletions fixca/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time
from functools import wraps
from typing import Callable, Any, Tuple, Dict, Union, TypeVar


def str_to_bool(s: Union[str, bool]) -> bool:
return str(s).lower() in ("true", "1", "yes")


RT = TypeVar("RT")


def memoize(ttl: int = 60, cleanup_interval: int = 600) -> Callable:
state = {"last_cleanup": 0}
cache: Dict[Tuple[Callable, Tuple, frozenset], Tuple[RT, float]] = {}

def decorating_function(user_function: Callable[..., RT]) -> Callable[..., RT]:
@wraps(user_function)
def wrapper(*args: Any, **kwargs: Any) -> RT:
nonlocal cache
now = time.time()
key = (user_function, args, frozenset(kwargs.items()))
if key in cache:
result, timestamp = cache[key]
if now - timestamp < ttl:
return result

result = user_function(*args, **kwargs)
cache[key] = (result, now)

nonlocal state
if now - state["last_cleanup"] > cleanup_interval:
for k in [k for k, v in cache.items() if now - v[1] >= ttl]:
cache.pop(k)
state["last_cleanup"] = now

return result

return wrapper

return decorating_function

0 comments on commit 0b12355

Please sign in to comment.