Skip to content

Commit

Permalink
feat: Validate strings when creating a new DocumentURIString
Browse files Browse the repository at this point in the history
This improves the robustness of `DocumentURIString`` handling by ensuring that they conform to the expected rules (no trailing or leading slashes, no dots). Since initialising a document now requires an instance of this class, and we can guarantee the string contained within meets our requirements, we can safely lose string sanitisation in the `Document` class.

As part of this, `Judgment` and `PressSummary` have had their signatures updated to reinforce the expectation of `DocumentURIString` being provided as the `uri` argument. This led to identifying unsafe behaviour around the handling of press summary URIs and version URIs, which have now been made more explicit.

BREAKING CHANGE: Code which provided unsanitised URIs when initialising `DocumentURIStrings` will now cause `InvalidDocumentURIException`s to be raised.
  • Loading branch information
jacksonj04 committed Nov 13, 2024
1 parent 6bfc155 commit e65e3ae
Show file tree
Hide file tree
Showing 23 changed files with 140 additions and 108 deletions.
6 changes: 4 additions & 2 deletions src/caselawclient/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,14 @@ def get_press_summaries_for_document_uri(
Returns a list of PressSummary objects associated with a given Document URI
"""
vars: query_dicts.GetComponentsForDocumentDict = {
"parent_uri": DocumentURIString(uri if uri.startswith("/") else "/" + uri),
"parent_uri": uri,
"component": "pressSummary",
}
response = self._send_to_eval(vars, "get_components_for_document.xqy")
uris = get_multipart_strings_from_marklogic_response(response)
return [PressSummary(uri.strip(".xml"), self) for uri in uris]
return [
PressSummary(DocumentURIString(uri.strip("/").strip(".xml")), self) for uri in uris
] # TODO: Migrate this strip behaviour into proper manipulation of a MarkLogicURIString

def get_document_by_uri(
self,
Expand Down
25 changes: 22 additions & 3 deletions src/caselawclient/models/documents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, NewType, Optional
from typing import TYPE_CHECKING, Any, Optional

from ds_caselaw_utils import courts
from ds_caselaw_utils.courts import CourtNotFoundException
Expand Down Expand Up @@ -30,7 +30,7 @@
)

from .body import DocumentBody
from .exceptions import CannotPublishUnpublishableDocument, DocumentNotSafeForDeletion
from .exceptions import CannotPublishUnpublishableDocument, DocumentNotSafeForDeletion, InvalidDocumentURIException
from .statuses import DOCUMENT_STATUS_HOLD, DOCUMENT_STATUS_IN_PROGRESS, DOCUMENT_STATUS_NEW, DOCUMENT_STATUS_PUBLISHED

MINIMUM_ENRICHMENT_TIME = datetime.timedelta(minutes=20)
Expand All @@ -47,7 +47,26 @@ class GatewayTimeoutGettingHTMLWithQuery(RuntimeWarning):
from caselawclient.Client import MarklogicApiClient


DocumentURIString = NewType("DocumentURIString", str)
class DocumentURIString(str):
"""
This class checks that the string is actually a valid Document URI on creation. It does _not_ manipulate the string.
"""

def __new__(cls, content: str) -> "DocumentURIString":
# Check that the URI doesn't begin or end with a slash
if content[0] == "/" or content[-1] == "/":
raise InvalidDocumentURIException(
f'"{content}" is not a valid document URI; URIs cannot begin or end with slashes.'
)

# Check that the URI doesn't contain a full stop
if "." in content:
raise InvalidDocumentURIException(
f'"{content}" is not a valid document URI; URIs cannot contain full stops.'
)

# If everything is good, return as usual
return str.__new__(cls, content)


class Document:
Expand Down
4 changes: 4 additions & 0 deletions src/caselawclient/models/documents/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class CannotPublishUnpublishableDocument(Exception):

class DocumentNotSafeForDeletion(Exception):
"""A document which is not safe for deletion cannot be deleted."""


class InvalidDocumentURIException(Exception):
"""The document URI is not valid."""
8 changes: 4 additions & 4 deletions src/caselawclient/models/judgments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
if TYPE_CHECKING:
from caselawclient.models.press_summaries import PressSummary

from .documents import Document
from .documents import Document, DocumentURIString


class Judgment(NeutralCitationMixin, Document):
Expand All @@ -21,8 +21,8 @@ class Judgment(NeutralCitationMixin, Document):
document_noun = "judgment"
document_noun_plural = "judgments"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(self.document_noun, *args, **kwargs)
def __init__(self, uri: DocumentURIString, *args: Any, **kwargs: Any) -> None:
super().__init__(self.document_noun, uri, *args, **kwargs)

@cached_property
def neutral_citation(self) -> NeutralCitationString:
Expand All @@ -46,7 +46,7 @@ def linked_document(self) -> Optional["PressSummary"]:
Attempt to fetch a linked press summary, and return it, if it exists
"""
try:
uri = self.uri + "/press-summary/1"
uri = DocumentURIString(self.uri + "/press-summary/1")
if not TYPE_CHECKING: # This isn't nice, but will be cleaned up when we refactor how related documents work
PressSummary = importlib.import_module("caselawclient.models.press_summaries").PressSummary
return PressSummary(uri, self.api_client)
Expand Down
8 changes: 4 additions & 4 deletions src/caselawclient/models/press_summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from caselawclient.errors import DocumentNotFoundError
from caselawclient.models.neutral_citation_mixin import NeutralCitationMixin

from .documents import Document
from .documents import Document, DocumentURIString

if TYPE_CHECKING:
from caselawclient.models.judgments import Judgment
Expand All @@ -23,8 +23,8 @@ class PressSummary(NeutralCitationMixin, Document):
document_noun = "press summary"
document_noun_plural = "press summaries"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(self.document_noun, *args, **kwargs)
def __init__(self, uri: DocumentURIString, *args: Any, **kwargs: Any) -> None:
super().__init__(self.document_noun, uri, *args, **kwargs)

@cached_property
def neutral_citation(self) -> NeutralCitationString:
Expand All @@ -47,7 +47,7 @@ def linked_document(self) -> Optional[Judgment]:
Attempt to fetch a linked judgement, and return it, if it exists
"""
try:
uri = self.uri.removesuffix("/press-summary/1")
uri = DocumentURIString(self.uri.removesuffix("/press-summary/1"))
if not TYPE_CHECKING: # This isn't nice, but will be cleaned up when we refactor how related documents work
Judgment = importlib.import_module("caselawclient.models.judgments").Judgment
return Judgment(uri, self.api_client)
Expand Down
4 changes: 2 additions & 2 deletions src/caselawclient/models/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@


class VersionsDict(TypedDict):
uri: str
uri: str ## TODO: This should be either a MarkLogicDocumentURIString (raw from ML) or a DocumentURIString (and we parse it out). Just a str is too vague.
version: int


def render_versions(decoded_versions: list[BodyPart]) -> list[VersionsDict]:
versions: list[VersionsDict] = [
{
"uri": part.text.rstrip(".xml"),
"uri": part.text.strip("/").rstrip(".xml"),
"version": extract_version(part.text),
}
for part in decoded_versions
Expand Down
2 changes: 1 addition & 1 deletion src/caselawclient/xquery/get_components_for_document.xqy
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ let $docTypeQuery := cts:element-attribute-value-query(
)
let $refQuery := cts:element-query(
xs:QName("uk:summaryOf"),
concat("https://caselaw.nationalarchives.gov.uk/id", $parent_uri)
concat("https://caselaw.nationalarchives.gov.uk/id/", $parent_uri)
)

return xdmp:node-uri(cts:search(//akn:akomaNtoso, cts:and-query(($refQuery, $collectionQuery, $docTypeQuery))))
12 changes: 6 additions & 6 deletions tests/client/test_checkout_checkin_judgment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self):

def test_checkout_judgment(self):
with patch.object(self.client, "eval") as mock_eval:
uri = DocumentURIString("/ewca/civ/2004/632")
uri = DocumentURIString("ewca/civ/2004/632")
annotation = "locked by A KITTEN"
expected_vars = {
"uri": "/ewca/civ/2004/632.xml",
Expand All @@ -35,7 +35,7 @@ def test_checkout_judgment_with_midnight_timeout(self):
"calculate_seconds_until_midnight",
return_value=3600,
):
uri = DocumentURIString("/ewca/civ/2004/632")
uri = DocumentURIString("ewca/civ/2004/632")
annotation = "locked by A KITTEN"
expires_at_midnight = True
expected_vars = {
Expand All @@ -50,7 +50,7 @@ def test_checkout_judgment_with_midnight_timeout(self):

def test_checkout_judgment_with_timeout_seconds(self):
with patch.object(self.client, "eval") as mock_eval:
uri = DocumentURIString("/ewca/civ/2004/632")
uri = DocumentURIString("ewca/civ/2004/632")
annotation = "locked by A KITTEN"
timeout_seconds = 1234
expected_vars = {
Expand All @@ -65,7 +65,7 @@ def test_checkout_judgment_with_timeout_seconds(self):

def test_checkin_judgment(self):
with patch.object(self.client, "eval") as mock_eval:
uri = DocumentURIString("/ewca/civ/2004/632")
uri = DocumentURIString("ewca/civ/2004/632")
expected_vars = {"uri": "/ewca/civ/2004/632.xml"}
self.client.checkin_judgment(uri)

Expand Down Expand Up @@ -101,7 +101,7 @@ def test_get_checkout_status_message(self):
b"</dls:checkout>\r\n"
b"--595658fa1db1aa98--\r\n"
)
result = self.client.get_judgment_checkout_status_message(DocumentURIString("/ewca/2002/2"))
result = self.client.get_judgment_checkout_status_message(DocumentURIString("ewca/2002/2"))
assert result == "locked by a kitten"

def test_get_checkout_status_message_empty(self):
Expand All @@ -118,7 +118,7 @@ def test_get_checkout_status_message_empty(self):
b"\r\n"
b"--595658fa1db1aa98--\r\n"
)
result = self.client.get_judgment_checkout_status_message(DocumentURIString("/ewca/2002/2"))
result = self.client.get_judgment_checkout_status_message(DocumentURIString("ewca/2002/2"))
assert result is None

def test_calculate_seconds_until_midnight(self):
Expand Down
16 changes: 2 additions & 14 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_eval_and_decode(self, mock_eval):
@patch("caselawclient.Client.MarklogicApiClient._eval_and_decode")
def test_document_exists(self, mock_decode):
mock_decode.return_value = "true"
assert self.client.document_exists(DocumentURIString("/2029/eat/1")) is True
assert self.client.document_exists(DocumentURIString("2029/eat/1")) is True
mock_decode.assert_called_with(
{"uri": "/2029/eat/1.xml"},
"document_exists.xqy",
Expand All @@ -115,7 +115,7 @@ def test_document_exists(self, mock_decode):
@patch("caselawclient.Client.MarklogicApiClient._eval_and_decode")
def test_document_not_exists(self, mock_decode):
mock_decode.return_value = "false"
assert self.client.document_exists(DocumentURIString("/2029/eat/1")) is False
assert self.client.document_exists(DocumentURIString("2029/eat/1")) is False
mock_decode.assert_called_with(
{"uri": "/2029/eat/1.xml"},
"document_exists.xqy",
Expand Down Expand Up @@ -158,21 +158,9 @@ def test_invoke_calls_request(self, MockPath):
)

def test_format_uri(self):
uri = DocumentURIString("/ewca/2022/123")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_format_uri_no_leading_slash(self):
uri = DocumentURIString("ewca/2022/123")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_format_uri_trailing_slash(self):
uri = DocumentURIString("ewca/2022/123/")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_format_uri_all_the_slashes(self):
uri = DocumentURIString("/ewca/2022/123/")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_user_agent(self):
user_agent = self.client.session.prepare_request(
Request("GET", "http://example.invalid"),
Expand Down
8 changes: 4 additions & 4 deletions tests/client/test_eval_xslt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_eval_xslt_user_can_view_unpublished(self):
"user_can_view_unpublished_judgments",
return_value=True,
):
uri = DocumentURIString("/judgment/uri")
uri = DocumentURIString("judgment/uri")
expected_vars: XsltTransformDict = {
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
"version_uri": None,
Expand All @@ -43,7 +43,7 @@ def test_eval_xslt_user_cannot_view_unpublished(self):
"user_can_view_unpublished_judgments",
return_value=False,
), patch.object(logging, "warning") as mock_logging:
uri = DocumentURIString("/judgment/uri")
uri = DocumentURIString("judgment/uri")
expected_vars: XsltTransformDict = {
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
"version_uri": None,
Expand All @@ -67,7 +67,7 @@ def test_eval_xslt_with_filename(self):
"user_can_view_unpublished_judgments",
return_value=True,
):
uri = DocumentURIString("/judgment/uri")
uri = DocumentURIString("judgment/uri")
expected_vars: XsltTransformDict = {
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
"version_uri": None,
Expand All @@ -92,7 +92,7 @@ def test_eval_xslt_with_query(self):
"user_can_view_unpublished_judgments",
return_value=True,
):
uri = DocumentURIString("/judgment/uri")
uri = DocumentURIString("judgment/uri")
query = "the query string"
expected_vars: XsltTransformDict = {
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
Expand Down
6 changes: 3 additions & 3 deletions tests/client/test_get_judgment_and_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_get_judgment_xml(self):
b"</akomaNtoso>"
)

result = self.client.get_judgment_xml(DocumentURIString("/judgment/uri"))
result = self.client.get_judgment_xml(DocumentURIString("judgment/uri"))

expected = (
'<?xml version="1.0" encoding="UTF-8"?>\n'
Expand All @@ -42,7 +42,7 @@ def test_get_judgment_xml(self):

def test_get_judgment_version(self):
with patch.object(self.client, "eval") as mock_eval:
uri = DocumentURIString("/ewca/civ/2004/632")
uri = DocumentURIString("ewca/civ/2004/632")
version = 3
expected_vars = {"uri": "/ewca/civ/2004/632.xml", "version": "3"}
self.client.get_judgment_version(uri, version)
Expand All @@ -52,7 +52,7 @@ def test_get_judgment_version(self):

def test_list_judgment_versions(self):
with patch.object(self.client, "eval") as mock_eval:
uri = DocumentURIString("/ewca/civ/2004/632")
uri = DocumentURIString("ewca/civ/2004/632")
expected_vars = {"uri": "/ewca/civ/2004/632.xml"}
self.client.list_judgment_versions(uri)

Expand Down
36 changes: 17 additions & 19 deletions tests/client/test_get_press_summaries_for_document_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,23 @@ def test_get_press_summaries_for_document_uri(
mock_press_summary,
):
mock_eval.return_value = "EVAL"
mock_get_marklogic_response.return_value = ["/foo/bar/baz/1", "/foo/bar/baz/2"]
mock_get_marklogic_response.return_value = ["foo/bar/baz/1", "foo/bar/baz/2"]

for uri in ["foo/bar", "/foo/bar"]:
with self.subTest(uri=uri):
self.client.get_press_summaries_for_document_uri(DocumentURIString(uri))
self.client.get_press_summaries_for_document_uri(DocumentURIString("foo/bar"))

mock_get_marklogic_response.assert_called_with("EVAL")
mock_eval.assert_called_with(
{
"parent_uri": "/foo/bar",
"component": "pressSummary",
},
"get_components_for_document.xqy",
)
mock_get_marklogic_response.assert_called_with("EVAL")
mock_eval.assert_called_with(
{
"parent_uri": "foo/bar",
"component": "pressSummary",
},
"get_components_for_document.xqy",
)

mock_press_summary.assert_has_calls(
[
call("/foo/bar/baz/1", self.client),
call("/foo/bar/baz/2", self.client),
],
any_order=True,
)
mock_press_summary.assert_has_calls(
[
call("foo/bar/baz/1", self.client),
call("foo/bar/baz/2", self.client),
],
any_order=True,
)
2 changes: 1 addition & 1 deletion tests/client/test_get_set_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_set_judgment_date_warn(self):

def test_set_internal_uri_leading_slash(self):
with patch.object(self.client, "eval") as mock_eval:
uri = DocumentURIString("/judgment/uri")
uri = DocumentURIString("judgment/uri")
expected_vars = {
"uri": "/judgment/uri.xml",
"content_with_id": "https://caselaw.nationalarchives.gov.uk/id/judgment/uri",
Expand Down
Loading

0 comments on commit e65e3ae

Please sign in to comment.