From 912ebffdfb143cb7b7642cd24b51a8771e9adc09 Mon Sep 17 00:00:00 2001
From: reivilibre <oliverw@matrix.org>
Date: Wed, 13 Nov 2024 21:01:49 +0000
Subject: [PATCH] Fix bug where the thumbnail endpoint was not used for
 downloading thumbnails (#9)

Closes #8
---
 .../scanner/file_downloader.py                | 19 ++++++++++++++++---
 tests/scanner/test_file_downloader.py         |  8 +++++---
 2 files changed, 21 insertions(+), 6 deletions(-)

diff --git a/src/matrix_content_scanner/scanner/file_downloader.py b/src/matrix_content_scanner/scanner/file_downloader.py
index b1757f8..389030d 100644
--- a/src/matrix_content_scanner/scanner/file_downloader.py
+++ b/src/matrix_content_scanner/scanner/file_downloader.py
@@ -60,7 +60,9 @@ async def download_file(
             ContentScannerRestError: The file was not found or could not be downloaded due
                 to an error on the remote homeserver's side.
         """
-        url = await self._build_https_url(media_path)
+        url = await self._build_https_url(
+            media_path, for_thumbnail=thumbnail_params is not None
+        )
 
         # Attempt to retrieve the file at the generated URL.
         try:
@@ -71,7 +73,11 @@ async def download_file(
             # again with an r0 endpoint.
             logger.info("File not found, trying legacy r0 path")
 
-            url = await self._build_https_url(media_path, endpoint_version="r0")
+            url = await self._build_https_url(
+                media_path,
+                endpoint_version="r0",
+                for_thumbnail=thumbnail_params is not None,
+            )
 
             try:
                 file = await self._get_file_content(url, thumbnail_params)
@@ -89,6 +95,8 @@ async def _build_https_url(
         self,
         media_path: str,
         endpoint_version: str = "v3",
+        *,
+        for_thumbnail: bool,
     ) -> str:
         """Turn a `server_name/media_id` path into an https:// one we can use to fetch
         the media.
@@ -100,6 +108,9 @@ async def _build_https_url(
             media_path: The media path to translate.
             endpoint_version: The version of the download endpoint to use. As of Matrix
                 v1.1, this is either "v3" or "r0".
+            for_thumbnail: True if a server-side thumbnail is desired instead of the full
+                media. In that case, the URL for the `/thumbnail` endpoint is returned
+                instead of the `/download` endpoint.
 
         Returns:
             An https URL to use. If `base_homeserver_url` is set in the config, this
@@ -129,7 +140,9 @@ async def _build_https_url(
                 # didn't find a .well-known file.
                 base_url = "https://" + server_name
 
-        prefix = self.MEDIA_DOWNLOAD_PREFIX
+        prefix = (
+            self.MEDIA_THUMBNAIL_PREFIX if for_thumbnail else self.MEDIA_DOWNLOAD_PREFIX
+        )
 
         # Build the full URL.
         path_prefix = prefix % endpoint_version
diff --git a/tests/scanner/test_file_downloader.py b/tests/scanner/test_file_downloader.py
index 784f813..ce709d4 100644
--- a/tests/scanner/test_file_downloader.py
+++ b/tests/scanner/test_file_downloader.py
@@ -69,8 +69,8 @@ async def test_download(self) -> None:
         self.assertEqual(media.content_type, "image/png")
 
         # Check that we tried downloading from the set base URL.
-        args = self.get_mock.call_args
-        self.assertTrue(args[0][0].startswith("http://my-site.com/"))
+        args = self.get_mock.call_args.args
+        self.assertTrue(args[0].startswith("http://my-site.com/"))
 
     async def test_no_base_url(self) -> None:
         """Tests that configuring a base homeserver URL means files are downloaded from
@@ -146,7 +146,9 @@ async def test_thumbnail(self) -> None:
             MEDIA_PATH, to_thumbnail_params({"height": "50"})
         )
 
-        query: CIMultiDictProxy[str] = self.get_mock.call_args[1]["query"]
+        url: str = self.get_mock.call_args.args[0]
+        query: CIMultiDictProxy[str] = self.get_mock.call_args.kwargs["query"]
+        self.assertIn("/thumbnail/", url)
         self.assertIn("height", query)
         self.assertEqual(query.get("height"), "50", query.getall("height"))