Skip to content

Commit

Permalink
Merge pull request #770 from nationalarchives/chore/fix-mypy-in-tests
Browse files Browse the repository at this point in the history
Make tests pass mypy type checking
  • Loading branch information
jacksonj04 authored Nov 13, 2024
2 parents 377d867 + f144447 commit d3db2c9
Show file tree
Hide file tree
Showing 26 changed files with 212 additions and 187 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ repos:
- types-python-dateutil
- types-pytz
- ds-caselaw-utils~=2.0.0
files: ^tests/
exclude: ^smoketest/
id: mypy
name: mypy-tests
repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
5 changes: 3 additions & 2 deletions src/caselawclient/models/judgments.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def linked_document(self) -> Optional["PressSummary"]:
"""
try:
uri = self.uri + "/press-summary/1"
PressSummary = importlib.import_module("caselawclient.models.press_summaries").PressSummary
return PressSummary(uri, self.api_client) # type: ignore
if not TYPE_CHECKING: # This isn't nice, but will be cleaned up when we refactor how related documents work
PressSummary = importlib.import_module("caselawclient.models.press_summaries").PressSummary
return PressSummary(uri, self.api_client)
except DocumentNotFoundError:
return None
5 changes: 3 additions & 2 deletions src/caselawclient/models/press_summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def linked_document(self) -> Optional[Judgment]:
"""
try:
uri = self.uri.removesuffix("/press-summary/1")
Judgment = importlib.import_module("caselawclient.models.judgments").Judgment
return Judgment(uri, self.api_client) # type: ignore
if not TYPE_CHECKING: # This isn't nice, but will be cleaned up when we refactor how related documents work
Judgment = importlib.import_module("caselawclient.models.judgments").Judgment
return Judgment(uri, self.api_client)
except DocumentNotFoundError:
return None
28 changes: 14 additions & 14 deletions tests/client/test_advanced_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def test_invoke_called_with_default_params_when_optional_parameters_not_provided
parameters and return the response
"""

with patch.object(self.client, "invoke"):
with patch.object(self.client, "invoke") as patched_invoke:
response = self.client.advanced_search(SearchParameters())

self.client.invoke.assert_called_with(
patched_invoke.assert_called_with(
"/judgments/search/search-v2.xqy",
json.dumps(
{
Expand All @@ -50,7 +50,7 @@ def test_invoke_called_with_default_params_when_optional_parameters_not_provided
),
)

assert response == self.client.invoke.return_value
assert response == patched_invoke.return_value

def test_invoke_called_with_all_params_when_all_parameters_provided(self):
"""
Expand All @@ -60,7 +60,7 @@ def test_invoke_called_with_all_params_when_all_parameters_provided(self):
Then it should call the MarkLogic module with all the parameters
and return the response
"""
with patch.object(self.client, "invoke"):
with patch.object(self.client, "invoke") as patched_invoke:
response = self.client.advanced_search(
SearchParameters(
query="test query",
Expand All @@ -80,7 +80,7 @@ def test_invoke_called_with_all_params_when_all_parameters_provided(self):
),
)

self.client.invoke.assert_called_with(
patched_invoke.assert_called_with(
"/judgments/search/search-v2.xqy",
json.dumps(
{
Expand All @@ -103,7 +103,7 @@ def test_invoke_called_with_all_params_when_all_parameters_provided(self):
),
)

assert response == self.client.invoke.return_value
assert response == patched_invoke.return_value

def test_exception_raised_when_invoke_raises_an_exception(self):
"""
Expand All @@ -113,8 +113,8 @@ def test_exception_raised_when_invoke_raises_an_exception(self):
Then it should raise that same exception
"""
exception = Exception("Error message from MarkLogic")
with patch.object(self.client, "invoke"):
self.client.invoke.side_effect = exception
with patch.object(self.client, "invoke") as patched_invoke:
patched_invoke.side_effect = exception
with pytest.raises(Exception) as e:
self.client.advanced_search(SearchParameters(query="test query"))
assert e.value == exception
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_user_can_view_unpublished_and_show_unpublished_is_true(
When the advanced_search method is called with the show_unpublished parameter set to True
Then it should call the MarkLogic module with the expected query parameters
"""
with patch.object(self.client, "invoke"), patch.object(
with patch.object(self.client, "invoke") as patched_invoke, patch.object(
self.client,
"user_can_view_unpublished_judgments",
return_value=True,
Expand All @@ -193,7 +193,7 @@ def test_user_can_view_unpublished_and_show_unpublished_is_true(
show_unpublished=True,
),
)
assert '"show_unpublished": "true"' in self.client.invoke.call_args.args[1]
assert '"show_unpublished": "true"' in patched_invoke.call_args.args[1]

def test_user_cannot_view_unpublished_but_show_unpublished_is_true(
self,
Expand All @@ -204,7 +204,7 @@ def test_user_cannot_view_unpublished_but_show_unpublished_is_true(
When the advanced_search method is called with the show_unpublished parameter set to True
Then it should call the MarkLogic module with the show_unpublished parameter set to False and log a warning
"""
with patch.object(self.client, "invoke"), patch.object(
with patch.object(self.client, "invoke") as patched_invoke, patch.object(
self.client,
"user_can_view_unpublished_judgments",
return_value=False,
Expand All @@ -221,7 +221,7 @@ def test_user_cannot_view_unpublished_but_show_unpublished_is_true(
),
)

assert '"show_unpublished": "false"' in self.client.invoke.call_args.args[1]
assert '"show_unpublished": "false"' in patched_invoke.call_args.args[1]
mock_logging.assert_called()

def test_no_page_0(self):
Expand All @@ -231,11 +231,11 @@ def test_no_page_0(self):
When the advanced_search method is called with the page parameter set to 0
Then it should call the MarkLogic module with the page parameter set to 1
"""
with patch.object(self.client, "invoke"):
with patch.object(self.client, "invoke") as patched_invoke:
self.client.advanced_search(
SearchParameters(
page=0,
),
)

assert ', "page": 1,' in self.client.invoke.call_args.args[1]
assert ', "page": 1,' in patched_invoke.call_args.args[1]
17 changes: 9 additions & 8 deletions tests/client/test_checkout_checkin_judgment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import patch

from caselawclient.Client import ROOT_DIR, MarklogicApiClient
from caselawclient.models.documents import DocumentURIString


class TestGetCheckoutStatus(unittest.TestCase):
Expand All @@ -13,7 +14,7 @@ def setUp(self):

def test_checkout_judgment(self):
with patch.object(self.client, "eval") as mock_eval:
uri = "/ewca/civ/2004/632"
uri = DocumentURIString("/ewca/civ/2004/632")
annotation = "locked by A KITTEN"
expected_vars = {
"uri": "/ewca/civ/2004/632.xml",
Expand All @@ -34,7 +35,7 @@ def test_checkout_judgment_with_midnight_timeout(self):
"calculate_seconds_until_midnight",
return_value=3600,
):
uri = "/ewca/civ/2004/632"
uri = DocumentURIString("/ewca/civ/2004/632")
annotation = "locked by A KITTEN"
expires_at_midnight = True
expected_vars = {
Expand All @@ -49,7 +50,7 @@ def test_checkout_judgment_with_midnight_timeout(self):

def test_checkout_judgment_with_timeout_seconds(self):
with patch.object(self.client, "eval") as mock_eval:
uri = "/ewca/civ/2004/632"
uri = DocumentURIString("/ewca/civ/2004/632")
annotation = "locked by A KITTEN"
timeout_seconds = 1234
expected_vars = {
Expand All @@ -64,7 +65,7 @@ def test_checkout_judgment_with_timeout_seconds(self):

def test_checkin_judgment(self):
with patch.object(self.client, "eval") as mock_eval:
uri = "/ewca/civ/2004/632"
uri = DocumentURIString("/ewca/civ/2004/632")
expected_vars = {"uri": "/ewca/civ/2004/632.xml"}
self.client.checkin_judgment(uri)

Expand All @@ -73,7 +74,7 @@ def test_checkin_judgment(self):

def test_get_checkout_status(self):
with patch.object(self.client, "eval") as mock_eval:
uri = "judgment/uri"
uri = DocumentURIString("judgment/uri")
expected_vars = {"uri": "/judgment/uri.xml"}
self.client.get_judgment_checkout_status(uri)

Expand All @@ -100,7 +101,7 @@ def test_get_checkout_status_message(self):
b"</dls:checkout>\r\n"
b"--595658fa1db1aa98--\r\n"
)
result = self.client.get_judgment_checkout_status_message("/ewca/2002/2")
result = self.client.get_judgment_checkout_status_message(DocumentURIString("/ewca/2002/2"))
assert result == "locked by a kitten"

def test_get_checkout_status_message_empty(self):
Expand All @@ -117,7 +118,7 @@ def test_get_checkout_status_message_empty(self):
b"\r\n"
b"--595658fa1db1aa98--\r\n"
)
result = self.client.get_judgment_checkout_status_message("/ewca/2002/2")
result = self.client.get_judgment_checkout_status_message(DocumentURIString("/ewca/2002/2"))
assert result is None

def test_calculate_seconds_until_midnight(self):
Expand All @@ -133,7 +134,7 @@ def test_break_judgment_checkout(self):
client = MarklogicApiClient("", "", "", False)

with patch.object(client, "eval") as mock_eval:
uri = "judgment/uri"
uri = DocumentURIString("judgment/uri")
expected_vars = {
"uri": "/judgment/uri.xml",
}
Expand Down
65 changes: 32 additions & 33 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_single_string_from_marklogic_response,
)
from caselawclient.errors import GatewayTimeoutError
from caselawclient.models.documents import DocumentURIString


class TestErrors(unittest.TestCase):
Expand Down Expand Up @@ -100,12 +101,12 @@ def test_eval_and_decode(self, mock_eval):
b"true\r\n"
b"--595658fa1db1aa98--\r\n"
)
assert self.client._eval_and_decode({"url": "/2029/eat/1"}, "myfile.xqy") == "true"
assert self.client._eval_and_decode({}, "myfile.xqy") == "true"

@patch("caselawclient.Client.MarklogicApiClient._eval_and_decode")
def test_document_exists(self, mock_decode):
mock_decode.return_value = "true"
assert self.client.document_exists("/2029/eat/1") is True
assert self.client.document_exists(DocumentURIString("/2029/eat/1")) is True
mock_decode.assert_called_with(
{"uri": "/2029/eat/1.xml"},
"document_exists.xqy",
Expand All @@ -114,7 +115,7 @@ def test_document_exists(self, mock_decode):
@patch("caselawclient.Client.MarklogicApiClient._eval_and_decode")
def test_document_not_exists(self, mock_decode):
mock_decode.return_value = "false"
assert self.client.document_exists("/2029/eat/1") is False
assert self.client.document_exists(DocumentURIString("/2029/eat/1")) is False
mock_decode.assert_called_with(
{"uri": "/2029/eat/1.xml"},
"document_exists.xqy",
Expand All @@ -125,53 +126,51 @@ def test_eval_calls_request(self, MockPath):
mock_path_instance = MockPath.return_value
mock_path_instance.read_text.return_value = "mock-query"

self.client.session.request = MagicMock()

self.client.eval("mock-query-path.xqy", vars='{{"testvar":"test"}}')

self.client.session.request.assert_called_with(
"POST",
url=self.client._path_to_request_url("LATEST/eval"),
headers={
"Content-type": "application/x-www-form-urlencoded",
"Accept": "multipart/mixed",
},
data={"xquery": "mock-query", "vars": '{{"testvar":"test"}}'},
)
with patch.object(self.client.session, "request") as patched_request:
self.client.eval("mock-query-path.xqy", vars='{{"testvar":"test"}}')

patched_request.assert_called_with(
"POST",
url=self.client._path_to_request_url("LATEST/eval"),
headers={
"Content-type": "application/x-www-form-urlencoded",
"Accept": "multipart/mixed",
},
data={"xquery": "mock-query", "vars": '{{"testvar":"test"}}'},
)

@patch("caselawclient.Client.Path")
def test_invoke_calls_request(self, MockPath):
mock_path_instance = MockPath.return_value
mock_path_instance.read_text.return_value = "mock-query"

self.client.session.request = MagicMock()

self.client.invoke("mock-query-path.xqy", vars='{{"testvar":"test"}}')

self.client.session.request.assert_called_with(
"POST",
url=self.client._path_to_request_url("LATEST/invoke"),
headers={
"Content-type": "application/x-www-form-urlencoded",
"Accept": "multipart/mixed",
},
data={"module": "mock-query-path.xqy", "vars": '{{"testvar":"test"}}'},
)
with patch.object(self.client.session, "request") as patched_request:
self.client.invoke("mock-query-path.xqy", vars='{{"testvar":"test"}}')

patched_request.assert_called_with(
"POST",
url=self.client._path_to_request_url("LATEST/invoke"),
headers={
"Content-type": "application/x-www-form-urlencoded",
"Accept": "multipart/mixed",
},
data={"module": "mock-query-path.xqy", "vars": '{{"testvar":"test"}}'},
)

def test_format_uri(self):
uri = "/ewca/2022/123"
uri = DocumentURIString("/ewca/2022/123")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_format_uri_no_leading_slash(self):
uri = "ewca/2022/123"
uri = DocumentURIString("ewca/2022/123")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_format_uri_trailing_slash(self):
uri = "ewca/2022/123/"
uri = DocumentURIString("ewca/2022/123/")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_format_uri_all_the_slashes(self):
uri = "/ewca/2022/123/"
uri = DocumentURIString("/ewca/2022/123/")
assert self.client._format_uri_for_marklogic(uri) == "/ewca/2022/123.xml"

def test_user_agent(self):
Expand Down
19 changes: 10 additions & 9 deletions tests/client/test_eval_xslt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import unittest
from unittest.mock import patch

from caselawclient.Client import ROOT_DIR, MarklogicApiClient
from caselawclient.Client import ROOT_DIR, MarklogicApiClient, MarkLogicDocumentURIString
from caselawclient.models.documents import DocumentURIString
from caselawclient.xquery_type_dicts import XsltTransformDict


Expand All @@ -19,9 +20,9 @@ def test_eval_xslt_user_can_view_unpublished(self):
"user_can_view_unpublished_judgments",
return_value=True,
):
uri = "/judgment/uri"
uri = DocumentURIString("/judgment/uri")
expected_vars: XsltTransformDict = {
"uri": "/judgment/uri.xml",
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
"version_uri": None,
"show_unpublished": True,
"img_location": "imagepath",
Expand All @@ -42,9 +43,9 @@ def test_eval_xslt_user_cannot_view_unpublished(self):
"user_can_view_unpublished_judgments",
return_value=False,
), patch.object(logging, "warning") as mock_logging:
uri = "/judgment/uri"
uri = DocumentURIString("/judgment/uri")
expected_vars: XsltTransformDict = {
"uri": "/judgment/uri.xml",
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
"version_uri": None,
"show_unpublished": False,
"img_location": "imagepath",
Expand All @@ -66,9 +67,9 @@ def test_eval_xslt_with_filename(self):
"user_can_view_unpublished_judgments",
return_value=True,
):
uri = "/judgment/uri"
uri = DocumentURIString("/judgment/uri")
expected_vars: XsltTransformDict = {
"uri": "/judgment/uri.xml",
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
"version_uri": None,
"show_unpublished": True,
"img_location": "imagepath",
Expand All @@ -91,10 +92,10 @@ def test_eval_xslt_with_query(self):
"user_can_view_unpublished_judgments",
return_value=True,
):
uri = "/judgment/uri"
uri = DocumentURIString("/judgment/uri")
query = "the query string"
expected_vars: XsltTransformDict = {
"uri": "/judgment/uri.xml",
"uri": MarkLogicDocumentURIString("/judgment/uri.xml"),
"version_uri": None,
"show_unpublished": True,
"img_location": "imagepath",
Expand Down
Loading

0 comments on commit d3db2c9

Please sign in to comment.