diff --git a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py index b07b8355..caa6e228 100644 --- a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py +++ b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py @@ -96,7 +96,6 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): """ - import json import logging from typing import List, Mapping @@ -113,7 +112,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 PYDEPS = ["jsonschema"] diff --git a/lib/charms/hydra/v0/oauth.py b/lib/charms/hydra/v0/oauth.py index a12137c7..8d36e96a 100644 --- a/lib/charms/hydra/v0/oauth.py +++ b/lib/charms/hydra/v0/oauth.py @@ -48,23 +48,16 @@ def _set_client_config(self): ``` """ -import inspect import json import logging import re -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from typing import Dict, List, Mapping, Optional import jsonschema -from ops.charm import ( - CharmBase, - RelationBrokenEvent, - RelationChangedEvent, - RelationCreatedEvent, - RelationDepartedEvent, -) +from ops.charm import CharmBase, RelationBrokenEvent, RelationChangedEvent, RelationCreatedEvent from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents -from ops.model import Relation, Secret, TooManyRelatedAppsError +from ops.model import Relation, Secret, SecretNotFoundError, TooManyRelatedAppsError # The unique Charmhub library identifier, never change it LIBID = "a3a301e325e34aac80a2d633ef61fe97" @@ -74,12 +67,20 @@ def _set_client_config(self): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 10 + +PYDEPS = ["jsonschema"] + logger = logging.getLogger(__name__) DEFAULT_RELATION_NAME = "oauth" -ALLOWED_GRANT_TYPES = ["authorization_code", "refresh_token", "client_credentials"] +ALLOWED_GRANT_TYPES = [ + "authorization_code", + "refresh_token", + "client_credentials", + "urn:ietf:params:oauth:grant-type:device_code", +] ALLOWED_CLIENT_AUTHN_METHODS = ["client_secret_basic", "client_secret_post"] CLIENT_SECRET_FIELD = "secret" @@ -127,6 +128,7 @@ def _set_client_config(self): }, "groups": {"type": "string", "default": None}, "ca_chain": {"type": "array", "items": {"type": "string"}, "default": []}, + "jwt_access_token": {"type": "string", "default": "False"}, }, "required": [ "issuer_url", @@ -153,13 +155,13 @@ def _set_client_config(self): "type": "array", "default": None, "items": { - "enum": ["authorization_code", "client_credentials", "refresh_token"], + "enum": ALLOWED_GRANT_TYPES, "type": "string", }, }, "token_endpoint_auth_method": { "type": "string", - "enum": ["client_secret_basic", "client_secret_post"], + "enum": ALLOWED_CLIENT_AUTHN_METHODS, "default": "client_secret_basic", }, }, @@ -200,11 +202,32 @@ def _dump_data(data: Dict, schema: Optional[Dict] = None) -> Dict: ret[k] = json.dumps(v) except json.JSONDecodeError as e: raise DataValidationError(f"Failed to encode relation json: {e}") + elif isinstance(v, bool): + ret[k] = str(v) else: ret[k] = v return ret +def strtobool(val: str) -> bool: + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + if not isinstance(val, str): + raise ValueError(f"invalid value type {type(val)}") + + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError(f"invalid truth value {val}") + + class OAuthRelation(Object): """A class containing helper methods for oauth relation.""" @@ -291,11 +314,22 @@ class OauthProviderConfig: client_secret: Optional[str] = None groups: Optional[str] = None ca_chain: Optional[str] = None + jwt_access_token: Optional[bool] = False @classmethod def from_dict(cls, dic: Dict) -> "OauthProviderConfig": """Generate OauthProviderConfig instance from dict.""" - return cls(**{k: v for k, v in dic.items() if k in inspect.signature(cls).parameters}) + jwt_access_token = False + if "jwt_access_token" in dic: + jwt_access_token = strtobool(dic["jwt_access_token"]) + return cls( + jwt_access_token=jwt_access_token, + **{ + k: v + for k, v in dic.items() + if k in [f.name for f in fields(cls)] and k != "jwt_access_token" + }, + ) class OAuthInfoChangedEvent(EventBase): @@ -315,6 +349,7 @@ def snapshot(self) -> Dict: def restore(self, snapshot: Dict) -> None: """Restore event.""" + super().restore(snapshot) self.client_id = snapshot["client_id"] self.client_secret_id = snapshot["client_secret_id"] @@ -454,7 +489,9 @@ def is_client_created(self, relation_id: Optional[int] = None) -> bool: and "client_secret_id" in relation.data[relation.app] ) - def get_provider_info(self, relation_id: Optional[int] = None) -> OauthProviderConfig: + def get_provider_info( + self, relation_id: Optional[int] = None + ) -> Optional[OauthProviderConfig]: """Get the provider information from the databag.""" if len(self.model.relations) == 0: return None @@ -647,8 +684,8 @@ def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) self._get_client_config_from_relation_data, ) self.framework.observe( - events.relation_departed, - self._on_relation_departed, + events.relation_broken, + self._on_relation_broken, ) def _get_client_config_from_relation_data(self, event: RelationChangedEvent) -> None: @@ -696,7 +733,7 @@ def _get_client_config_from_relation_data(self, event: RelationChangedEvent) -> def _get_secret_label(self, relation: Relation) -> str: return f"client_secret_{relation.id}" - def _on_relation_departed(self, event: RelationDepartedEvent) -> None: + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: # Workaround for https://github.com/canonical/operator/issues/888 self._pop_relation_data(event.relation.id) @@ -711,8 +748,12 @@ def _create_juju_secret(self, client_secret: str, relation: Relation) -> Secret: return juju_secret def _delete_juju_secret(self, relation: Relation) -> None: - secret = self.model.get_secret(label=self._get_secret_label(relation)) - secret.remove_all_revisions() + try: + secret = self.model.get_secret(label=self._get_secret_label(relation)) + except SecretNotFoundError: + return + else: + secret.remove_all_revisions() def set_provider_info_in_relation_data( self, @@ -725,6 +766,7 @@ def set_provider_info_in_relation_data( scope: str, groups: Optional[str] = None, ca_chain: Optional[str] = None, + jwt_access_token: Optional[bool] = False, ) -> None: """Put the provider information in the databag.""" if not self.model.unit.is_leader(): @@ -738,6 +780,7 @@ def set_provider_info_in_relation_data( "userinfo_endpoint": userinfo_endpoint, "jwks_endpoint": jwks_endpoint, "scope": scope, + "jwt_access_token": jwt_access_token, } if groups: data["groups"] = groups @@ -760,5 +803,5 @@ def set_client_credentials_in_relation_data( # TODO: What if we are refreshing the client_secret? We need to add a # new revision for that secret = self._create_juju_secret(client_secret, relation) - data = dict(client_id=client_id, client_secret_id=secret.id) + data = {"client_id": client_id, "client_secret_id": secret.id} relation.data[self.model.app].update(_dump_data(data)) diff --git a/lib/charms/observability_libs/v0/cert_handler.py b/lib/charms/observability_libs/v0/cert_handler.py index 0fc610ff..275cf7db 100644 --- a/lib/charms/observability_libs/v0/cert_handler.py +++ b/lib/charms/observability_libs/v0/cert_handler.py @@ -40,25 +40,25 @@ from typing import List, Optional, Union, cast try: - from charms.tls_certificates_interface.v3.tls_certificates import ( # type: ignore + from charms.tls_certificates_interface.v2.tls_certificates import ( # type: ignore AllCertificatesInvalidatedEvent, CertificateAvailableEvent, CertificateExpiringEvent, CertificateInvalidatedEvent, - TLSCertificatesRequiresV3, + TLSCertificatesRequiresV2, generate_csr, generate_private_key, ) except ImportError as e: raise ImportError( - "failed to import charms.tls_certificates_interface.v3.tls_certificates; " + "failed to import charms.tls_certificates_interface.v2.tls_certificates; " "Either the library itself is missing (please get it through charmcraft fetch-lib) " "or one of its dependencies is unmet." ) from e import logging -from ops.charm import CharmBase, RelationBrokenEvent +from ops.charm import CharmBase from ops.framework import EventBase, EventSource, Object, ObjectEvents from ops.model import Relation @@ -67,7 +67,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 0 -LIBPATCH = 12 +LIBPATCH = 14 def is_ip_address(value: str) -> bool: @@ -132,7 +132,7 @@ def __init__( self.peer_relation_name = peer_relation_name self.certificates_relation_name = certificates_relation_name - self.certificates = TLSCertificatesRequiresV3(self.charm, self.certificates_relation_name) + self.certificates = TLSCertificatesRequiresV2(self.charm, self.certificates_relation_name) self.framework.observe( self.charm.on.config_changed, @@ -158,10 +158,6 @@ def __init__( self.certificates.on.all_certificates_invalidated, # pyright: ignore self._on_all_certificates_invalidated, ) - self.framework.observe( - self.charm.on[self.certificates_relation_name].relation_broken, # pyright: ignore - self._on_certificates_relation_broken, - ) # Peer relation events self.framework.observe( @@ -289,7 +285,7 @@ def _generate_csr( if clear_cert: self._ca_cert = "" self._server_cert = "" - self._chain = "" + self._chain = [] def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: """Get the certificate from the event and store it in a peer relation. @@ -311,7 +307,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: if event_csr == self._csr: self._ca_cert = event.ca self._server_cert = event.certificate - self._chain = event.chain_as_pem() + self._chain = event.chain self.on.cert_changed.emit() # pyright: ignore @property @@ -382,29 +378,21 @@ def _server_cert(self, value: str): rel.data[self.charm.unit].update({"certificate": value}) @property - def _chain(self) -> str: + def _chain(self) -> List[str]: if self._peer_relation: - if chain := self._peer_relation.data[self.charm.unit].get("chain", ""): - chain = json.loads(chain) - - # In a previous version of this lib, chain used to be a list. - # Convert the List[str] to str, per - # https://github.com/canonical/tls-certificates-interface/pull/141 - if isinstance(chain, list): - chain = "\n\n".join(reversed(chain)) - - return cast(str, chain) - return "" + if chain := self._peer_relation.data[self.charm.unit].get("chain", []): + return cast(list, json.loads(cast(str, chain))) + return [] @_chain.setter - def _chain(self, value: str): + def _chain(self, value: List[str]): # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. rel = self._peer_relation assert rel is not None # For type checker rel.data[self.charm.unit].update({"chain": json.dumps(value)}) @property - def chain(self) -> str: + def chain(self) -> List[str]: """Return the ca chain.""" return self._chain @@ -433,13 +421,11 @@ def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> Non self.on.cert_changed.emit() # pyright: ignore def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: - # Do what you want with this information, probably remove all certificates - # Note: assuming "limit: 1" in metadata - self._generate_csr(overwrite=True, clear_cert=True) - self.on.cert_changed.emit() # pyright: ignore - - def _on_certificates_relation_broken(self, event: RelationBrokenEvent) -> None: """Clear the certificates data when removing the relation.""" + # Note: assuming "limit: 1" in metadata + # The "certificates_relation_broken" event is converted to "all invalidated" custom + # event by the tls-certificates library. Per convention, we let the lib manage the + # relation and we do not observe "certificates_relation_broken" directly. if self._peer_relation: private_key = self._private_key # This is a workaround for https://bugs.launchpad.net/juju/+bug/2024583 @@ -447,4 +433,5 @@ def _on_certificates_relation_broken(self, event: RelationBrokenEvent) -> None: if private_key: self._peer_relation.data[self.charm.unit].update({"private_key": private_key}) + # We do not generate a CSR here because the relation is gone. self.on.cert_changed.emit() # pyright: ignore diff --git a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py index be967686..e3d35c6f 100644 --- a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py +++ b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py @@ -178,7 +178,7 @@ def __init__(self, *args): - `scrape_timeout` - `proxy_url` - `relabel_configs` -- `metrics_relabel_configs` +- `metric_relabel_configs` - `sample_limit` - `label_limit` - `label_name_length_limit` @@ -362,7 +362,7 @@ def _on_scrape_targets_changed(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 46 +LIBPATCH = 47 PYDEPS = ["cosl"] @@ -377,7 +377,7 @@ def _on_scrape_targets_changed(self, event): "scrape_timeout", "proxy_url", "relabel_configs", - "metrics_relabel_configs", + "metric_relabel_configs", "sample_limit", "label_limit", "label_name_length_limit", diff --git a/lib/charms/tempo_k8s/v1/charm_tracing.py b/lib/charms/tempo_k8s/v1/charm_tracing.py index 5932c8d8..2dbdddd6 100644 --- a/lib/charms/tempo_k8s/v1/charm_tracing.py +++ b/lib/charms/tempo_k8s/v1/charm_tracing.py @@ -9,21 +9,57 @@ This means that, if your charm is related to, for example, COS' Tempo charm, you will be able to inspect in real time from the Grafana dashboard the execution flow of your charm. -To start using this library, you need to do two things: +# Quickstart +Fetch the following charm libs (and ensure the minimum version/revision numbers are satisfied): + + charmcraft fetch-lib charms.tempo_k8s.v2.tracing # >= 1.10 + charmcraft fetch-lib charms.tempo_k8s.v1.charm_tracing # >= 2.7 + +Then edit your charm code to include: + +```python +# import the necessary charm libs +from charms.tempo_k8s.v2.tracing import TracingEndpointRequirer, charm_tracing_config +from charms.tempo_k8s.v1.charm_tracing import charm_tracing + +# decorate your charm class with charm_tracing: +@charm_tracing( + # forward-declare the instance attributes that the instrumentor will look up to obtain the + # tempo endpoint and server certificate + tracing_endpoint="tracing_endpoint", + server_cert="server_cert" +) +class MyCharm(CharmBase): + _path_to_cert = "/path/to/cert.crt" + # path to cert file **in the charm container**. Its presence will be used to determine whether + # the charm is ready to use tls for encrypting charm traces. If your charm does not support tls, + # you can ignore this and pass None to charm_tracing_config. + # If you do support TLS, you'll need to make sure that the server cert is copied to this location + # and kept up to date so the instrumentor can use it. + + def __init__(self, ...): + ... + self.tracing = TracingEndpointRequirer(self, ...) + self.tracing_endpoint, self.server_cert = charm_tracing_config(self.tracing, self._path_to_cert) +``` + +# Detailed usage +To use this library, you need to do two things: 1) decorate your charm class with `@trace_charm(tracing_endpoint="my_tracing_endpoint")` -2) add to your charm a "my_tracing_endpoint" (you can name this attribute whatever you like) **property** -that returns an otlp http/https endpoint url. If you are using the `TracingEndpointProvider` as -`self.tracing = TracingEndpointProvider(self)`, the implementation could be: +2) add to your charm a "my_tracing_endpoint" (you can name this attribute whatever you like) +**property**, **method** or **instance attribute** that returns an otlp http/https endpoint url. +If you are using the ``charms.tempo_k8s.v2.tracing.TracingEndpointRequirer`` as +``self.tracing = TracingEndpointRequirer(self)``, the implementation could be: ``` @property def my_tracing_endpoint(self) -> Optional[str]: '''Tempo endpoint for charm tracing''' if self.tracing.is_ready(): - return self.tracing.otlp_http_endpoint() + return self.tracing.get_endpoint("otlp_http") else: return None ``` @@ -33,19 +69,52 @@ def my_tracing_endpoint(self) -> Optional[str]: - every event as a span (including custom events) - every charm method call (except dunders) as a span -if you wish to add more fine-grained information to the trace, you can do so by getting a hold of the tracer like so: + +## TLS support +If your charm integrates with a TLS provider which is also trusted by the tracing provider (the Tempo charm), +you can configure ``charm_tracing`` to use TLS by passing a ``server_cert`` parameter to the decorator. + +If your charm is not trusting the same CA as the Tempo endpoint it is sending traces to, +you'll need to implement a cert-transfer relation to obtain the CA certificate from the same +CA that Tempo is using. + +For example: +``` +from charms.tempo_k8s.v1.charm_tracing import trace_charm +@trace_charm( + tracing_endpoint="my_tracing_endpoint", + server_cert="_server_cert" +) +class MyCharm(CharmBase): + self._server_cert = "/path/to/server.crt" + ... + + def on_tls_changed(self, e) -> Optional[str]: + # update the server cert on the charm container for charm tracing + Path(self._server_cert).write_text(self.get_server_cert()) + + def on_tls_broken(self, e) -> Optional[str]: + # remove the server cert so charm_tracing won't try to use tls anymore + Path(self._server_cert).unlink() +``` + + +## More fine-grained manual instrumentation +if you wish to add more spans to the trace, you can do so by getting a hold of the tracer like so: ``` import opentelemetry ... - @property - def tracer(self) -> opentelemetry.trace.Tracer: - return opentelemetry.trace.get_tracer(type(self).__name__) +def get_tracer(self) -> opentelemetry.trace.Tracer: + return opentelemetry.trace.get_tracer(type(self).__name__) ``` By default, the tracer is named after the charm type. If you wish to override that, you can pass -a different `service_name` argument to `trace_charm`. +a different ``service_name`` argument to ``trace_charm``. + +See the official opentelemetry Python SDK documentation for usage: +https://opentelemetry-python.readthedocs.io/en/latest/ -*Upgrading from `v0`:* +## Upgrading from `v0` If you are upgrading from `charm_tracing` v0, you need to take the following steps (assuming you already have the newest version of the library in your charm): @@ -55,8 +124,9 @@ def tracer(self) -> opentelemetry.trace.Tracer: `opentelemetry-exporter-otlp-proto-http>=1.21.0`. -2) Update the charm method referenced to from `@trace` and `@trace_charm`, -to return from `TracingEndpointRequirer.otlp_http_endpoint()` instead of `grpc_http`. For example: +2) Update the charm method referenced to from ``@trace`` and ``@trace_charm``, +to return from ``TracingEndpointRequirer.get_endpoint("otlp_http")`` instead of ``grpc_http``. +For example: ``` from charms.tempo_k8s.v0.charm_tracing import trace_charm @@ -72,7 +142,7 @@ class MyCharm(CharmBase): def my_tracing_endpoint(self) -> Optional[str]: '''Tempo endpoint for charm tracing''' if self.tracing.is_ready(): - return self.tracing.otlp_grpc_endpoint() + return self.tracing.otlp_grpc_endpoint() # OLD API, DEPRECATED. else: return None ``` @@ -93,15 +163,67 @@ class MyCharm(CharmBase): def my_tracing_endpoint(self) -> Optional[str]: '''Tempo endpoint for charm tracing''' if self.tracing.is_ready(): - return self.tracing.otlp_http_endpoint() + return self.tracing.get_endpoint("otlp_http") # NEW API, use this. else: return None ``` -3) If you were passing a certificate using `server_cert`, you need to change it to provide an *absolute* path to -the certificate file. +3) If you were passing a certificate (str) using `server_cert`, you need to change it to +provide an *absolute* path to the certificate file instead. """ + +def _remove_stale_otel_sdk_packages(): + """Hack to remove stale opentelemetry sdk packages from the charm's python venv. + + See https://github.com/canonical/grafana-agent-operator/issues/146 and + https://bugs.launchpad.net/juju/+bug/2058335 for more context. This patch can be removed after + this juju issue is resolved and sufficient time has passed to expect most users of this library + have migrated to the patched version of juju. When this patch is removed, un-ignore rule E402 for this file in the pyproject.toml (see setting + [tool.ruff.lint.per-file-ignores] in pyproject.toml). + + This only has an effect if executed on an upgrade-charm event. + """ + # all imports are local to keep this function standalone, side-effect-free, and easy to revert later + import os + + if os.getenv("JUJU_DISPATCH_PATH") != "hooks/upgrade-charm": + return + + import logging + import shutil + from collections import defaultdict + + from importlib_metadata import distributions + + otel_logger = logging.getLogger("charm_tracing_otel_patcher") + otel_logger.debug("Applying _remove_stale_otel_sdk_packages patch on charm upgrade") + # group by name all distributions starting with "opentelemetry_" + otel_distributions = defaultdict(list) + for distribution in distributions(): + name = distribution._normalized_name # type: ignore + if name.startswith("opentelemetry_"): + otel_distributions[name].append(distribution) + + otel_logger.debug(f"Found {len(otel_distributions)} opentelemetry distributions") + + # If we have multiple distributions with the same name, remove any that have 0 associated files + for name, distributions_ in otel_distributions.items(): + if len(distributions_) <= 1: + continue + + otel_logger.debug(f"Package {name} has multiple ({len(distributions_)}) distributions.") + for distribution in distributions_: + if not distribution.files: # Not None or empty list + path = distribution._path # type: ignore + otel_logger.info(f"Removing empty distribution of {name} at {path}.") + shutil.rmtree(path) + + otel_logger.debug("Successfully applied _remove_stale_otel_sdk_packages patch. ") + + +_remove_stale_otel_sdk_packages() + import functools import inspect import logging @@ -122,18 +244,20 @@ def my_tracing_endpoint(self) -> Optional[str]: ) import opentelemetry +import ops from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import Span, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import INVALID_SPAN, Tracer -from opentelemetry.trace import get_current_span as otlp_get_current_span from opentelemetry.trace import ( + INVALID_SPAN, + Tracer, get_tracer, get_tracer_provider, set_span_in_context, set_tracer_provider, ) +from opentelemetry.trace import get_current_span as otlp_get_current_span from ops.charm import CharmBase from ops.framework import Framework @@ -146,14 +270,22 @@ def my_tracing_endpoint(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 9 +LIBPATCH = 15 PYDEPS = ["opentelemetry-exporter-otlp-proto-http==1.21.0"] logger = logging.getLogger("tracing") +dev_logger = logging.getLogger("tracing-dev") +# set this to 0 if you are debugging/developing this library source +dev_logger.setLevel(logging.CRITICAL) + +_CharmType = Type[CharmBase] # the type CharmBase and any subclass thereof +_C = TypeVar("_C", bound=_CharmType) +_T = TypeVar("_T", bound=type) +_F = TypeVar("_F", bound=Type[Callable]) tracer: ContextVar[Tracer] = ContextVar("tracer") -_GetterType = Union[Callable[[CharmBase], Optional[str]], property] +_GetterType = Union[Callable[[_CharmType], Optional[str]], property] CHARM_TRACING_ENABLED = "CHARM_TRACING_ENABLED" @@ -199,9 +331,22 @@ def _get_tracer() -> Optional[Tracer]: try: return tracer.get() except LookupError: + # fallback: this course-corrects for a user error where charm_tracing symbols are imported + # from different paths (typically charms.tempo_k8s... and lib.charms.tempo_k8s...) try: ctx: Context = copy_context() if context_tracer := _get_tracer_from_context(ctx): + logger.warning( + "Tracer not found in `tracer` context var. " + "Verify that you're importing all `charm_tracing` symbols from the same module path. \n" + "For example, DO" + ": `from charms.lib...charm_tracing import foo, bar`. \n" + "DONT: \n" + " \t - `from charms.lib...charm_tracing import foo` \n" + " \t - `from lib...charm_tracing import bar` \n" + "For more info: https://python-notes.curiousefficiency.org/en/latest/python" + "_concepts/import_traps.html#the-double-import-trap" + ) return context_tracer.get() else: return None @@ -219,11 +364,6 @@ def _span(name: str) -> Generator[Optional[Span], Any, Any]: yield None -_C = TypeVar("_C", bound=Type[CharmBase]) -_T = TypeVar("_T", bound=type) -_F = TypeVar("_F", bound=Type[Callable]) - - class TracingError(RuntimeError): """Base class for errors raised by this module.""" @@ -232,60 +372,78 @@ class UntraceableObjectError(TracingError): """Raised when an object you're attempting to instrument cannot be autoinstrumented.""" -def _get_tracing_endpoint(tracing_endpoint_getter, self, charm): - if isinstance(tracing_endpoint_getter, property): - tracing_endpoint = tracing_endpoint_getter.__get__(self) - else: # method or callable - tracing_endpoint = tracing_endpoint_getter(self) +class TLSError(TracingError): + """Raised when the tracing endpoint is https but we don't have a cert yet.""" + + +def _get_tracing_endpoint( + tracing_endpoint_attr: str, + charm_instance: object, + charm_type: type, +): + _tracing_endpoint = getattr(charm_instance, tracing_endpoint_attr) + if callable(_tracing_endpoint): + tracing_endpoint = _tracing_endpoint() + else: + tracing_endpoint = _tracing_endpoint if tracing_endpoint is None: - logger.debug( - f"{charm}.{tracing_endpoint_getter} returned None; quietly disabling " - f"charm_tracing for the run." - ) return + elif not isinstance(tracing_endpoint, str): raise TypeError( - f"{charm}.{tracing_endpoint_getter} should return a tempo endpoint (string); " + f"{charm_type.__name__}.{tracing_endpoint_attr} should resolve to a tempo endpoint (string); " f"got {tracing_endpoint} instead." ) - else: - logger.debug(f"Setting up span exporter to endpoint: {tracing_endpoint}/v1/traces") + + dev_logger.debug(f"Setting up span exporter to endpoint: {tracing_endpoint}/v1/traces") return f"{tracing_endpoint}/v1/traces" -def _get_server_cert(server_cert_getter, self, charm): - if isinstance(server_cert_getter, property): - server_cert = server_cert_getter.__get__(self) - else: # method or callable - server_cert = server_cert_getter(self) +def _get_server_cert( + server_cert_attr: str, + charm_instance: ops.CharmBase, + charm_type: Type[ops.CharmBase], +): + _server_cert = getattr(charm_instance, server_cert_attr) + if callable(_server_cert): + server_cert = _server_cert() + else: + server_cert = _server_cert if server_cert is None: logger.warning( - f"{charm}.{server_cert_getter} returned None; sending traces over INSECURE connection." + f"{charm_type}.{server_cert_attr} is None; sending traces over INSECURE connection." ) return elif not Path(server_cert).is_absolute(): raise ValueError( - f"{charm}.{server_cert_getter} should return a valid tls cert absolute path (string | Path)); " + f"{charm_type}.{server_cert_attr} should resolve to a valid tls cert absolute path (string | Path)); " f"got {server_cert} instead." ) return server_cert def _setup_root_span_initializer( - charm: Type[CharmBase], - tracing_endpoint_getter: _GetterType, - server_cert_getter: Optional[_GetterType], + charm_type: _CharmType, + tracing_endpoint_attr: str, + server_cert_attr: Optional[str], service_name: Optional[str] = None, ): """Patch the charm's initializer.""" - original_init = charm.__init__ + original_init = charm_type.__init__ @functools.wraps(original_init) def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): + # we're using 'self' here because this is charm init code, makes sense to read what's below + # from the perspective of the charm. Self.unit.name... + original_init(self, framework, *args, **kwargs) + # we call this from inside the init context instead of, say, _autoinstrument, because we want it to + # be checked on a per-charm-instantiation basis, not on a per-type-declaration one. if not is_enabled(): + # this will only happen during unittesting, hopefully, so it's fine to log a + # bit more verbosely logger.info("Tracing DISABLED: skipping root span initialization") return @@ -297,38 +455,41 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): # default service name isn't just app name because it could conflict with the workload service name _service_name = service_name or f"{self.app.name}-charm" + unit_name = self.unit.name + # apply hacky patch to remove stale opentelemetry sdk packages on upgrade-charm. + # it could be trouble if someone ever decides to implement their own tracer parallel to + # ours and before the charm has inited. We assume they won't. resource = Resource.create( attributes={ "service.name": _service_name, "compose_service": _service_name, "charm_type": type(self).__name__, # juju topology - "juju_unit": self.unit.name, + "juju_unit": unit_name, "juju_application": self.app.name, "juju_model": self.model.name, "juju_model_uuid": self.model.uuid, } ) provider = TracerProvider(resource=resource) - try: - tracing_endpoint = _get_tracing_endpoint(tracing_endpoint_getter, self, charm) - except Exception: - # if anything goes wrong with retrieving the endpoint, we go on with tracing disabled. - # better than breaking the charm. - logger.exception( - f"exception retrieving the tracing " - f"endpoint from {charm}.{tracing_endpoint_getter}; " - f"proceeding with charm_tracing DISABLED. " - ) - return + + # if anything goes wrong with retrieving the endpoint, we let the exception bubble up. + tracing_endpoint = _get_tracing_endpoint(tracing_endpoint_attr, self, charm_type) if not tracing_endpoint: + # tracing is off if tracing_endpoint is None return server_cert: Optional[Union[str, Path]] = ( - _get_server_cert(server_cert_getter, self, charm) if server_cert_getter else None + _get_server_cert(server_cert_attr, self, charm_type) if server_cert_attr else None ) + if tracing_endpoint.startswith("https://") and not server_cert: + raise TLSError( + "Tracing endpoint is https, but no server_cert has been passed." + "Please point @trace_charm to a `server_cert` attr." + ) + exporter = OTLPSpanExporter( endpoint=tracing_endpoint, certificate_file=str(Path(server_cert).absolute()) if server_cert else None, @@ -341,16 +502,18 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): _tracer = get_tracer(_service_name) # type: ignore _tracer_token = tracer.set(_tracer) - dispatch_path = os.getenv("JUJU_DISPATCH_PATH", "") + dispatch_path = os.getenv("JUJU_DISPATCH_PATH", "") # something like hooks/install + event_name = dispatch_path.split("/")[1] if "/" in dispatch_path else dispatch_path + root_span_name = f"{unit_name}: {event_name} event" + span = _tracer.start_span(root_span_name, attributes={"juju.dispatch_path": dispatch_path}) # all these shenanigans are to work around the fact that the opentelemetry tracing API is built # on the assumption that spans will be used as contextmanagers. # Since we don't (as we need to close the span on framework.commit), # we need to manually set the root span as current. - span = _tracer.start_span("charm exec", attributes={"juju.dispatch_path": dispatch_path}) ctx = set_span_in_context(span) - # log a trace id so we can look it up in tempo. + # log a trace id, so we can pick it up from the logs (and jhack) to look it up in tempo. root_trace_id = hex(span.get_span_context().trace_id)[2:] # strip 0x prefix logger.debug(f"Starting root trace with id={root_trace_id!r}.") @@ -358,6 +521,7 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): @contextmanager def wrap_event_context(event_name: str): + dev_logger.info(f"entering event context: {event_name}") # when the framework enters an event context, we create a span. with _span("event: " + event_name) as event_context_span: if event_context_span: @@ -371,6 +535,7 @@ def wrap_event_context(event_name: str): @functools.wraps(original_close) def wrap_close(): + dev_logger.info("tearing down tracer and flushing traces") span.end() opentelemetry.context.detach(span_token) # type: ignore tracer.reset(_tracer_token) @@ -382,7 +547,7 @@ def wrap_close(): framework.close = wrap_close return - charm.__init__ = wrap_init + charm_type.__init__ = wrap_init # type: ignore def trace_charm( @@ -390,7 +555,7 @@ def trace_charm( server_cert: Optional[str] = None, service_name: Optional[str] = None, extra_types: Sequence[type] = (), -): +) -> Callable[[_T], _T]: """Autoinstrument the decorated charm with tracing telemetry. Use this function to get out-of-the-box traces for all events emitted on this charm and all @@ -398,7 +563,7 @@ def trace_charm( Usage: >>> from charms.tempo_k8s.v1.charm_tracing import trace_charm - >>> from charms.tempo_k8s.v1.tracing import TracingEndpointProvider + >>> from charms.tempo_k8s.v1.tracing import TracingEndpointRequirer >>> from ops import CharmBase >>> >>> @trace_charm( @@ -408,7 +573,7 @@ def trace_charm( >>> >>> def __init__(self, framework: Framework): >>> ... - >>> self.tracing = TracingEndpointProvider(self) + >>> self.tracing = TracingEndpointRequirer(self) >>> >>> @property >>> def tempo_otlp_http_endpoint(self) -> Optional[str]: @@ -417,24 +582,28 @@ def trace_charm( >>> else: >>> return None >>> - :param server_cert: method or property on the charm type that returns an - optional absolute path to a tls certificate to be used when sending traces to a remote server. - If it returns None, an _insecure_ connection will be used. - :param tracing_endpoint: name of a property on the charm type that returns an - optional (fully resolvable) tempo url. If None, tracing will be effectively disabled. Else, traces will be - pushed to that endpoint. + + :param tracing_endpoint: name of a method, property or attribute on the charm type that returns an + optional (fully resolvable) tempo url to which the charm traces will be pushed. + If None, tracing will be effectively disabled. + :param server_cert: name of a method, property or attribute on the charm type that returns an + optional absolute path to a CA certificate file to be used when sending traces to a remote server. + If it returns None, an _insecure_ connection will be used. To avoid errors in transient + situations where the endpoint is already https but there is no certificate on disk yet, it + is recommended to disable tracing (by returning None from the tracing_endpoint) altogether + until the cert has been written to disk. :param service_name: service name tag to attach to all traces generated by this charm. Defaults to the juju application name this charm is deployed under. :param extra_types: pass any number of types that you also wish to autoinstrument. For example, charm libs, relation endpoint wrappers, workload abstractions, ... """ - def _decorator(charm_type: Type[CharmBase]): + def _decorator(charm_type: _T) -> _T: """Autoinstrument the wrapped charmbase type.""" _autoinstrument( charm_type, - tracing_endpoint_getter=getattr(charm_type, tracing_endpoint), - server_cert_getter=getattr(charm_type, server_cert) if server_cert else None, + tracing_endpoint_attr=tracing_endpoint, + server_cert_attr=server_cert, service_name=service_name, extra_types=extra_types, ) @@ -444,12 +613,12 @@ def _decorator(charm_type: Type[CharmBase]): def _autoinstrument( - charm_type: Type[CharmBase], - tracing_endpoint_getter: _GetterType, - server_cert_getter: Optional[_GetterType] = None, + charm_type: _T, + tracing_endpoint_attr: str, + server_cert_attr: Optional[str] = None, service_name: Optional[str] = None, extra_types: Sequence[type] = (), -) -> Type[CharmBase]: +) -> _T: """Set up tracing on this charm class. Use this function to get out-of-the-box traces for all events emitted on this charm and all @@ -461,29 +630,32 @@ def _autoinstrument( >>> from ops.main import main >>> _autoinstrument( >>> MyCharm, - >>> tracing_endpoint_getter=MyCharm.tempo_otlp_http_endpoint, + >>> tracing_endpoint_attr="tempo_otlp_http_endpoint", >>> service_name="MyCharm", >>> extra_types=(Foo, Bar) >>> ) >>> main(MyCharm) :param charm_type: the CharmBase subclass to autoinstrument. - :param server_cert_getter: method or property on the charm type that returns an - optional absolute path to a tls certificate to be used when sending traces to a remote server. - This needs to be a valid path to a certificate. - :param tracing_endpoint_getter: method or property on the charm type that returns an - optional tempo url. If None, tracing will be effectively disabled. Else, traces will be - pushed to that endpoint. + :param tracing_endpoint_attr: name of a method, property or attribute on the charm type that returns an + optional (fully resolvable) tempo url to which the charm traces will be pushed. + If None, tracing will be effectively disabled. + :param server_cert_attr: name of a method, property or attribute on the charm type that returns an + optional absolute path to a CA certificate file to be used when sending traces to a remote server. + If it returns None, an _insecure_ connection will be used. To avoid errors in transient + situations where the endpoint is already https but there is no certificate on disk yet, it + is recommended to disable tracing (by returning None from the tracing_endpoint) altogether + until the cert has been written to disk. :param service_name: service name tag to attach to all traces generated by this charm. Defaults to the juju application name this charm is deployed under. :param extra_types: pass any number of types that you also wish to autoinstrument. For example, charm libs, relation endpoint wrappers, workload abstractions, ... """ - logger.info(f"instrumenting {charm_type}") + dev_logger.info(f"instrumenting {charm_type}") _setup_root_span_initializer( charm_type, - tracing_endpoint_getter, - server_cert_getter=server_cert_getter, + tracing_endpoint_attr, + server_cert_attr=server_cert_attr, service_name=service_name, ) trace_type(charm_type) @@ -500,46 +672,66 @@ def trace_type(cls: _T) -> _T: It assumes that this class is only instantiated after a charm type decorated with `@trace_charm` has been instantiated. """ - logger.info(f"instrumenting {cls}") + dev_logger.info(f"instrumenting {cls}") for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): - logger.info(f"discovered {method}") + dev_logger.info(f"discovered {method}") if method.__name__.startswith("__"): - logger.info(f"skipping {method} (dunder)") + dev_logger.info(f"skipping {method} (dunder)") continue - new_method = trace_method(method) - if isinstance(inspect.getattr_static(cls, method.__name__), staticmethod): + # the span title in the general case should be: + # method call: MyCharmWrappedMethods.b + # if the method has a name (functools.wrapped or regular method), let + # _trace_callable use its default algorithm to determine what name to give the span. + trace_method_name = None + try: + qualname_c0 = method.__qualname__.split(".")[0] + if not hasattr(cls, method.__name__): + # if the callable doesn't have a __name__ (probably a decorated method), + # it probably has a bad qualname too (such as my_decorator..wrapper) which is not + # great for finding out what the trace is about. So we use the method name instead and + # add a reference to the decorator name. Result: + # method call: @my_decorator(MyCharmWrappedMethods.b) + trace_method_name = f"@{qualname_c0}({cls.__name__}.{name})" + except Exception: # noqa: failsafe + pass + + new_method = trace_method(method, name=trace_method_name) + + if isinstance(inspect.getattr_static(cls, name), staticmethod): new_method = staticmethod(new_method) setattr(cls, name, new_method) return cls -def trace_method(method: _F) -> _F: +def trace_method(method: _F, name: Optional[str] = None) -> _F: """Trace this method. A span will be opened when this method is called and closed when it returns. """ - return _trace_callable(method, "method") + return _trace_callable(method, "method", name=name) -def trace_function(function: _F) -> _F: +def trace_function(function: _F, name: Optional[str] = None) -> _F: """Trace this function. A span will be opened when this function is called and closed when it returns. """ - return _trace_callable(function, "function") + return _trace_callable(function, "function", name=name) -def _trace_callable(callable: _F, qualifier: str) -> _F: - logger.info(f"instrumenting {callable}") +def _trace_callable(callable: _F, qualifier: str, name: Optional[str] = None) -> _F: + dev_logger.info(f"instrumenting {callable}") # sig = inspect.signature(callable) @functools.wraps(callable) def wrapped_function(*args, **kwargs): # type: ignore - name = getattr(callable, "__qualname__", getattr(callable, "__name__", str(callable))) - with _span(f"{qualifier} call: {name}"): # type: ignore + name_ = name or getattr( + callable, "__qualname__", getattr(callable, "__name__", str(callable)) + ) + with _span(f"{qualifier} call: {name_}"): # type: ignore return callable(*args, **kwargs) # type: ignore # wrapped_function.__signature__ = sig diff --git a/lib/charms/tempo_k8s/v2/tracing.py b/lib/charms/tempo_k8s/v2/tracing.py index 8b9fb4f3..dfb23365 100644 --- a/lib/charms/tempo_k8s/v2/tracing.py +++ b/lib/charms/tempo_k8s/v2/tracing.py @@ -107,7 +107,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 PYDEPS = ["pydantic"] @@ -116,14 +116,13 @@ def __init__(self, *args): DEFAULT_RELATION_NAME = "tracing" RELATION_INTERFACE_NAME = "tracing" +# Supported list rationale https://github.com/canonical/tempo-coordinator-k8s-operator/issues/8 ReceiverProtocol = Literal[ "zipkin", - "kafka", - "opencensus", - "tempo_http", - "tempo_grpc", "otlp_grpc", "otlp_http", + "jaeger_grpc", + "jaeger_thrift_http", ] RawReceiver = Tuple[ReceiverProtocol, str] @@ -141,14 +140,12 @@ class TransportProtocolType(str, enum.Enum): grpc = "grpc" -receiver_protocol_to_transport_protocol = { +receiver_protocol_to_transport_protocol: Dict[ReceiverProtocol, TransportProtocolType] = { "zipkin": TransportProtocolType.http, - "kafka": TransportProtocolType.http, - "opencensus": TransportProtocolType.http, - "tempo_http": TransportProtocolType.http, - "tempo_grpc": TransportProtocolType.grpc, "otlp_grpc": TransportProtocolType.grpc, "otlp_http": TransportProtocolType.http, + "jaeger_thrift_http": TransportProtocolType.http, + "jaeger_grpc": TransportProtocolType.grpc, } """A mapping between telemetry protocols and their corresponding transport protocol. """ diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py index cbdd80d1..aa4704c7 100644 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -111,6 +111,7 @@ def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> Non ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, + recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -276,13 +277,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven """ # noqa: D405, D410, D411, D214, D416 import copy +import ipaddress import json import logging import uuid from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from ipaddress import IPv4Address from typing import List, Literal, Optional, Union from cryptography import x509 @@ -316,7 +317,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 10 +LIBPATCH = 17 PYDEPS = ["cryptography", "jsonschema"] @@ -453,11 +454,35 @@ class ProviderCertificate: ca: str chain: List[str] revoked: bool + expiry_time: datetime + expiry_notification_time: Optional[datetime] = None def chain_as_pem(self) -> str: """Return full certificate chain as a PEM string.""" return "\n\n".join(reversed(self.chain)) + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "relation_id": self.relation_id, + "application_name": self.application_name, + "csr": self.csr, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + "revoked": self.revoked, + "expiry_time": self.expiry_time.isoformat(), + "expiry_notification_time": self.expiry_notification_time.isoformat() + if self.expiry_notification_time + else None, + } + ) + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -682,21 +707,49 @@ def _get_closest_future_time( ) -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], + requirer_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider and by the requirer. + Time recommended by the provider is preferred, + then time recommended by the requirer, + then dynamically calculated time. Args: - certificate (str): x509 certificate as a string + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + requirer_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the requirer. Returns: - Optional[datetime]: Expiry datetime or None + datetime: Time to notify the user about the certificate expiry. """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after_utc - except ValueError: - logger.warning("Could not load certificate.") - return None + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = ( + expiry_time - timedelta(hours=provider_recommended_notification_time) + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + + if requirer_recommended_notification_time is not None: + requirer_recommended_notification_time = abs(requirer_recommended_notification_time) + requirer_recommendation_time_delta = ( + expiry_time - timedelta(hours=requirer_recommended_notification_time) + ) + if validity_start_time < requirer_recommendation_time_delta: + return requirer_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) def generate_ca( @@ -965,6 +1018,8 @@ def generate_csr( # noqa: C901 organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -983,6 +1038,8 @@ def generate_csr( # noqa: C901 organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. + state_or_province_name (str): State or Province Name. + locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) @@ -1008,13 +1065,19 @@ def generate_csr( # noqa: C901 subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] if sans_oid: _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: - _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) if sans: _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: @@ -1030,6 +1093,13 @@ def generate_csr( # noqa: C901 return signed_certificate.public_bytes(serialization.Encoding.PEM) +def get_sha256_hex(data: str) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(data.encode()) + return digest.finalize().hex() + + def csr_matches_certificate(csr: str, cert: str) -> bool: """Check if a CSR matches a certificate. @@ -1039,25 +1109,16 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: Returns: bool: True/False depending on whether the CSR matches the certificate. """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if ( - csr_object.public_key().public_numbers().n # type: ignore[union-attr] - != cert_object.public_key().public_numbers().n # type: ignore[union-attr] - ): - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): return False return True @@ -1135,6 +1196,7 @@ def _add_certificate( certificate_signing_request: str, ca: str, chain: List[str], + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificate to relation data. @@ -1144,6 +1206,8 @@ def _add_certificate( certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain + recommended_expiry_notification_time (int): + Time in hours before the certificate expires to notify the user. Returns: None @@ -1161,6 +1225,7 @@ def _add_certificate( "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, + "recommended_expiry_notification_time": recommended_expiry_notification_time, } provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) @@ -1227,6 +1292,7 @@ def set_relation_certificate( ca: str, chain: List[str], relation_id: int, + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificates to relation data. @@ -1236,6 +1302,8 @@ def set_relation_certificate( ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID + recommended_expiry_notification_time (int): + Recommended time in hours before the certificate expires to notify the user. Returns: None @@ -1257,6 +1325,7 @@ def set_relation_certificate( certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], + recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: @@ -1310,6 +1379,13 @@ def get_provider_certificates( provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) for certificate in provider_certificates: + try: + certificate_object = x509.load_pem_x509_certificate( + data=certificate["certificate"].encode() + ) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue provider_certificate = ProviderCertificate( relation_id=relation.id, application_name=relation.app.name, @@ -1318,6 +1394,10 @@ def get_provider_certificates( ca=certificate["ca"], chain=certificate["chain"], revoked=certificate.get("revoked", False), + expiry_time=certificate_object.not_valid_after_utc, + expiry_notification_time=certificate.get( + "recommended_expiry_notification_time" + ), ) certificates.append(provider_certificate) return certificates @@ -1475,15 +1555,17 @@ def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: int = 168, + expiry_notification_time: Optional[int] = None, ): """Generate/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Time difference between now and expiry (in hours). - Used to trigger the CertificateExpiring event. Default: 7 days. + expiry_notification_time (int): Number of hours prior to certificate expiry. + Used to trigger the CertificateExpiring event. + This value is used as a recommendation only, + The actual value is calculated taking into account the provider's recommendation. """ super().__init__(charm, relationship_name) if not JujuVersion.from_environ().has_secrets: @@ -1544,9 +1626,25 @@ def get_provider_certificates(self) -> List[ProviderCertificate]: if not certificate: logger.warning("No certificate found in relation data - Skipping") continue + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue ca = provider_certificate_dict.get("ca") chain = provider_certificate_dict.get("chain", []) csr = provider_certificate_dict.get("certificate_signing_request") + recommended_expiry_notification_time = provider_certificate_dict.get( + "recommended_expiry_notification_time" + ) + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=validity_start_time, + expiry_time=expiry_time, + provider_recommended_notification_time=recommended_expiry_notification_time, + requirer_recommended_notification_time=self.expiry_notification_time, + ) if not csr: logger.warning("No CSR found in relation data - Skipping") continue @@ -1559,6 +1657,8 @@ def get_provider_certificates(self) -> List[ProviderCertificate]: ca=ca, chain=chain, revoked=revoked, + expiry_time=expiry_time, + expiry_notification_time=expiry_notification_time, ) provider_certificates.append(provider_certificate) return provider_certificates @@ -1708,13 +1808,9 @@ def get_expiring_certificates(self) -> List[ProviderCertificate]: expiring_certificates: List[ProviderCertificate] = [] for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - expiry_time = _get_certificate_expiry_time(cert.certificate) - if not expiry_time: + if not cert.expiry_time or not cert.expiry_notification_time: continue - expiry_notification_time = expiry_time - timedelta( - hours=self.expiry_notification_time - ) - if datetime.now(timezone.utc) > expiry_notification_time: + if datetime.now(timezone.utc) > cert.expiry_notification_time: expiring_certificates.append(cert) return expiring_certificates @@ -1774,9 +1870,15 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ] for certificate in provider_certificates: if certificate.csr in requirer_csrs: + csr_in_sha256_hex = get_sha256_hex(certificate.csr) if certificate.revoked: with suppress(SecretNotFoundError): - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + logger.debug( + "Removing secret with label %s", + f"{LIBID}-{csr_in_sha256_hex}", + ) + secret = self.model.get_secret( + label=f"{LIBID}-{csr_in_sha256_hex}") secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", @@ -1787,16 +1889,24 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ) else: try: - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") - secret.set_content({"certificate": certificate.certificate}) + logger.debug( + "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + secret.set_content( + {"certificate": certificate.certificate, "csr": certificate.csr} + ) secret.set_info( - expire=self._get_next_secret_expiry_time(certificate.certificate), + expire=self._get_next_secret_expiry_time(certificate), ) except SecretNotFoundError: + logger.debug( + "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) secret = self.charm.unit.add_secret( - {"certificate": certificate.certificate}, - label=f"{LIBID}-{certificate.csr}", - expire=self._get_next_secret_expiry_time(certificate.certificate), + {"certificate": certificate.certificate, "csr": certificate.csr}, + label=f"{LIBID}-{csr_in_sha256_hex}", + expire=self._get_next_secret_expiry_time(certificate), ) self.on.certificate_available.emit( certificate_signing_request=certificate.csr, @@ -1805,7 +1915,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: chain=certificate.chain, ) - def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: + def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: """Return the expiry time or expiry notification time. Extracts the expiry time from the provided certificate, calculates the @@ -1813,17 +1923,18 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: the future. Args: - certificate: x509 certificate + certificate: ProviderCertificate object Returns: Optional[datetime]: None if the certificate expiry time cannot be read, next expiry time otherwise. """ - expiry_time = _get_certificate_expiry_time(certificate) - if not expiry_time: + if not certificate.expiry_time or not certificate.expiry_notification_time: return None - expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) - return _get_closest_future_time(expiry_notification_time, expiry_time) + return _get_closest_future_time( + certificate.expiry_notification_time, + certificate.expiry_time, + ) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: """Handle Relation Broken Event. @@ -1857,27 +1968,26 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: """ if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return - csr = event.secret.label[len(f"{LIBID}-") :] + csr = event.secret.get_content()["csr"] provider_certificate = self._find_certificate_in_relation_data(csr) if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(provider_certificate.certificate) - if not expiry_time: + if not provider_certificate.expiry_time: # A secret expired but matching certificate is invalid. Cleaning up event.secret.remove_all_revisions() return - if datetime.now(timezone.utc) < expiry_time: + if datetime.now(timezone.utc) < provider_certificate.expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( certificate=provider_certificate.certificate, - expiry=expiry_time.isoformat(), + expiry=provider_certificate.expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(provider_certificate.certificate), + expire=provider_certificate.expiry_time, ) else: logger.warning("Certificate is expired")