diff --git a/digid_eherkenning/saml2/eherkenning.py b/digid_eherkenning/saml2/eherkenning.py index 57f48da..5f5f96a 100644 --- a/digid_eherkenning/saml2/eherkenning.py +++ b/digid_eherkenning/saml2/eherkenning.py @@ -1,7 +1,7 @@ import binascii from base64 import b64encode from io import BytesIO -from typing import Union +from typing import Union, no_type_check from uuid import uuid4 from django.urls import reverse @@ -9,7 +9,6 @@ from cryptography.hazmat.primitives import serialization from cryptography.x509 import load_pem_x509_certificate -from furl.furl import furl from lxml.builder import ElementMaker from lxml.etree import Element, tostring from onelogin.saml2.settings import OneLogin_Saml2_Settings @@ -465,35 +464,17 @@ def conf(self) -> EHerkenningConfig: self._conf.setdefault("acs_path", reverse("eherkenning:acs")) return self._conf + @no_type_check # my editor has more red than the red wedding in GOT def create_config_dict(self, conf: EHerkenningConfig) -> EHerkenningSAMLConfig: config_dict: EHerkenningSAMLConfig = super().create_config_dict(conf) + sp_config = config_dict["sp"] + + # we have multiple services, so delete the config for the "single service" variant attribute_consuming_services = create_attribute_consuming_services(conf) - with ( - conf["cert_file"].open("r") as cert_file, - conf["key_file"].open("r") as key_file, - ): - certificate = cert_file.read() - privkey = key_file.read() - acs_url = furl(conf["base_url"]) / conf["acs_path"] - config_dict.update( - { - "sp": { - # Identifier of the SP entity (must be a URI) - "entityId": conf["entity_id"], - # Specifies info about where and how the message MUST be - # returned to the requester, in this case our SP. - "assertionConsumerService": { - "url": acs_url.url, - "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact", - }, - "attributeConsumingServices": attribute_consuming_services, - "NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified", - "x509cert": certificate, - "privateKey": privkey, - }, - } - ) + del sp_config["attributeConsumingService"] + sp_config["attributeConsumingServices"] = attribute_consuming_services + return config_dict def create_config( diff --git a/tests/test_eherkenning_metadata.py b/tests/test_eherkenning_metadata.py index 1abfba6..c2fe4d9 100644 --- a/tests/test_eherkenning_metadata.py +++ b/tests/test_eherkenning_metadata.py @@ -2,8 +2,10 @@ import pytest from lxml import etree +from simple_certmanager.models import Certificate -from digid_eherkenning.models import EherkenningConfiguration +from digid_eherkenning.choices import ConfigTypes +from digid_eherkenning.models import ConfigCertificate, EherkenningConfiguration from digid_eherkenning.saml2.eherkenning import ( eHerkenningClient, generate_eherkenning_metadata, @@ -370,3 +372,55 @@ def test_no_eidas_service(self): ".//md:ServiceDescription", namespaces=NAME_SPACES ).text, ) + + +@pytest.mark.django_db +def test_current_and_next_certificate_in_metadata( + temp_private_root, + eherkenning_config: EherkenningConfiguration, + eherkenning_certificate: Certificate, + next_certificate: Certificate, +): + ConfigCertificate.objects.create( + config_type=ConfigTypes.eherkenning, + certificate=next_certificate, + ) + assert ConfigCertificate.objects.count() == 2 # expect current and next + + eh_metadata = generate_eherkenning_metadata() + + entity_descriptor_node = etree.XML(eh_metadata) + + metadata_node = entity_descriptor_node.find( + "md:SPSSODescriptor", namespaces=NAME_SPACES + ) + assert metadata_node is not None + key_nodes = metadata_node.findall("md:KeyDescriptor", namespaces=NAME_SPACES) + assert len(key_nodes) == 2 # we expect current + next key + key1_node, key2_node = key_nodes + assert key1_node.attrib["use"] == "signing" + assert key2_node.attrib["use"] == "signing" + + with ( + eherkenning_certificate.public_certificate.open("r") as _current, + next_certificate.public_certificate.open("r") as _next, + ): + current_base64 = _current.read().replace("\n", "") + next_base64 = _next.read().replace("\n", "") + + # certificate nodes include only the base64 encoded PEM data, without header/footer + cert1_node = key1_node.find( + "ds:KeyInfo/ds:X509Data/ds:X509Certificate", namespaces=NAME_SPACES + ) + assert cert1_node is not None + assert cert1_node.text is not None + assert (cert_data_1 := cert1_node.text.strip()) in current_base64 + + cert2_node = key2_node.find( + "ds:KeyInfo/ds:X509Data/ds:X509Certificate", namespaces=NAME_SPACES + ) + assert cert2_node is not None + assert cert2_node.text is not None + assert (cert_data_2 := cert2_node.text.strip()) in next_base64 + # they should not be the same + assert cert_data_1 != cert_data_2