diff --git a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py index b07b8355..caa6e228 100644 --- a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py +++ b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py @@ -96,7 +96,6 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): """ - import json import logging from typing import List, Mapping @@ -113,7 +112,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 PYDEPS = ["jsonschema"] diff --git a/lib/charms/hydra/v0/oauth.py b/lib/charms/hydra/v0/oauth.py index a12137c7..8d36e96a 100644 --- a/lib/charms/hydra/v0/oauth.py +++ b/lib/charms/hydra/v0/oauth.py @@ -48,23 +48,16 @@ def _set_client_config(self): ``` """ -import inspect import json import logging import re -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from typing import Dict, List, Mapping, Optional import jsonschema -from ops.charm import ( - CharmBase, - RelationBrokenEvent, - RelationChangedEvent, - RelationCreatedEvent, - RelationDepartedEvent, -) +from ops.charm import CharmBase, RelationBrokenEvent, RelationChangedEvent, RelationCreatedEvent from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents -from ops.model import Relation, Secret, TooManyRelatedAppsError +from ops.model import Relation, Secret, SecretNotFoundError, TooManyRelatedAppsError # The unique Charmhub library identifier, never change it LIBID = "a3a301e325e34aac80a2d633ef61fe97" @@ -74,12 +67,20 @@ def _set_client_config(self): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 6 +LIBPATCH = 10 + +PYDEPS = ["jsonschema"] + logger = logging.getLogger(__name__) DEFAULT_RELATION_NAME = "oauth" -ALLOWED_GRANT_TYPES = ["authorization_code", "refresh_token", "client_credentials"] +ALLOWED_GRANT_TYPES = [ + "authorization_code", + "refresh_token", + "client_credentials", + "urn:ietf:params:oauth:grant-type:device_code", +] ALLOWED_CLIENT_AUTHN_METHODS = ["client_secret_basic", "client_secret_post"] CLIENT_SECRET_FIELD = "secret" @@ -127,6 +128,7 @@ def _set_client_config(self): }, "groups": {"type": "string", "default": None}, "ca_chain": {"type": "array", "items": {"type": "string"}, "default": []}, + "jwt_access_token": {"type": "string", "default": "False"}, }, "required": [ "issuer_url", @@ -153,13 +155,13 @@ def _set_client_config(self): "type": "array", "default": None, "items": { - "enum": ["authorization_code", "client_credentials", "refresh_token"], + "enum": ALLOWED_GRANT_TYPES, "type": "string", }, }, "token_endpoint_auth_method": { "type": "string", - "enum": ["client_secret_basic", "client_secret_post"], + "enum": ALLOWED_CLIENT_AUTHN_METHODS, "default": "client_secret_basic", }, }, @@ -200,11 +202,32 @@ def _dump_data(data: Dict, schema: Optional[Dict] = None) -> Dict: ret[k] = json.dumps(v) except json.JSONDecodeError as e: raise DataValidationError(f"Failed to encode relation json: {e}") + elif isinstance(v, bool): + ret[k] = str(v) else: ret[k] = v return ret +def strtobool(val: str) -> bool: + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + """ + if not isinstance(val, str): + raise ValueError(f"invalid value type {type(val)}") + + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError(f"invalid truth value {val}") + + class OAuthRelation(Object): """A class containing helper methods for oauth relation.""" @@ -291,11 +314,22 @@ class OauthProviderConfig: client_secret: Optional[str] = None groups: Optional[str] = None ca_chain: Optional[str] = None + jwt_access_token: Optional[bool] = False @classmethod def from_dict(cls, dic: Dict) -> "OauthProviderConfig": """Generate OauthProviderConfig instance from dict.""" - return cls(**{k: v for k, v in dic.items() if k in inspect.signature(cls).parameters}) + jwt_access_token = False + if "jwt_access_token" in dic: + jwt_access_token = strtobool(dic["jwt_access_token"]) + return cls( + jwt_access_token=jwt_access_token, + **{ + k: v + for k, v in dic.items() + if k in [f.name for f in fields(cls)] and k != "jwt_access_token" + }, + ) class OAuthInfoChangedEvent(EventBase): @@ -315,6 +349,7 @@ def snapshot(self) -> Dict: def restore(self, snapshot: Dict) -> None: """Restore event.""" + super().restore(snapshot) self.client_id = snapshot["client_id"] self.client_secret_id = snapshot["client_secret_id"] @@ -454,7 +489,9 @@ def is_client_created(self, relation_id: Optional[int] = None) -> bool: and "client_secret_id" in relation.data[relation.app] ) - def get_provider_info(self, relation_id: Optional[int] = None) -> OauthProviderConfig: + def get_provider_info( + self, relation_id: Optional[int] = None + ) -> Optional[OauthProviderConfig]: """Get the provider information from the databag.""" if len(self.model.relations) == 0: return None @@ -647,8 +684,8 @@ def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) self._get_client_config_from_relation_data, ) self.framework.observe( - events.relation_departed, - self._on_relation_departed, + events.relation_broken, + self._on_relation_broken, ) def _get_client_config_from_relation_data(self, event: RelationChangedEvent) -> None: @@ -696,7 +733,7 @@ def _get_client_config_from_relation_data(self, event: RelationChangedEvent) -> def _get_secret_label(self, relation: Relation) -> str: return f"client_secret_{relation.id}" - def _on_relation_departed(self, event: RelationDepartedEvent) -> None: + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: # Workaround for https://github.com/canonical/operator/issues/888 self._pop_relation_data(event.relation.id) @@ -711,8 +748,12 @@ def _create_juju_secret(self, client_secret: str, relation: Relation) -> Secret: return juju_secret def _delete_juju_secret(self, relation: Relation) -> None: - secret = self.model.get_secret(label=self._get_secret_label(relation)) - secret.remove_all_revisions() + try: + secret = self.model.get_secret(label=self._get_secret_label(relation)) + except SecretNotFoundError: + return + else: + secret.remove_all_revisions() def set_provider_info_in_relation_data( self, @@ -725,6 +766,7 @@ def set_provider_info_in_relation_data( scope: str, groups: Optional[str] = None, ca_chain: Optional[str] = None, + jwt_access_token: Optional[bool] = False, ) -> None: """Put the provider information in the databag.""" if not self.model.unit.is_leader(): @@ -738,6 +780,7 @@ def set_provider_info_in_relation_data( "userinfo_endpoint": userinfo_endpoint, "jwks_endpoint": jwks_endpoint, "scope": scope, + "jwt_access_token": jwt_access_token, } if groups: data["groups"] = groups @@ -760,5 +803,5 @@ def set_client_credentials_in_relation_data( # TODO: What if we are refreshing the client_secret? We need to add a # new revision for that secret = self._create_juju_secret(client_secret, relation) - data = dict(client_id=client_id, client_secret_id=secret.id) + data = {"client_id": client_id, "client_secret_id": secret.id} relation.data[self.model.app].update(_dump_data(data)) diff --git a/lib/charms/observability_libs/v0/cert_handler.py b/lib/charms/observability_libs/v0/cert_handler.py index 0fc610ff..275cf7db 100644 --- a/lib/charms/observability_libs/v0/cert_handler.py +++ b/lib/charms/observability_libs/v0/cert_handler.py @@ -40,25 +40,25 @@ from typing import List, Optional, Union, cast try: - from charms.tls_certificates_interface.v3.tls_certificates import ( # type: ignore + from charms.tls_certificates_interface.v2.tls_certificates import ( # type: ignore AllCertificatesInvalidatedEvent, CertificateAvailableEvent, CertificateExpiringEvent, CertificateInvalidatedEvent, - TLSCertificatesRequiresV3, + TLSCertificatesRequiresV2, generate_csr, generate_private_key, ) except ImportError as e: raise ImportError( - "failed to import charms.tls_certificates_interface.v3.tls_certificates; " + "failed to import charms.tls_certificates_interface.v2.tls_certificates; " "Either the library itself is missing (please get it through charmcraft fetch-lib) " "or one of its dependencies is unmet." ) from e import logging -from ops.charm import CharmBase, RelationBrokenEvent +from ops.charm import CharmBase from ops.framework import EventBase, EventSource, Object, ObjectEvents from ops.model import Relation @@ -67,7 +67,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 0 -LIBPATCH = 12 +LIBPATCH = 14 def is_ip_address(value: str) -> bool: @@ -132,7 +132,7 @@ def __init__( self.peer_relation_name = peer_relation_name self.certificates_relation_name = certificates_relation_name - self.certificates = TLSCertificatesRequiresV3(self.charm, self.certificates_relation_name) + self.certificates = TLSCertificatesRequiresV2(self.charm, self.certificates_relation_name) self.framework.observe( self.charm.on.config_changed, @@ -158,10 +158,6 @@ def __init__( self.certificates.on.all_certificates_invalidated, # pyright: ignore self._on_all_certificates_invalidated, ) - self.framework.observe( - self.charm.on[self.certificates_relation_name].relation_broken, # pyright: ignore - self._on_certificates_relation_broken, - ) # Peer relation events self.framework.observe( @@ -289,7 +285,7 @@ def _generate_csr( if clear_cert: self._ca_cert = "" self._server_cert = "" - self._chain = "" + self._chain = [] def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: """Get the certificate from the event and store it in a peer relation. @@ -311,7 +307,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: if event_csr == self._csr: self._ca_cert = event.ca self._server_cert = event.certificate - self._chain = event.chain_as_pem() + self._chain = event.chain self.on.cert_changed.emit() # pyright: ignore @property @@ -382,29 +378,21 @@ def _server_cert(self, value: str): rel.data[self.charm.unit].update({"certificate": value}) @property - def _chain(self) -> str: + def _chain(self) -> List[str]: if self._peer_relation: - if chain := self._peer_relation.data[self.charm.unit].get("chain", ""): - chain = json.loads(chain) - - # In a previous version of this lib, chain used to be a list. - # Convert the List[str] to str, per - # https://github.com/canonical/tls-certificates-interface/pull/141 - if isinstance(chain, list): - chain = "\n\n".join(reversed(chain)) - - return cast(str, chain) - return "" + if chain := self._peer_relation.data[self.charm.unit].get("chain", []): + return cast(list, json.loads(cast(str, chain))) + return [] @_chain.setter - def _chain(self, value: str): + def _chain(self, value: List[str]): # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. rel = self._peer_relation assert rel is not None # For type checker rel.data[self.charm.unit].update({"chain": json.dumps(value)}) @property - def chain(self) -> str: + def chain(self) -> List[str]: """Return the ca chain.""" return self._chain @@ -433,13 +421,11 @@ def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> Non self.on.cert_changed.emit() # pyright: ignore def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: - # Do what you want with this information, probably remove all certificates - # Note: assuming "limit: 1" in metadata - self._generate_csr(overwrite=True, clear_cert=True) - self.on.cert_changed.emit() # pyright: ignore - - def _on_certificates_relation_broken(self, event: RelationBrokenEvent) -> None: """Clear the certificates data when removing the relation.""" + # Note: assuming "limit: 1" in metadata + # The "certificates_relation_broken" event is converted to "all invalidated" custom + # event by the tls-certificates library. Per convention, we let the lib manage the + # relation and we do not observe "certificates_relation_broken" directly. if self._peer_relation: private_key = self._private_key # This is a workaround for https://bugs.launchpad.net/juju/+bug/2024583 @@ -447,4 +433,5 @@ def _on_certificates_relation_broken(self, event: RelationBrokenEvent) -> None: if private_key: self._peer_relation.data[self.charm.unit].update({"private_key": private_key}) + # We do not generate a CSR here because the relation is gone. self.on.cert_changed.emit() # pyright: ignore diff --git a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py index 2ab8a22c..34dd0264 100644 --- a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py +++ b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py @@ -4,7 +4,7 @@ """# KubernetesComputeResourcesPatch Library. This library is designed to enable developers to more simply patch the Kubernetes compute resource -limits and requests created by Juju during the deployment of a sidecar charm. +limits and requests created by Juju during the deployment of a charm. When initialised, this library binds a handler to the parent charm's `config-changed` event. The config-changed event is used because it is guaranteed to fire on startup, on upgrade and on @@ -76,6 +76,17 @@ def _resource_spec_from_config(self) -> ResourceRequirements: return ResourceRequirements(limits=spec, requests=spec) ``` +If you wish to pull the state of the resources patch operation and set the charm unit status based on that patch result, +you can achieve that using `get_status()` function. +```python +class SomeCharm(CharmBase): + def __init__(self, *args): + #... + self.framework.observe(self.on.collect_unit_status, self._on_collect_unit_status) + #... + def _on_collect_unit_status(self, event: CollectStatusEvent): + event.add_status(self.resources_patch.get_status()) +``` Additionally, you may wish to use mocks in your charm's unit testing to ensure that the library does not try to make any API calls, or open any files during testing that are unlikely to be @@ -83,12 +94,14 @@ def _resource_spec_from_config(self) -> ResourceRequirements: ```python # ... +from ops import ActiveStatus @patch.multiple( "charm.KubernetesComputeResourcesPatch", _namespace="test-namespace", _is_patched=lambda *a, **kw: True, is_ready=lambda *a, **kw: True, + get_status=lambda _: ActiveStatus(), ) @patch("lightkube.core.client.GenericSyncClient") def setUp(self, *unused): @@ -105,8 +118,9 @@ def setUp(self, *unused): import logging from decimal import Decimal from math import ceil, floor -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import tenacity from lightkube import ApiError, Client # pyright: ignore from lightkube.core import exceptions from lightkube.models.apps_v1 import StatefulSetSpec @@ -120,8 +134,10 @@ def setUp(self, *unused): from lightkube.resources.core_v1 import Pod from lightkube.types import PatchType from lightkube.utils.quantity import equals_canonically, parse_quantity +from ops import ActiveStatus, BlockedStatus, WaitingStatus from ops.charm import CharmBase from ops.framework import BoundEvent, EventBase, EventSource, Object, ObjectEvents +from ops.model import StatusBase logger = logging.getLogger(__name__) @@ -133,14 +149,16 @@ def setUp(self, *unused): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 _Decimal = Union[Decimal, float, str, int] # types that are potentially convertible to Decimal def adjust_resource_requirements( - limits: Optional[dict], requests: Optional[dict], adhere_to_requests: bool = True + limits: Optional[Dict[Any, Any]], + requests: Optional[Dict[Any, Any]], + adhere_to_requests: bool = True, ) -> ResourceRequirements: """Adjust resource limits so that `limits` and `requests` are consistent with each other. @@ -289,6 +307,18 @@ def sanitize_resource_spec_dict(spec: Optional[dict]) -> Optional[dict]: return d +def _retry_on_condition(exception): + """Retry if the exception is an ApiError with a status code != 403. + + Returns: a boolean value to indicate whether to retry or not. + """ + if isinstance(exception, ApiError) and str(exception.status.code) != "403": + return True + if isinstance(exception, exceptions.ConfigError) or isinstance(exception, ValueError): + return True + return False + + class K8sResourcePatchFailedEvent(EventBase): """Emitted when patching fails.""" @@ -385,27 +415,132 @@ def get_actual(self, pod_name: str) -> Optional[ResourceRequirements]: ) return podspec.resources + def is_failed( + self, resource_reqs_func: Callable[[], ResourceRequirements] + ) -> Tuple[bool, str]: + """Returns a tuple indicating whether a patch operation has failed along with a failure message. + + Implementation is based on dry running the patch operation to catch if there would be failures (e.g: Wrong spec and Auth errors). + """ + try: + resource_reqs = resource_reqs_func() + limits = resource_reqs.limits + requests = resource_reqs.requests + except ValueError as e: + msg = f"Failed obtaining resource limit spec: {e}" + logger.error(msg) + return True, msg + + # Dry run does not catch negative values for resource requests and limits. + if not is_valid_spec(limits) or not is_valid_spec(requests): + msg = f"Invalid resource requirements specs: {limits}, {requests}" + logger.error(msg) + return True, msg + + resource_reqs = ResourceRequirements( + limits=sanitize_resource_spec_dict(limits), # type: ignore[arg-type] + requests=sanitize_resource_spec_dict(requests), # type: ignore[arg-type] + ) + + try: + self.apply(resource_reqs, dry_run=True) + except ApiError as e: + if e.status.code == 403: + msg = f"Kubernetes resources patch failed: `juju trust` this application. {e}" + else: + msg = f"Kubernetes resources patch failed: {e}" + return True, msg + except ValueError as e: + msg = f"Kubernetes resources patch failed: {e}" + return True, msg + + return False, "" + + def is_in_progress(self) -> bool: + """Returns a boolean to indicate whether a patch operation is in progress. + + Implementation follows a similar approach to `kubectl rollout status statefulset` to track the progress of a rollout. + Reference: https://github.com/kubernetes/kubectl/blob/kubernetes-1.31.0/pkg/polymorphichelpers/rollout_status.go + """ + try: + sts = self.client.get( + StatefulSet, name=self.statefulset_name, namespace=self.namespace + ) + except (ValueError, ApiError) as e: + # Assumption: if there was a persistent issue, it'd have been caught in `is_failed` + # Wait until next run to try again. + logger.error(f"Failed to fetch statefulset from K8s api: {e}") + return False + + if sts.status is None or sts.spec is None: + logger.debug("status/spec are not yet available") + return False + if sts.status.observedGeneration == 0 or ( + sts.metadata + and sts.status.observedGeneration + and sts.metadata.generation + and sts.metadata.generation > sts.status.observedGeneration + ): + logger.debug("waiting for statefulset spec update to be observed...") + return True + if ( + sts.spec.replicas is not None + and sts.status.readyReplicas is not None + and sts.status.readyReplicas < sts.spec.replicas + ): + logger.debug( + f"Waiting for {sts.spec.replicas-sts.status.readyReplicas} pods to be ready..." + ) + return True + + if ( + sts.spec.updateStrategy + and sts.spec.updateStrategy.type == "rollingUpdate" + and sts.spec.updateStrategy.rollingUpdate is not None + ): + if ( + sts.spec.replicas is not None + and sts.spec.updateStrategy.rollingUpdate.partition is not None + ): + if sts.status.updatedReplicas and sts.status.updatedReplicas < ( + sts.spec.replicas - sts.spec.updateStrategy.rollingUpdate.partition + ): + logger.debug( + f"Waiting for partitioned roll out to finish: {sts.status.updatedReplicas} out of {sts.spec.replicas - sts.spec.updateStrategy.rollingUpdate.partition} new pods have been updated..." + ) + return True + logger.debug( + f"partitioned roll out complete: {sts.status.updatedReplicas} new pods have been updated..." + ) + return False + + if sts.status.updateRevision != sts.status.currentRevision: + logger.debug( + f"waiting for statefulset rolling update to complete {sts.status.updatedReplicas} pods at revision {sts.status.updateRevision}..." + ) + return True + + logger.debug( + f"statefulset rolling update complete pods at revision {sts.status.currentRevision}" + ) + return False + def is_ready(self, pod_name, resource_reqs: ResourceRequirements): """Reports if the resource patch has been applied and is in effect. Returns: bool: A boolean indicating if the service patch has been applied and is in effect. """ - logger.info( - "reqs=%s, templated=%s, actual=%s", - resource_reqs, - self.get_templated(), - self.get_actual(pod_name), - ) return self.is_patched(resource_reqs) and equals_canonically( # pyright: ignore resource_reqs, self.get_actual(pod_name) # pyright: ignore ) - def apply(self, resource_reqs: ResourceRequirements) -> None: + def apply(self, resource_reqs: ResourceRequirements, dry_run=False) -> None: """Patch the Kubernetes resources created by Juju to limit cpu or mem.""" # Need to ignore invalid input, otherwise the StatefulSet gives "FailedCreate" and the # charm would be stuck in unknown/lost. - if self.is_patched(resource_reqs): + if not dry_run and self.is_patched(resource_reqs): + logger.debug(f"Resource requests are already patched: {resource_reqs}") return self.client.patch( @@ -415,6 +550,7 @@ def apply(self, resource_reqs: ResourceRequirements) -> None: namespace=self.namespace, patch_type=PatchType.APPLY, field_manager=self.__class__.__name__, + dry_run=dry_run, ) @@ -422,6 +558,9 @@ class KubernetesComputeResourcesPatch(Object): """A utility for patching the Kubernetes compute resources set up by Juju.""" on = K8sResourcePatchEvents() # pyright: ignore + PATCH_RETRY_STOP = tenacity.stop_after_delay(20) + PATCH_RETRY_WAIT = tenacity.wait_fixed(5) + PATCH_RETRY_IF = tenacity.retry_if_exception(_retry_on_condition) def __init__( self, @@ -468,7 +607,11 @@ def _on_config_changed(self, _): self._patch() def _patch(self) -> None: - """Patch the Kubernetes resources created by Juju to limit cpu or mem.""" + """Patch the Kubernetes resources created by Juju to limit cpu or mem. + + This method will keep on retrying to patch the kubernetes resource for a default duration of 20 seconds + if the patching failure is due to a recoverable error (e.g: Network Latency). + """ try: resource_reqs = self.resource_reqs_func() limits = resource_reqs.limits @@ -492,7 +635,18 @@ def _patch(self) -> None: ) try: - self.patcher.apply(resource_reqs) + for attempt in tenacity.Retrying( + retry=self.PATCH_RETRY_IF, + stop=self.PATCH_RETRY_STOP, + wait=self.PATCH_RETRY_WAIT, + # if you don't succeed raise the last caught exception when you're done + reraise=True, + ): + with attempt: + logger.debug( + f"attempt #{attempt.retry_state.attempt_number} to patch resource limits" + ) + self.patcher.apply(resource_reqs) except exceptions.ConfigError as e: msg = f"Error creating k8s client: {e}" @@ -503,6 +657,7 @@ def _patch(self) -> None: except ApiError as e: if e.status.code == 403: msg = f"Kubernetes resources patch failed: `juju trust` this application. {e}" + else: msg = f"Kubernetes resources patch failed: {e}" @@ -554,6 +709,29 @@ def is_ready(self) -> bool: self.on.patch_failed.emit(message=msg) return False + def get_status(self) -> StatusBase: + """Return the status of patching the resource limits in a `StatusBase` format. + + Returns: + StatusBase: There is a 1:1 mapping between the state of the patching operation and a `StatusBase` value that the charm can be set to. + Possible values are: + - ActiveStatus: The patch was applied successfully. + - BlockedStatus: The patch failed and requires a human intervention. + - WaitingStatus: The patch is still in progress. + + Example: + - ActiveStatus("Patch applied successfully") + - BlockedStatus("Failed due to missing permissions") + - WaitingStatus("Patch is in progress") + """ + failed, msg = self.patcher.is_failed(self.resource_reqs_func) + if failed: + return BlockedStatus(msg) + if self.patcher.is_in_progress(): + return WaitingStatus("waiting for resources patch to apply") + # patch successful or nothing has been patched yet + return ActiveStatus() + @property def _app(self) -> str: """Name of the current Juju application. diff --git a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py index be967686..e3d35c6f 100644 --- a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py +++ b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py @@ -178,7 +178,7 @@ def __init__(self, *args): - `scrape_timeout` - `proxy_url` - `relabel_configs` -- `metrics_relabel_configs` +- `metric_relabel_configs` - `sample_limit` - `label_limit` - `label_name_length_limit` @@ -362,7 +362,7 @@ def _on_scrape_targets_changed(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 46 +LIBPATCH = 47 PYDEPS = ["cosl"] @@ -377,7 +377,7 @@ def _on_scrape_targets_changed(self, event): "scrape_timeout", "proxy_url", "relabel_configs", - "metrics_relabel_configs", + "metric_relabel_configs", "sample_limit", "label_limit", "label_name_length_limit", diff --git a/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py b/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py index 1e7ff840..cf8def11 100644 --- a/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py +++ b/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py @@ -69,6 +69,9 @@ def my_tracing_endpoint(self) -> Optional[str]: - every event as a span (including custom events) - every charm method call (except dunders) as a span +We recommend that you scale up your tracing provider and relate it to an ingress so that your tracing requests +go through the ingress and get load balanced across all units. Otherwise, if the provider's leader goes down, your tracing goes down. + ## TLS support If your charm integrates with a TLS provider which is also trusted by the tracing provider (the Tempo charm), @@ -114,6 +117,57 @@ def get_tracer(self) -> opentelemetry.trace.Tracer: See the official opentelemetry Python SDK documentation for usage: https://opentelemetry-python.readthedocs.io/en/latest/ + +## Caching traces +The `trace_charm` machinery will buffer any traces collected during charm execution and store them +to a file on the charm container until a tracing backend becomes available. At that point, it will +flush them to the tracing receiver. + +By default, the buffer is configured to start dropping old traces if any of these conditions apply: + +- the storage size exceeds 10 MiB +- the number of buffered events exceeds 100 + +You can configure this by, for example: + +```python +@trace_charm( + tracing_endpoint="my_tracing_endpoint", + server_cert="_server_cert", + # only cache up to 42 events + buffer_max_events=42, + # only cache up to 42 MiB + buffer_max_size_mib=42, # minimum 10! +) +class MyCharm(CharmBase): + ... +``` + +Note that setting `buffer_max_events` to 0 will effectively disable the buffer. + +The path of the buffer file is by default in the charm's execution root, which for k8s charms means +that in case of pod churn, the cache will be lost. The recommended solution is to use an existing storage +(or add a new one) such as: + +```yaml +storage: + data: + type: filesystem + location: /charm-traces +``` + +and then configure the `@trace_charm` decorator to use it as path for storing the buffer: +```python +@trace_charm( + tracing_endpoint="my_tracing_endpoint", + server_cert="_server_cert", + # store traces to a PVC so they're not lost on pod restart. + buffer_path="/charm-traces/buffer.file", +) +class MyCharm(CharmBase): + ... +``` + ## Upgrading from `v0` If you are upgrading from `charm_tracing` v0, you need to take the following steps (assuming you already @@ -171,6 +225,12 @@ def my_tracing_endpoint(self) -> Optional[str]: 3) If you were passing a certificate (str) using `server_cert`, you need to change it to provide an *absolute* path to the certificate file instead. """ +import typing + +from opentelemetry.exporter.otlp.proto.common._internal.trace_encoder import ( + encode_spans, +) +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter def _remove_stale_otel_sdk_packages(): @@ -222,6 +282,9 @@ def _remove_stale_otel_sdk_packages(): otel_logger.debug("Successfully applied _remove_stale_otel_sdk_packages patch. ") +# apply hacky patch to remove stale opentelemetry sdk packages on upgrade-charm. +# it could be trouble if someone ever decides to implement their own tracer parallel to +# ours and before the charm has inited. We assume they won't. _remove_stale_otel_sdk_packages() import functools @@ -235,6 +298,7 @@ def _remove_stale_otel_sdk_packages(): Any, Callable, Generator, + List, Optional, Sequence, Type, @@ -247,8 +311,12 @@ def _remove_stale_otel_sdk_packages(): import ops from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import Span, TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace import ReadableSpan, Span, TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, + SpanExporter, + SpanExportResult, +) from opentelemetry.trace import INVALID_SPAN, Tracer from opentelemetry.trace import get_current_span as otlp_get_current_span from opentelemetry.trace import ( @@ -269,7 +337,7 @@ def _remove_stale_otel_sdk_packages(): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 4 PYDEPS = ["opentelemetry-exporter-otlp-proto-http==1.21.0"] @@ -277,7 +345,7 @@ def _remove_stale_otel_sdk_packages(): dev_logger = logging.getLogger("tracing-dev") # set this to 0 if you are debugging/developing this library source -dev_logger.setLevel(logging.CRITICAL) +dev_logger.setLevel(logging.ERROR) _CharmType = Type[CharmBase] # the type CharmBase and any subclass thereof _C = TypeVar("_C", bound=_CharmType) @@ -287,6 +355,186 @@ def _remove_stale_otel_sdk_packages(): _GetterType = Union[Callable[[_CharmType], Optional[str]], property] CHARM_TRACING_ENABLED = "CHARM_TRACING_ENABLED" +BUFFER_DEFAULT_CACHE_FILE_NAME = ".charm_tracing_buffer.raw" +# we store the buffer as raw otlp-native protobuf (bytes) since it's hard to serialize/deserialize it in +# any portable format. Json dumping is supported, but loading isn't. +# cfr: https://github.com/open-telemetry/opentelemetry-python/issues/1003 + +BUFFER_DEFAULT_CACHE_FILE_SIZE_LIMIT_MiB = 10 +_BUFFER_CACHE_FILE_SIZE_LIMIT_MiB_MIN = 10 +BUFFER_DEFAULT_MAX_EVENT_HISTORY_LENGTH = 100 +_MiB_TO_B = 2**20 # megabyte to byte conversion rate +_OTLP_SPAN_EXPORTER_TIMEOUT = 1 +"""Timeout in seconds that the OTLP span exporter has to push traces to the backend.""" + + +class _Buffer: + """Handles buffering for spans emitted while no tracing backend is configured or available. + + Use the max_event_history_length_buffering param of @trace_charm to tune + the amount of memory that this will hog on your units. + + The buffer is formatted as a bespoke byte dump (protobuf limitation). + We cannot store them as json because that is not well-supported by the sdk + (see https://github.com/open-telemetry/opentelemetry-python/issues/3364). + """ + + _SPANSEP = b"__CHARM_TRACING_BUFFER_SPAN_SEP__" + + def __init__(self, db_file: Path, max_event_history_length: int, max_buffer_size_mib: int): + self._db_file = db_file + self._max_event_history_length = max_event_history_length + self._max_buffer_size_mib = max(max_buffer_size_mib, _BUFFER_CACHE_FILE_SIZE_LIMIT_MiB_MIN) + + # set by caller + self.exporter: Optional[OTLPSpanExporter] = None + + def save(self, spans: typing.Sequence[ReadableSpan]): + """Save the spans collected by this exporter to the cache file. + + This method should be as fail-safe as possible. + """ + if self._max_event_history_length < 1: + dev_logger.debug("buffer disabled: max history length < 1") + return + + current_history_length = len(self.load()) + new_history_length = current_history_length + len(spans) + if (diff := self._max_event_history_length - new_history_length) < 0: + self.drop(diff) + self._save(spans) + + def _serialize(self, spans: Sequence[ReadableSpan]) -> bytes: + # encode because otherwise we can't json-dump them + return encode_spans(spans).SerializeToString() + + def _save(self, spans: Sequence[ReadableSpan], replace: bool = False): + dev_logger.debug(f"saving {len(spans)} new spans to buffer") + old = [] if replace else self.load() + new = self._serialize(spans) + + try: + # if the buffer exceeds the size limit, we start dropping old spans until it does + + while len((new + self._SPANSEP.join(old))) > (self._max_buffer_size_mib * _MiB_TO_B): + if not old: + # if we've already dropped all spans and still we can't get under the + # size limit, we can't save this span + logger.error( + f"span exceeds total buffer size limit ({self._max_buffer_size_mib}MiB); " + f"buffering FAILED" + ) + return + + old = old[1:] + logger.warning( + f"buffer size exceeds {self._max_buffer_size_mib}MiB; dropping older spans... " + f"Please increase the buffer size, disable buffering, or ensure the spans can be flushed." + ) + + self._db_file.write_bytes(new + self._SPANSEP.join(old)) + except Exception: + logger.exception("error buffering spans") + + def load(self) -> List[bytes]: + """Load currently buffered spans from the cache file. + + This method should be as fail-safe as possible. + """ + if not self._db_file.exists(): + dev_logger.debug("buffer file not found. buffer empty.") + return [] + try: + spans = self._db_file.read_bytes().split(self._SPANSEP) + except Exception: + logger.exception(f"error parsing {self._db_file}") + return [] + return spans + + def drop(self, n_spans: Optional[int] = None): + """Drop some currently buffered spans from the cache file.""" + current = self.load() + if n_spans: + dev_logger.debug(f"dropping {n_spans} spans from buffer") + new = current[n_spans:] + else: + dev_logger.debug("emptying buffer") + new = [] + + self._db_file.write_bytes(self._SPANSEP.join(new)) + + def flush(self) -> Optional[bool]: + """Export all buffered spans to the given exporter, then clear the buffer. + + Returns whether the flush was successful, and None if there was nothing to flush. + """ + if not self.exporter: + dev_logger.debug("no exporter set; skipping buffer flush") + return False + + buffered_spans = self.load() + if not buffered_spans: + dev_logger.debug("nothing to flush; buffer empty") + return None + + errors = False + for span in buffered_spans: + try: + out = self.exporter._export(span) # type: ignore + if not (200 <= out.status_code < 300): + # take any 2xx status code as a success + errors = True + except ConnectionError: + dev_logger.debug( + "failed exporting buffered span; backend might be down or still starting" + ) + errors = True + except Exception: + logger.exception("unexpected error while flushing span batch from buffer") + errors = True + + if not errors: + self.drop() + else: + logger.error("failed flushing spans; buffer preserved") + return not errors + + @property + def is_empty(self): + """Utility to check whether the buffer has any stored spans. + + This is more efficient than attempting a load() given how large the buffer might be. + """ + return (not self._db_file.exists()) or (self._db_file.stat().st_size == 0) + + +class _OTLPSpanExporter(OTLPSpanExporter): + """Subclass of OTLPSpanExporter to configure the max retry timeout, so that it fails a bit faster.""" + + # The issue we're trying to solve is that the model takes AGES to settle if e.g. tls is misconfigured, + # as every hook of a charm_tracing-instrumented charm takes about a minute to exit, as the charm can't + # flush the traces and keeps retrying for 'too long' + + _MAX_RETRY_TIMEOUT = 4 + # we give the exporter 4 seconds in total to succeed pushing the traces to tempo + # if it fails, we'll be caching the data in the buffer and flush it the next time, so there's no data loss risk. + # this means 2/3 retries (hard to guess from the implementation) and up to ~7 seconds total wait + + +class _BufferedExporter(InMemorySpanExporter): + def __init__(self, buffer: _Buffer) -> None: + super().__init__() + self._buffer = buffer + + def export(self, spans: typing.Sequence[ReadableSpan]) -> SpanExportResult: + self._buffer.save(spans) + return super().export(spans) + + def force_flush(self, timeout_millis: int = 0) -> bool: + # parent implementation is fake, so the timeout_millis arg is not doing anything. + result = super().force_flush(timeout_millis) + self._buffer.save(self.get_finished_spans()) + return result def is_enabled() -> bool: @@ -423,7 +671,10 @@ def _setup_root_span_initializer( charm_type: _CharmType, tracing_endpoint_attr: str, server_cert_attr: Optional[str], - service_name: Optional[str] = None, + service_name: Optional[str], + buffer_path: Optional[Path], + buffer_max_events: int, + buffer_max_size_mib: int, ): """Patch the charm's initializer.""" original_init = charm_type.__init__ @@ -442,18 +693,11 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): logger.info("Tracing DISABLED: skipping root span initialization") return - # already init some attrs that will be reinited later by calling original_init: - # self.framework = framework - # self.handle = Handle(None, self.handle_kind, None) - original_event_context = framework._event_context # default service name isn't just app name because it could conflict with the workload service name _service_name = service_name or f"{self.app.name}-charm" unit_name = self.unit.name - # apply hacky patch to remove stale opentelemetry sdk packages on upgrade-charm. - # it could be trouble if someone ever decides to implement their own tracer parallel to - # ours and before the charm has inited. We assume they won't. resource = Resource.create( attributes={ "service.name": _service_name, @@ -471,33 +715,60 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): # if anything goes wrong with retrieving the endpoint, we let the exception bubble up. tracing_endpoint = _get_tracing_endpoint(tracing_endpoint_attr, self, charm_type) + buffer_only = False + # whether we're only exporting to buffer, or also to the otlp exporter. + if not tracing_endpoint: # tracing is off if tracing_endpoint is None - return + # however we can buffer things until tracing comes online + buffer_only = True server_cert: Optional[Union[str, Path]] = ( _get_server_cert(server_cert_attr, self, charm_type) if server_cert_attr else None ) - if tracing_endpoint.startswith("https://") and not server_cert: + if (tracing_endpoint and tracing_endpoint.startswith("https://")) and not server_cert: logger.error( "Tracing endpoint is https, but no server_cert has been passed." "Please point @trace_charm to a `server_cert` attr. " "This might also mean that the tracing provider is related to a " "certificates provider, but this application is not (yet). " "In that case, you might just have to wait a bit for the certificates " - "integration to settle. " + "integration to settle. This span will be buffered." ) - return + buffer_only = True - exporter = OTLPSpanExporter( - endpoint=tracing_endpoint, - certificate_file=str(Path(server_cert).absolute()) if server_cert else None, - timeout=2, + buffer = _Buffer( + db_file=buffer_path or Path() / BUFFER_DEFAULT_CACHE_FILE_NAME, + max_event_history_length=buffer_max_events, + max_buffer_size_mib=buffer_max_size_mib, ) + previous_spans_buffered = not buffer.is_empty + + exporters: List[SpanExporter] = [] + if buffer_only: + # we have to buffer because we're missing necessary backend configuration + dev_logger.debug("buffering mode: ON") + exporters.append(_BufferedExporter(buffer)) + + else: + dev_logger.debug("buffering mode: FALLBACK") + # in principle, we have the right configuration to be pushing traces, + # but if we fail for whatever reason, we will put everything in the buffer + # and retry the next time + otlp_exporter = _OTLPSpanExporter( + endpoint=tracing_endpoint, + certificate_file=str(Path(server_cert).absolute()) if server_cert else None, + timeout=_OTLP_SPAN_EXPORTER_TIMEOUT, # give individual requests 1 second to succeed + ) + exporters.append(otlp_exporter) + exporters.append(_BufferedExporter(buffer)) + buffer.exporter = otlp_exporter + + for exporter in exporters: + processor = BatchSpanProcessor(exporter) + provider.add_span_processor(processor) - processor = BatchSpanProcessor(exporter) - provider.add_span_processor(processor) set_tracer_provider(provider) _tracer = get_tracer(_service_name) # type: ignore _tracer_token = tracer.set(_tracer) @@ -521,7 +792,7 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): @contextmanager def wrap_event_context(event_name: str): - dev_logger.info(f"entering event context: {event_name}") + dev_logger.debug(f"entering event context: {event_name}") # when the framework enters an event context, we create a span. with _span("event: " + event_name) as event_context_span: if event_context_span: @@ -535,12 +806,50 @@ def wrap_event_context(event_name: str): @functools.wraps(original_close) def wrap_close(): - dev_logger.info("tearing down tracer and flushing traces") + dev_logger.debug("tearing down tracer and flushing traces") span.end() opentelemetry.context.detach(span_token) # type: ignore tracer.reset(_tracer_token) tp = cast(TracerProvider, get_tracer_provider()) - tp.force_flush(timeout_millis=1000) # don't block for too long + flush_successful = tp.force_flush(timeout_millis=1000) # don't block for too long + + if buffer_only: + # if we're in buffer_only mode, it means we couldn't even set up the exporter for + # tempo as we're missing some data. + # so attempting to flush the buffer doesn't make sense + dev_logger.debug("tracing backend unavailable: all spans pushed to buffer") + + else: + dev_logger.debug("tracing backend found: attempting to flush buffer...") + + # if we do have an exporter for tempo, and we could send traces to it, + # we can attempt to flush the buffer as well. + if not flush_successful: + logger.error("flushing FAILED: unable to push traces to backend.") + else: + dev_logger.debug("flush succeeded.") + + # the backend has accepted the spans generated during this event, + if not previous_spans_buffered: + # if the buffer was empty to begin with, any spans we collected now can be discarded + buffer.drop() + dev_logger.debug("buffer dropped: this trace has been sent already") + else: + # if the buffer was nonempty, we can attempt to flush it + dev_logger.debug("attempting buffer flush...") + buffer_flush_successful = buffer.flush() + if buffer_flush_successful: + dev_logger.debug("buffer flush OK") + elif buffer_flush_successful is None: + # TODO is this even possible? + dev_logger.debug("buffer flush OK; empty: nothing to flush") + else: + # this situation is pretty weird, I'm not even sure it can happen, + # because it would mean that we did manage + # to push traces directly to the tempo exporter (flush_successful), + # but the buffer flush failed to push to the same exporter! + logger.error("buffer flush FAILED") + tp.shutdown() original_close() @@ -555,6 +864,9 @@ def trace_charm( server_cert: Optional[str] = None, service_name: Optional[str] = None, extra_types: Sequence[type] = (), + buffer_max_events: int = BUFFER_DEFAULT_MAX_EVENT_HISTORY_LENGTH, + buffer_max_size_mib: int = BUFFER_DEFAULT_CACHE_FILE_SIZE_LIMIT_MiB, + buffer_path: Optional[Union[str, Path]] = None, ) -> Callable[[_T], _T]: """Autoinstrument the decorated charm with tracing telemetry. @@ -596,6 +908,10 @@ def trace_charm( Defaults to the juju application name this charm is deployed under. :param extra_types: pass any number of types that you also wish to autoinstrument. For example, charm libs, relation endpoint wrappers, workload abstractions, ... + :param buffer_max_events: max number of events to save in the buffer. Set to 0 to disable buffering. + :param buffer_max_size_mib: max size of the buffer file. When exceeded, spans will be dropped. + Minimum 10MiB. + :param buffer_path: path to buffer file to use for saving buffered spans. """ def _decorator(charm_type: _T) -> _T: @@ -606,6 +922,9 @@ def _decorator(charm_type: _T) -> _T: server_cert_attr=server_cert, service_name=service_name, extra_types=extra_types, + buffer_path=Path(buffer_path) if buffer_path else None, + buffer_max_size_mib=buffer_max_size_mib, + buffer_max_events=buffer_max_events, ) return charm_type @@ -618,6 +937,9 @@ def _autoinstrument( server_cert_attr: Optional[str] = None, service_name: Optional[str] = None, extra_types: Sequence[type] = (), + buffer_max_events: int = BUFFER_DEFAULT_MAX_EVENT_HISTORY_LENGTH, + buffer_max_size_mib: int = BUFFER_DEFAULT_CACHE_FILE_SIZE_LIMIT_MiB, + buffer_path: Optional[Path] = None, ) -> _T: """Set up tracing on this charm class. @@ -650,13 +972,20 @@ def _autoinstrument( Defaults to the juju application name this charm is deployed under. :param extra_types: pass any number of types that you also wish to autoinstrument. For example, charm libs, relation endpoint wrappers, workload abstractions, ... + :param buffer_max_events: max number of events to save in the buffer. Set to 0 to disable buffering. + :param buffer_max_size_mib: max size of the buffer file. When exceeded, spans will be dropped. + Minimum 10MiB. + :param buffer_path: path to buffer file to use for saving buffered spans. """ - dev_logger.info(f"instrumenting {charm_type}") + dev_logger.debug(f"instrumenting {charm_type}") _setup_root_span_initializer( charm_type, tracing_endpoint_attr, server_cert_attr=server_cert_attr, service_name=service_name, + buffer_path=buffer_path, + buffer_max_events=buffer_max_events, + buffer_max_size_mib=buffer_max_size_mib, ) trace_type(charm_type) for type_ in extra_types: @@ -672,12 +1001,12 @@ def trace_type(cls: _T) -> _T: It assumes that this class is only instantiated after a charm type decorated with `@trace_charm` has been instantiated. """ - dev_logger.info(f"instrumenting {cls}") + dev_logger.debug(f"instrumenting {cls}") for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): - dev_logger.info(f"discovered {method}") + dev_logger.debug(f"discovered {method}") if method.__name__.startswith("__"): - dev_logger.info(f"skipping {method} (dunder)") + dev_logger.debug(f"skipping {method} (dunder)") continue # the span title in the general case should be: @@ -723,7 +1052,7 @@ def trace_function(function: _F, name: Optional[str] = None) -> _F: def _trace_callable(callable: _F, qualifier: str, name: Optional[str] = None) -> _F: - dev_logger.info(f"instrumenting {callable}") + dev_logger.debug(f"instrumenting {callable}") # sig = inspect.signature(callable) @functools.wraps(callable) diff --git a/lib/charms/tempo_coordinator_k8s/v0/tracing.py b/lib/charms/tempo_coordinator_k8s/v0/tracing.py index 1f92867f..2035dffd 100644 --- a/lib/charms/tempo_coordinator_k8s/v0/tracing.py +++ b/lib/charms/tempo_coordinator_k8s/v0/tracing.py @@ -34,7 +34,7 @@ def __init__(self, *args): `TracingEndpointRequirer.request_protocols(*protocol:str, relation:Optional[Relation])` method. Using this method also allows you to use per-relation protocols. -Units of provider charms obtain the tempo endpoint to which they will push their traces by calling +Units of requirer charms obtain the tempo endpoint to which they will push their traces by calling `TracingEndpointRequirer.get_endpoint(protocol: str)`, where `protocol` is, for example: - `otlp_grpc` - `otlp_http` @@ -44,7 +44,10 @@ def __init__(self, *args): If the `protocol` is not in the list of protocols that the charm requested at endpoint set-up time, the library will raise an error. -## Requirer Library Usage +We recommend that you scale up your tracing provider and relate it to an ingress so that your tracing requests +go through the ingress and get load balanced across all units. Otherwise, if the provider's leader goes down, your tracing goes down. + +## Provider Library Usage The `TracingEndpointProvider` object may be used by charms to manage relations with their trace sources. For this purposes a Tempo-like charm needs to do two things @@ -107,7 +110,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 PYDEPS = ["pydantic"] diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py index cbdd80d1..141412b0 100644 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -111,6 +111,7 @@ def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> Non ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, + recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -276,13 +277,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven """ # noqa: D405, D410, D411, D214, D416 import copy +import ipaddress import json import logging import uuid from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from ipaddress import IPv4Address from typing import List, Literal, Optional, Union from cryptography import x509 @@ -304,6 +305,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven ModelError, Relation, RelationDataContent, + Secret, SecretNotFoundError, Unit, ) @@ -316,7 +318,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 10 +LIBPATCH = 23 PYDEPS = ["cryptography", "jsonschema"] @@ -453,11 +455,35 @@ class ProviderCertificate: ca: str chain: List[str] revoked: bool + expiry_time: datetime + expiry_notification_time: Optional[datetime] = None def chain_as_pem(self) -> str: """Return full certificate chain as a PEM string.""" return "\n\n".join(reversed(self.chain)) + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "relation_id": self.relation_id, + "application_name": self.application_name, + "csr": self.csr, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + "revoked": self.revoked, + "expiry_time": self.expiry_time.isoformat(), + "expiry_notification_time": self.expiry_notification_time.isoformat() + if self.expiry_notification_time + else None, + } + ) + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -682,21 +708,49 @@ def _get_closest_future_time( ) -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], + requirer_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider and by the requirer. + Time recommended by the provider is preferred, + then time recommended by the requirer, + then dynamically calculated time. Args: - certificate (str): x509 certificate as a string + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + requirer_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the requirer. Returns: - Optional[datetime]: Expiry datetime or None + datetime: Time to notify the user about the certificate expiry. """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after_utc - except ValueError: - logger.warning("Could not load certificate.") - return None + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = expiry_time - timedelta( + hours=provider_recommended_notification_time + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + + if requirer_recommended_notification_time is not None: + requirer_recommended_notification_time = abs(requirer_recommended_notification_time) + requirer_recommendation_time_delta = expiry_time - timedelta( + hours=requirer_recommended_notification_time + ) + if validity_start_time < requirer_recommendation_time_delta: + return requirer_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) def generate_ca( @@ -965,6 +1019,8 @@ def generate_csr( # noqa: C901 organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -983,6 +1039,8 @@ def generate_csr( # noqa: C901 organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. + state_or_province_name (str): State or Province Name. + locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) @@ -1008,13 +1066,19 @@ def generate_csr( # noqa: C901 subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] if sans_oid: _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: - _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) if sans: _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: @@ -1030,6 +1094,13 @@ def generate_csr( # noqa: C901 return signed_certificate.public_bytes(serialization.Encoding.PEM) +def get_sha256_hex(data: str) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(data.encode()) + return digest.finalize().hex() + + def csr_matches_certificate(csr: str, cert: str) -> bool: """Check if a CSR matches a certificate. @@ -1039,25 +1110,16 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: Returns: bool: True/False depending on whether the CSR matches the certificate. """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if ( - csr_object.public_key().public_numbers().n # type: ignore[union-attr] - != cert_object.public_key().public_numbers().n # type: ignore[union-attr] - ): - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): return False return True @@ -1135,6 +1197,7 @@ def _add_certificate( certificate_signing_request: str, ca: str, chain: List[str], + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificate to relation data. @@ -1144,6 +1207,8 @@ def _add_certificate( certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain + recommended_expiry_notification_time (int): + Time in hours before the certificate expires to notify the user. Returns: None @@ -1161,6 +1226,7 @@ def _add_certificate( "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, + "recommended_expiry_notification_time": recommended_expiry_notification_time, } provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) @@ -1227,6 +1293,7 @@ def set_relation_certificate( ca: str, chain: List[str], relation_id: int, + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificates to relation data. @@ -1236,6 +1303,8 @@ def set_relation_certificate( ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID + recommended_expiry_notification_time (int): + Recommended time in hours before the certificate expires to notify the user. Returns: None @@ -1257,6 +1326,7 @@ def set_relation_certificate( certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], + recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: @@ -1310,6 +1380,13 @@ def get_provider_certificates( provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) for certificate in provider_certificates: + try: + certificate_object = x509.load_pem_x509_certificate( + data=certificate["certificate"].encode() + ) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue provider_certificate = ProviderCertificate( relation_id=relation.id, application_name=relation.app.name, @@ -1318,6 +1395,10 @@ def get_provider_certificates( ca=certificate["ca"], chain=certificate["chain"], revoked=certificate.get("revoked", False), + expiry_time=certificate_object.not_valid_after_utc, + expiry_notification_time=certificate.get( + "recommended_expiry_notification_time" + ), ) certificates.append(provider_certificate) return certificates @@ -1368,18 +1449,31 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None Returns: None """ - provider_certificates = self.get_provider_certificates(relation_id) - requirer_csrs = self.get_requirer_csrs(relation_id) + provider_certificates = self.get_unsolicited_certificates(relation_id=relation_id) + for provider_certificate in provider_certificates: + self.on.certificate_revocation_request.emit( + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + self.remove_certificate(certificate=provider_certificate.certificate) + + def get_unsolicited_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return provider certificates for which no certificate requests exists. + + Those certificates should be revoked. + """ + unsolicited_certificates: List[ProviderCertificate] = [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) list_of_csrs = [csr.csr for csr in requirer_csrs] for certificate in provider_certificates: if certificate.csr not in list_of_csrs: - self.on.certificate_revocation_request.emit( - certificate=certificate.certificate, - certificate_signing_request=certificate.csr, - ca=certificate.ca, - chain=certificate.chain, - ) - self.remove_certificate(certificate=certificate.certificate) + unsolicited_certificates.append(certificate) + return unsolicited_certificates def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None @@ -1475,15 +1569,17 @@ def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: int = 168, + expiry_notification_time: Optional[int] = None, ): """Generate/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Time difference between now and expiry (in hours). - Used to trigger the CertificateExpiring event. Default: 7 days. + expiry_notification_time (int): Number of hours prior to certificate expiry. + Used to trigger the CertificateExpiring event. + This value is used as a recommendation only, + The actual value is calculated taking into account the provider's recommendation. """ super().__init__(charm, relationship_name) if not JujuVersion.from_environ().has_secrets: @@ -1544,9 +1640,25 @@ def get_provider_certificates(self) -> List[ProviderCertificate]: if not certificate: logger.warning("No certificate found in relation data - Skipping") continue + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue ca = provider_certificate_dict.get("ca") chain = provider_certificate_dict.get("chain", []) csr = provider_certificate_dict.get("certificate_signing_request") + recommended_expiry_notification_time = provider_certificate_dict.get( + "recommended_expiry_notification_time" + ) + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=validity_start_time, + expiry_time=expiry_time, + provider_recommended_notification_time=recommended_expiry_notification_time, + requirer_recommended_notification_time=self.expiry_notification_time, + ) if not csr: logger.warning("No CSR found in relation data - Skipping") continue @@ -1559,6 +1671,8 @@ def get_provider_certificates(self) -> List[ProviderCertificate]: ca=ca, chain=chain, revoked=revoked, + expiry_time=expiry_time, + expiry_notification_time=expiry_notification_time, ) provider_certificates.append(provider_certificate) return provider_certificates @@ -1708,13 +1822,9 @@ def get_expiring_certificates(self) -> List[ProviderCertificate]: expiring_certificates: List[ProviderCertificate] = [] for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - expiry_time = _get_certificate_expiry_time(cert.certificate) - if not expiry_time: + if not cert.expiry_time or not cert.expiry_notification_time: continue - expiry_notification_time = expiry_time - timedelta( - hours=self.expiry_notification_time - ) - if datetime.now(timezone.utc) > expiry_notification_time: + if datetime.now(timezone.utc) > cert.expiry_notification_time: expiring_certificates.append(cert) return expiring_certificates @@ -1774,9 +1884,14 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ] for certificate in provider_certificates: if certificate.csr in requirer_csrs: + csr_in_sha256_hex = get_sha256_hex(certificate.csr) if certificate.revoked: with suppress(SecretNotFoundError): - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + logger.debug( + "Removing secret with label %s", + f"{LIBID}-{csr_in_sha256_hex}", + ) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", @@ -1787,16 +1902,34 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ) else: try: - secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") - secret.set_content({"certificate": certificate.certificate}) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + logger.debug( + "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + # Juju < 3.6 will create a new revision even if the content is the same + if ( + secret.get_content(refresh=True).get("certificate", "") + == certificate.certificate + ): + logger.debug( + "Secret %s with correct certificate already exists", + f"{LIBID}-{csr_in_sha256_hex}", + ) + continue + secret.set_content( + {"certificate": certificate.certificate, "csr": certificate.csr} + ) secret.set_info( - expire=self._get_next_secret_expiry_time(certificate.certificate), + expire=self._get_next_secret_expiry_time(certificate), ) except SecretNotFoundError: + logger.debug( + "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) secret = self.charm.unit.add_secret( - {"certificate": certificate.certificate}, - label=f"{LIBID}-{certificate.csr}", - expire=self._get_next_secret_expiry_time(certificate.certificate), + {"certificate": certificate.certificate, "csr": certificate.csr}, + label=f"{LIBID}-{csr_in_sha256_hex}", + expire=self._get_next_secret_expiry_time(certificate), ) self.on.certificate_available.emit( certificate_signing_request=certificate.csr, @@ -1805,7 +1938,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: chain=certificate.chain, ) - def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: + def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: """Return the expiry time or expiry notification time. Extracts the expiry time from the provided certificate, calculates the @@ -1813,17 +1946,18 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: the future. Args: - certificate: x509 certificate + certificate: ProviderCertificate object Returns: Optional[datetime]: None if the certificate expiry time cannot be read, next expiry time otherwise. """ - expiry_time = _get_certificate_expiry_time(certificate) - if not expiry_time: + if not certificate.expiry_time or not certificate.expiry_notification_time: return None - expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) - return _get_closest_future_time(expiry_notification_time, expiry_time) + return _get_closest_future_time( + certificate.expiry_notification_time, + certificate.expiry_time, + ) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: """Handle Relation Broken Event. @@ -1855,29 +1989,37 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: Args: event (SecretExpiredEvent): Juju event """ - if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): + csr = self._get_csr_from_secret(event.secret) + if not csr: + logger.error("Failed to get CSR from secret %s", event.secret.label) return - csr = event.secret.label[len(f"{LIBID}-") :] provider_certificate = self._find_certificate_in_relation_data(csr) if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up + logger.warning( + "Failed to find matching certificate for csr, cleaning up secret %s", + event.secret.label, + ) event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(provider_certificate.certificate) - if not expiry_time: + if not provider_certificate.expiry_time: # A secret expired but matching certificate is invalid. Cleaning up + logger.warning( + "Certificate matching csr is invalid, cleaning up secret %s", + event.secret.label, + ) event.secret.remove_all_revisions() return - if datetime.now(timezone.utc) < expiry_time: + if datetime.now(timezone.utc) < provider_certificate.expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( certificate=provider_certificate.certificate, - expiry=expiry_time.isoformat(), + expiry=provider_certificate.expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(provider_certificate.certificate), + expire=provider_certificate.expiry_time, ) else: logger.warning("Certificate is expired") @@ -1898,3 +2040,22 @@ def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCerti continue return provider_certificate return None + + def _get_csr_from_secret(self, secret: Secret) -> Union[str, None]: + """Extract the CSR from the secret label or content. + + This function is a workaround to maintain backwards compatibility + and fix the issue reported in + https://github.com/canonical/tls-certificates-interface/issues/228 + """ + try: + content = secret.get_content(refresh=True) + except SecretNotFoundError: + return None + if not (csr := content.get("csr", None)): + # In versions <14 of the Lib we were storing the CSR in the label of the secret + # The CSR now is stored int the content of the secret, which was a breaking change + # Here we get the CSR if the secret was created by an app using libpatch 14 or lower + if secret.label and secret.label.startswith(f"{LIBID}-"): + csr = secret.label[len(f"{LIBID}-") :] + return csr