From 62013610c78c93d2b1f66868b2d9ac488deca0f2 Mon Sep 17 00:00:00 2001 From: venkatram-dev Date: Fri, 15 Nov 2024 12:17:21 -0800 Subject: [PATCH 1/5] fix phototour code --- torchvision/datasets/phototour.py | 124 +++++++++--------------------- 1 file changed, 38 insertions(+), 86 deletions(-) diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index 9511f0626b4..8a5efcca935 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -11,18 +11,7 @@ class PhotoTour(VisionDataset): - """`Multi-view Stereo Correspondence `_ Dataset. - - .. note:: - - We only provide the newer version of the dataset, since the authors state that it - - is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the - patches are centred on real interest point detections, rather than being projections of 3D points as is the - case in the old dataset. - - The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm. - + """`Multi-view Stereo Correspondence `_ Dataset. Args: root (str or ``pathlib.Path``): Root directory where images are. @@ -32,60 +21,42 @@ class PhotoTour(VisionDataset): download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. - """ urls = { - "notredame_harris": [ - "http://matthewalunbrown.com/patchdata/notredame_harris.zip", - "notredame_harris.zip", - "69f8c90f78e171349abdf0307afefe4d", - ], - "yosemite_harris": [ - "http://matthewalunbrown.com/patchdata/yosemite_harris.zip", - "yosemite_harris.zip", - "a73253d1c6fbd3ba2613c45065c00d46", - ], - "liberty_harris": [ - "http://matthewalunbrown.com/patchdata/liberty_harris.zip", - "liberty_harris.zip", - "c731fcfb3abb4091110d0ae8c7ba182c", + "trevi": [ + "https://phototour.cs.washington.edu/patches/trevi.zip", + "trevi.zip", + "d49ab428f154554856f83dba8aa76539", ], "notredame": [ - "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip", + "https://phototour.cs.washington.edu/patches/notredame.zip", "notredame.zip", - "509eda8535847b8c0a90bbb210c83484", + "0f801127085e405a61465605ea80c595", + ], + "halfdome": [ + "https://phototour.cs.washington.edu/patches/halfdome.zip", + "halfdome.zip", + "db871c5a86f4878c6754d0d12146440b", ], - "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"], - "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"], } means = { - "notredame": 0.4854, - "yosemite": 0.4844, - "liberty": 0.4437, - "notredame_harris": 0.4854, - "yosemite_harris": 0.4844, - "liberty_harris": 0.4437, + "trevi": 0.4832, + "notredame": 0.4757, + "halfdome": 0.4718, } stds = { - "notredame": 0.1864, - "yosemite": 0.1818, - "liberty": 0.2019, - "notredame_harris": 0.1864, - "yosemite_harris": 0.1818, - "liberty_harris": 0.2019, + "trevi": 0.1913, + "notredame": 0.1931, + "halfdome": 0.1791, } lens = { - "notredame": 468159, - "yosemite": 633587, - "liberty": 450092, - "liberty_harris": 379587, - "yosemite_harris": 450912, - "notredame_harris": 325295, + "trevi": 101120, + "notredame": 104196, + "halfdome": 107776, } image_ext = "bmp" info_file = "info.txt" - matches_files = "m50_100000_100000_0.txt" def __init__( self, @@ -112,30 +83,23 @@ def __init__( self.cache() # load the serialized data - self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True) + self.data, self.labels = torch.load(self.data_file) - def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]: + def __getitem__(self, index: int) -> torch.Tensor: """ Args: index (int): Index Returns: - tuple: (data1, data2, matches) + torch.Tensor: The image patch. """ - if self.train: - data = self.data[index] - if self.transform is not None: - data = self.transform(data) - return data - m = self.matches[index] - data1, data2 = self.data[m[0]], self.data[m[1]] + data = self.data[index] if self.transform is not None: - data1 = self.transform(data1) - data2 = self.transform(data2) - return data1, data2, m[2] + data = self.transform(data) + return data def __len__(self) -> int: - return len(self.data if self.train else self.matches) + return len(self.data) def _check_datafile_exists(self) -> bool: return os.path.exists(self.data_file) @@ -165,19 +129,16 @@ def download(self) -> None: def cache(self) -> None: # process and save as torch files - dataset = ( read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), read_info_file(self.data_dir, self.info_file), - read_matches_files(self.data_dir, self.matches_files), ) with open(self.data_file, "wb") as f: torch.save(dataset, f) def extra_repr(self) -> str: - split = "Train" if self.train is True else "Test" - return f"Split: {split}" + return f"Dataset: {self.name}" def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: @@ -185,16 +146,18 @@ def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: def PIL2array(_img: Image.Image) -> np.ndarray: """Convert PIL image type to numpy 2D array""" + # Ensure the patch size is exactly 64x64 + if _img.size != (64, 64): + raise ValueError(f"Invalid patch size: {_img.size}") return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) def find_files(_data_dir: str, _image_ext: str) -> List[str]: """Return a list with the file names of the images containing the patches""" files = [] - # find those files with the specified extension for file_dir in os.listdir(_data_dir): if file_dir.endswith(_image_ext): files.append(os.path.join(_data_dir, file_dir)) - return sorted(files) # sort files in ascend order to keep relations + return sorted(files) patches = [] list_files = find_files(data_dir, image_ext) @@ -204,27 +167,16 @@ def find_files(_data_dir: str, _image_ext: str) -> List[str]: for y in range(0, img.height, 64): for x in range(0, img.width, 64): patch = img.crop((x, y, x + 64, y + 64)) - patches.append(PIL2array(patch)) + try: + patches.append(PIL2array(patch)) + except ValueError as e: + print(f"Skipping invalid patch at ({x}, {y}) in {fpath}: {e}") return torch.ByteTensor(np.array(patches[:n])) def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: - """Return a Tensor containing the list of labels - Read the file and keep only the ID of the 3D point. - """ + """Return a Tensor containing the list of labels.""" with open(os.path.join(data_dir, info_file)) as f: labels = [int(line.split()[0]) for line in f] return torch.LongTensor(labels) - -def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor: - """Return a Tensor containing the ground truth matches - Read the file and keep only 3D point ID. - Matches are represented with a 1, non matches with a 0. - """ - matches = [] - with open(os.path.join(data_dir, matches_file)) as f: - for line in f: - line_split = line.split() - matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])]) - return torch.LongTensor(matches) From c013c399c874cf3eb4d802fde648897efcf86052 Mon Sep 17 00:00:00 2001 From: venkatram-dev Date: Fri, 15 Nov 2024 12:25:14 -0800 Subject: [PATCH 2/5] lint phototour code --- torchvision/datasets/phototour.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index 8a5efcca935..d4368e840eb 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -179,4 +179,3 @@ def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: with open(os.path.join(data_dir, info_file)) as f: labels = [int(line.split()[0]) for line in f] return torch.LongTensor(labels) - From 736d2d2afba336d8facc3121c0a84b9c5d023fe8 Mon Sep 17 00:00:00 2001 From: venkatram-dev Date: Sat, 16 Nov 2024 11:58:49 -0800 Subject: [PATCH 3/5] fix test dataset download --- test/test_datasets_download.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 856a02b9d44..b974c49b229 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -262,9 +262,7 @@ def phototour(): return itertools.chain.from_iterable( [ collect_urls(datasets.PhotoTour, ROOT, name=name, download=True) - # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all - # requests timeout from within CI. They are disabled until this is resolved. - for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris" + for name in ("notredame", "trevi", "halfdome") ] ) From e15317b45d421ca86d9ac6faf6698d36ab320784 Mon Sep 17 00:00:00 2001 From: venkatram-dev Date: Sat, 16 Nov 2024 12:14:16 -0800 Subject: [PATCH 4/5] fix dataset test --- test/test_datasets.py | 46 +++++++++++++------------------------------ 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 7e91571744a..088d93e1c09 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1305,16 +1305,10 @@ def test_not_found_or_corrupted(self): class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.PhotoTour - # The PhotoTour dataset returns examples with different features with respect to the 'train' parameter. Thus, - # we overwrite 'FEATURE_TYPES' with a dummy value to satisfy the initial checks of the base class. Furthermore, we - # overwrite the 'test_feature_types()' method to select the correct feature types before the test is run. - FEATURE_TYPES = () - _TRAIN_FEATURE_TYPES = (torch.Tensor,) - _TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor) - - combinations_grid(train=(True, False)) + # The PhotoTour dataset returns only a single feature type. + FEATURE_TYPES = (torch.Tensor,) - _NAME = "liberty" + _NAME = "notredame" def dataset_args(self, tmpdir, config): return tmpdir, self._NAME @@ -1322,21 +1316,18 @@ def dataset_args(self, tmpdir, config): def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) - # In contrast to the original data, the fake images injected here comprise only a single patch. Thus, - # num_images == num_patches. + # Simulate fake data num_patches = 5 image_files = self._create_images(tmpdir, self._NAME, num_patches) point_ids, info_file = self._create_info_file(tmpdir / self._NAME, num_patches) - num_matches, matches_file = self._create_matches_file(tmpdir / self._NAME, num_patches, point_ids) - self._create_archive(tmpdir, self._NAME, *image_files, info_file, matches_file) + self._create_archive(tmpdir, self._NAME, *image_files, info_file) - return num_patches if config["train"] else num_matches + return num_patches def _create_images(self, root, name, num_images): - # The images in the PhotoTour dataset comprises of multiple grayscale patches of 64 x 64 pixels. Thus, the - # smallest fake image is 64 x 64 pixels and comprises a single patch. + # Generate images return datasets_utils.create_image_folder( root, name, lambda idx: f"patches{idx:04d}.bmp", num_images, size=(1, 64, 64) ) @@ -1350,18 +1341,6 @@ def _create_info_file(self, root, num_images): return point_ids, file - def _create_matches_file(self, root, num_patches, point_ids): - lines = [ - f"{patch_id1} {point_ids[patch_id1]} 0 {patch_id2} {point_ids[patch_id2]} 0\n" - for patch_id1, patch_id2 in itertools.combinations(range(num_patches), 2) - ] - - file = root / "m50_100000_100000_0.txt" - with open(file, "w") as fh: - fh.writelines(lines) - - return len(lines), file - def _create_archive(self, root, name, *files): archive = root / f"{name}.zip" with zipfile.ZipFile(archive, "w") as zip: @@ -1372,12 +1351,10 @@ def _create_archive(self, root, name, *files): @datasets_utils.test_all_configs def test_feature_types(self, config): - feature_types = self.FEATURE_TYPES - self.FEATURE_TYPES = self._TRAIN_FEATURE_TYPES if config["train"] else self._TEST_FEATURE_TYPES try: super().test_feature_types.__wrapped__(self, config) - finally: - self.FEATURE_TYPES = feature_types + except KeyError as e: + pytest.fail(f"KeyError during test_feature_types: {e}") class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase): @@ -1869,6 +1846,11 @@ def test_class_to_idx(self): with self.create_dataset() as (dataset, _): assert dataset.class_to_idx == class_to_idx + def test_images_download_preexisting(self): + with pytest.raises(RuntimeError): + with self.create_dataset({"download": True}): + pass + class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.INaturalist From 201d74cc825894c37964e1fe9689564d11d379ab Mon Sep 17 00:00:00 2001 From: venkatram-dev Date: Sat, 16 Nov 2024 12:26:51 -0800 Subject: [PATCH 5/5] add weights only to remove warning --- torchvision/datasets/phototour.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index d4368e840eb..24d67f9f543 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -83,7 +83,7 @@ def __init__( self.cache() # load the serialized data - self.data, self.labels = torch.load(self.data_file) + self.data, self.labels = torch.load(self.data_file, weights_only=True) def __getitem__(self, index: int) -> torch.Tensor: """