diff --git a/src/caselawclient/models/documents/__init__.py b/src/caselawclient/models/documents/__init__.py index 79898916..7ef4fbf6 100644 --- a/src/caselawclient/models/documents/__init__.py +++ b/src/caselawclient/models/documents/__init__.py @@ -6,7 +6,6 @@ from ds_caselaw_utils import courts from ds_caselaw_utils.courts import CourtNotFoundException from ds_caselaw_utils.types import NeutralCitationString -from lxml import etree from lxml import html as html_parser from requests_toolbelt.multipart import decoder @@ -16,8 +15,7 @@ NotSupportedOnVersion, OnlySupportedOnVersion, ) -from caselawclient.models.identifiers import Identifier -from caselawclient.models.identifiers.unpacker import unpack_identifier_from_etree +from caselawclient.models.identifiers.unpacker import unpack_all_identifiers_from_etree from caselawclient.models.utilities import VersionsDict, extract_version, render_versions from caselawclient.models.utilities.aws import ( ParserInstructionsDict, @@ -127,8 +125,6 @@ class Document: Individual document classes should extend this list where necessary to validate document type-specific attributes. """ - _identifiers: dict[str, Identifier] - def __init__(self, uri: DocumentURIString, api_client: "MarklogicApiClient", search_query: Optional[str] = None): """ :param uri: The URI of the document to retrieve from MarkLogic. @@ -170,39 +166,8 @@ def docx_exists(self) -> bool: def _initialise_identifiers(self) -> None: """Load this document's identifiers from MarkLogic.""" - self._identifiers = {} - identifiers_element_as_etree = self.api_client.get_property_as_node(self.uri, "identifiers") - - if identifiers_element_as_etree is not None: - for identifier_etree in identifiers_element_as_etree.findall("identifier"): - identifier = unpack_identifier_from_etree(identifier_etree) - self.add_identifier(identifier) - - @property - def identifiers(self) -> list[Identifier]: - """Return a list of Identifier objects for easy display and interaction.""" - return list(self._identifiers.values()) - - def add_identifier(self, identifier: Identifier) -> None: - """Add an Identifier object to this Document's list of identifiers.""" - - self._identifiers[identifier.uuid] = identifier - - @property - def identifiers_as_etree(self) -> etree._Element: - """Return an etree representation of all the Document's identifiers.""" - identifiers_root = etree.Element("identifiers") - - for identifier in self.identifiers: - identifiers_root.append(identifier.as_xml_tree) - - return identifiers_root - - def save_identifiers(self) -> None: - """Save the current state of this Document's identifiers to MarkLogic.""" - - self.api_client.set_property_as_node(self.uri, "identifiers", self.identifiers_as_etree) + self.identifiers = unpack_all_identifiers_from_etree(identifiers_element_as_etree) @property def best_human_identifier(self) -> Optional[str]: diff --git a/src/caselawclient/models/identifiers/__init__.py b/src/caselawclient/models/identifiers/__init__.py index 67bace25..16d30212 100644 --- a/src/caselawclient/models/identifiers/__init__.py +++ b/src/caselawclient/models/identifiers/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, Union from uuid import uuid4 from lxml import etree @@ -70,7 +70,7 @@ def __init_subclass__(cls: type["Identifier"], **kwargs: Any) -> None: super().__init_subclass__(**kwargs) def __repr__(self) -> str: - return f"{self.uuid} ({self.schema.name}): {self.value}" + return f"<{self.schema.name} {self.value}: {self.uuid}>" def __init__(self, value: str, uuid: Optional[str] = None) -> None: self.value = value @@ -96,3 +96,29 @@ def as_xml_tree(self) -> etree._Element: @property def url_slug(self) -> str: return self.schema.compile_identifier_url_slug(self.value) + + +class Identifiers(dict[str, Identifier]): + def add(self, identifier: Identifier) -> None: + self[identifier.uuid] = identifier + + def __delitem__(self, key: Union[Identifier, str]) -> None: + if isinstance(key, Identifier): + super().__delitem__(key.uuid) + else: + super().__delitem__(key) + + @property + def as_etree(self) -> etree._Element: + """Return an etree representation of all the Document's identifiers.""" + identifiers_root = etree.Element("identifiers") + + for identifier in self.values(): + identifiers_root.append(identifier.as_xml_tree) + + return identifiers_root + + def save(self, document) -> None: # type: ignore[no-untyped-def, unused-ignore] + """Save the current state of this Document's identifiers to MarkLogic.""" + + document.api_client.set_property_as_node(document.uri, "identifiers", self.as_etree) diff --git a/src/caselawclient/models/identifiers/unpacker.py b/src/caselawclient/models/identifiers/unpacker.py index 5f2faa7c..101b2e50 100644 --- a/src/caselawclient/models/identifiers/unpacker.py +++ b/src/caselawclient/models/identifiers/unpacker.py @@ -1,6 +1,8 @@ +from typing import Optional + from lxml import etree -from . import IDENTIFIER_UNPACKABLE_ATTRIBUTES, Identifier, InvalidIdentifierXMLRepresentationException +from . import IDENTIFIER_UNPACKABLE_ATTRIBUTES, Identifier, Identifiers, InvalidIdentifierXMLRepresentationException from .neutral_citation import NeutralCitationNumber IDENTIFIER_NAMESPACE_MAP: dict[str, type[Identifier]] = { @@ -8,8 +10,19 @@ } -def unpack_identifier_from_etree(identifier_xml: etree._Element) -> Identifier: - """Given an etree representation of an identifier, unpack it into an appropriate instance of an Identifier.""" +def unpack_all_identifiers_from_etree(identifiers_etree: Optional[etree._Element]) -> Identifiers: + """This expects the entire tag, and unpacks all Identifiers inside it""" + identifiers = Identifiers() + if identifiers_etree is None: + return identifiers + for identifier_etree in identifiers_etree.findall("identifier"): + identifier = unpack_an_identifier_from_etree(identifier_etree) + identifiers.add(identifier) + return identifiers + + +def unpack_an_identifier_from_etree(identifier_xml: etree._Element) -> Identifier: + """Given an etree representation of a single identifier, unpack it into an appropriate instance of an Identifier.""" namespace_element = identifier_xml.find("namespace") diff --git a/tests/models/documents/test_document_identifiers.py b/tests/models/documents/test_document_identifiers.py index 379f3c41..8c90f7c5 100644 --- a/tests/models/documents/test_document_identifiers.py +++ b/tests/models/documents/test_document_identifiers.py @@ -8,23 +8,23 @@ class TestDocumentIdentifiers: def test_add_identifiers(self): document = DocumentFactory.build() - identifier_1 = TestIdentifier(uuid="e28e3ef1-85ed-4997-87ee-e7428a6cc02e", value="TEST-123") - identifier_2 = TestIdentifier(uuid="14ce4b3b-03c8-44f9-a29e-e02ce35fe136", value="TEST-456") - document.add_identifier(identifier_1) - document.add_identifier(identifier_2) + identifier_1 = TestIdentifier(uuid="id-1", value="TEST-123") + identifier_2 = TestIdentifier(uuid="id-2", value="TEST-456") + document.identifiers.add(identifier_1) + document.identifiers.add(identifier_2) - assert document.identifiers == [ - identifier_1, - identifier_2, - ] + assert document.identifiers == { + "id-1": identifier_1, + "id-2": identifier_2, + } def test_identifiers_as_etree(self): document = DocumentFactory.build() identifier_1 = TestIdentifier(uuid="e28e3ef1-85ed-4997-87ee-e7428a6cc02e", value="TEST-123") identifier_2 = TestIdentifier(uuid="14ce4b3b-03c8-44f9-a29e-e02ce35fe136", value="TEST-456") - document.add_identifier(identifier_1) - document.add_identifier(identifier_2) + document.identifiers.add(identifier_1) + document.identifiers.add(identifier_2) expected_xml = """ @@ -43,6 +43,6 @@ def test_identifiers_as_etree(self): """ - assert etree.canonicalize(document.identifiers_as_etree, strip_text=True) == etree.canonicalize( + assert etree.canonicalize(document.identifiers.as_etree, strip_text=True) == etree.canonicalize( etree.fromstring(expected_xml), strip_text=True ) diff --git a/tests/models/identifiers/test_identifer_unpacking.py b/tests/models/identifiers/test_identifer_unpacking.py index 10837445..0911d37e 100644 --- a/tests/models/identifiers/test_identifer_unpacking.py +++ b/tests/models/identifiers/test_identifer_unpacking.py @@ -3,7 +3,8 @@ from lxml import etree from test_identifiers import TestIdentifier -from caselawclient.models.identifiers.unpacker import unpack_identifier_from_etree +from caselawclient.models.identifiers import Identifiers +from caselawclient.models.identifiers.unpacker import unpack_all_identifiers_from_etree, unpack_an_identifier_from_etree class TestIdentifierUnpacking: @@ -17,7 +18,7 @@ def test_unpack_identifier(self): """) - unpacked_identifier = unpack_identifier_from_etree(xml_tree) + unpacked_identifier = unpack_an_identifier_from_etree(xml_tree) assert type(unpacked_identifier) is TestIdentifier assert unpacked_identifier.uuid == "2d80bf1d-e3ea-452f-965c-041f4399f2dd" @@ -26,15 +27,31 @@ def test_unpack_identifier(self): class TestIdentifierPackUnpackRoundTrip: @patch("caselawclient.models.identifiers.unpacker.IDENTIFIER_NAMESPACE_MAP", {"test": TestIdentifier}) - def test_unpack_identifier(self): + def test_roundtrip_identifier(self): """Check that if we convert an Identifier to XML and back again we get the same thing out at the far end.""" original_identifier = TestIdentifier(value="TEST-123") xml_tree = original_identifier.as_xml_tree - unpacked_identifier = unpack_identifier_from_etree(xml_tree) + unpacked_identifier = unpack_an_identifier_from_etree(xml_tree) assert type(unpacked_identifier) is TestIdentifier assert unpacked_identifier.uuid == original_identifier.uuid assert unpacked_identifier.value == "TEST-123" + + @patch("caselawclient.models.identifiers.unpacker.IDENTIFIER_NAMESPACE_MAP", {"test": TestIdentifier}) + def test_roundtrip_identifiers(self): + """Check that if we convert an Identifier to XML and back again we get the same thing out at the far end.""" + uuid = "id-1" + original_identifiers = Identifiers() + original_identifiers.add(TestIdentifier(uuid=uuid, value="TEST-123")) + + xml_tree = original_identifiers.as_etree + + unpacked_identifiers = unpack_all_identifiers_from_etree(xml_tree) + unpacked_identifier = unpacked_identifiers[uuid] + + assert type(unpacked_identifier) is TestIdentifier + assert unpacked_identifier.uuid == uuid + assert unpacked_identifier.value == "TEST-123" diff --git a/tests/models/identifiers/test_identifiers.py b/tests/models/identifiers/test_identifiers.py index 0089b311..cbc8d3c2 100644 --- a/tests/models/identifiers/test_identifiers.py +++ b/tests/models/identifiers/test_identifiers.py @@ -1,6 +1,19 @@ +import pytest from lxml import etree -from caselawclient.models.identifiers import Identifier, IdentifierSchema +from caselawclient.models.identifiers import Identifier, Identifiers, IdentifierSchema + + +@pytest.fixture +def identifiers(): + return Identifiers( + {"id-1": Identifier(uuid="id-1", value="TEST-111"), "id-2": Identifier(uuid="id-2", value="TEST-222")} + ) + + +@pytest.fixture +def id3(): + return Identifier(uuid="id-3", value="TEST-333") class TestIdentifierSchema(IdentifierSchema): @@ -41,3 +54,21 @@ def test_xml_representation(self): assert etree.canonicalize(identifier.as_xml_tree, strip_text=True) == etree.canonicalize( etree.fromstring(expected_xml), strip_text=True ) + + +class TestIdentifiersCRUD: + def test_delete(self, identifiers): + del identifiers["id-1"] + assert len(identifiers) == 1 + assert "id-2" in identifiers + + def test_delete_identifier(self, identifiers): + id1 = identifiers["id-1"] + del identifiers[id1] + assert len(identifiers) == 1 + assert "id-2" in identifiers + + def test_add_identifier(self, identifiers, id3): + identifiers.add(id3) + assert identifiers["id-3"] == id3 + assert len(identifiers) == 3