From 4ba91bff35e98967150ea7a4e92745ab7c0294ca Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 21 Oct 2021 17:24:18 +0200 Subject: [PATCH] make mypy more strict for prototype datasets (#4513) * make mypy more strict for prototype datasets * fix code format * apply strictness only to datasets * fix more mypy issues * cleanup * fix mnist annotations * refactor celeba * warn on redundant casts * remove redundant cast * simplify annotation * fix import --- mypy.ini | 17 +++++ torchvision/datasets/usps.py | 4 +- .../prototype/datasets/_builtin/caltech.py | 13 ++-- .../prototype/datasets/_builtin/celeba.py | 64 +++++++++---------- .../prototype/datasets/_builtin/cifar.py | 20 +++--- .../prototype/datasets/_builtin/imagenet.py | 23 ++++--- .../prototype/datasets/_builtin/mnist.py | 28 ++++---- .../prototype/datasets/_builtin/sbd.py | 26 +++++--- .../prototype/datasets/_builtin/voc.py | 16 +++-- torchvision/prototype/datasets/_folder.py | 2 +- torchvision/prototype/datasets/benchmark.py | 2 + torchvision/prototype/datasets/decoder.py | 3 +- .../datasets/generate_category_files.py | 4 +- .../prototype/datasets/utils/_dataset.py | 11 +--- .../prototype/datasets/utils/_internal.py | 30 ++++----- .../prototype/datasets/utils/_resource.py | 2 +- 16 files changed, 146 insertions(+), 119 deletions(-) diff --git a/mypy.ini b/mypy.ini index a9d62f38e7b..d2bbe22614f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,6 +4,23 @@ files = torchvision show_error_codes = True pretty = True allow_redefinition = True +warn_redundant_casts = True + +[mypy-torchvision.prototype.datasets.*] + +; untyped definitions and calls +disallow_untyped_defs = True + +; None and Optional handling +no_implicit_optional = True + +; warnings +warn_unused_ignores = True +warn_return_any = True +warn_unreachable = True + +; miscellaneous strictness flags +allow_redefinition = True [mypy-torchvision.io._video_opt.*] diff --git a/torchvision/datasets/usps.py b/torchvision/datasets/usps.py index c90ebfa7e6f..c09ec282e9d 100644 --- a/torchvision/datasets/usps.py +++ b/torchvision/datasets/usps.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, cast, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import numpy as np from PIL import Image @@ -63,7 +63,7 @@ def __init__( raw_data = [line.decode().split() for line in fp.readlines()] tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data] imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) - imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8) + imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) targets = [int(d[0]) - 1 for d in raw_data] self.data = imgs diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 521dac1b814..818ef3ec20c 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -82,7 +82,10 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: return category, id def _collate_and_decode_sample( - self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] + self, + data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]], + *, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: key, image_data, ann_data = data category, _ = key @@ -117,11 +120,11 @@ def _make_datapipe( images_dp, anns_dp = resource_dps images_dp = TarArchiveReader(images_dp) - images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image) + images_dp = Filter(images_dp, self._is_not_background_image) images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) anns_dp = TarArchiveReader(anns_dp) - anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann) + anns_dp = Filter(anns_dp, self._is_ann) dp = KeyZipper( images_dp, @@ -136,7 +139,7 @@ def _make_datapipe( def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp: IterDataPipe = Filter(dp, self._is_not_background_image) + dp = Filter(dp, self._is_not_background_image) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) @@ -185,7 +188,7 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = TarArchiveReader(dp) - dp: IterDataPipe = Filter(dp, self._is_not_rogue_file) + dp = Filter(dp, self._is_not_rogue_file) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index d86eaf27fab..ebfce4b652d 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,6 +1,6 @@ import csv import io -from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence import torch from torchdata.datapipes.iter import ( @@ -23,37 +23,38 @@ from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor -class CelebACSVParser(IterDataPipe): +csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) + + +class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): def __init__( self, - datapipe, + datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, - has_header, - ): + fieldnames: Optional[Sequence[str]] = None, + ) -> None: self.datapipe = datapipe - self.has_header = has_header - self._fmtparams = dict(delimiter=" ", skipinitialspace=True) + self.fieldnames = fieldnames - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: for _, file in self.datapipe: file = (line.decode() for line in file) - if self.has_header: + if self.fieldnames: + fieldnames = self.fieldnames + else: # The first row is skipped, because it only contains the number of samples next(file) - # Empty field names are filtered out, because some files have an extr white space after the header + # Empty field names are filtered out, because some files have an extra white space after the header # line, which is recognized as extra column - fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name] + fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name] # Some files do not include a label for the image ID column if fieldnames[0] != "image_id": fieldnames.insert(0, "image_id") - for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams): - yield line.pop("image_id"), line - else: - for line in csv.reader(file, **self._fmtparams): - yield line[0], line[1:] + for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"): + yield line.pop("image_id"), line class CelebA(Dataset): @@ -104,13 +105,10 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: "2": "test", } - def _filter_split(self, data: Tuple[str, str], *, split): - _, split_id = data - return self._SPLIT_ID_TO_NAME[split_id[0]] == split + def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: + return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split - def _collate_anns( - self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...] - ) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]: + def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]: (image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks) @@ -127,7 +125,7 @@ def _collate_and_decode_sample( image = decoder(buffer) if decoder else buffer - identity = torch.tensor(int(ann["identity"][0])) + identity = int(ann["identity"]["identity"]) attributes = {attr: value == "1" for attr, value in ann["attributes"].items()} bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")]) landmarks = { @@ -153,24 +151,24 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps - splits_dp = CelebACSVParser(splits_dp, has_header=False) - splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) + splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) + splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) images_dp = ZipArchiveReader(images_dp) - anns_dp: IterDataPipe = Zipper( + anns_dp = Zipper( *[ - CelebACSVParser(dp, has_header=has_header) - for dp, has_header in ( - (identities_dp, False), - (attributes_dp, True), - (bboxes_dp, True), - (landmarks_dp, True), + CelebACSVParser(dp, fieldnames=fieldnames) + for dp, fieldnames in ( + (identities_dp, ("image_id", "identity")), + (attributes_dp, None), + (bboxes_dp, None), + (landmarks_dp, None), ) ] ) - anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns) + anns_dp = Mapper(anns_dp, self._collate_anns) dp = KeyZipper( splits_dp, diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 4fbd993d311..edb1c7f88d5 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -3,7 +3,7 @@ import io import pathlib import pickle -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast import numpy as np import torch @@ -56,7 +56,7 @@ def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) - def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: _, file = data - return pickle.load(file, encoding="latin1") + return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) def _collate_and_decode( self, @@ -86,9 +86,9 @@ def _make_datapipe( decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp: IterDataPipe = TarArchiveReader(dp) - dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config)) - dp: IterDataPipe = Mapper(dp, self._unpickle) + dp = TarArchiveReader(dp) + dp = Filter(dp, functools.partial(self._is_data_file, config=config)) + dp = Mapper(dp, self._unpickle) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) @@ -96,9 +96,9 @@ def _make_datapipe( def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME)) - dp: IterDataPipe = Mapper(dp, self._unpickle) - return next(iter(dp))[self._CATEGORIES_KEY] + dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) + dp = Mapper(dp, self._unpickle) + return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) class Cifar10(_CifarBase): @@ -133,9 +133,9 @@ class Cifar100(_CifarBase): _META_FILE_NAME = "meta" _CATEGORIES_KEY = "fine_label_names" - def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool: + def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: path = pathlib.Path(data[0]) - return path.name == config.split + return path.name == cast(str, config.split) @property def info(self) -> DatasetInfo: diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index d568fa2fcfc..28b52d6985b 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,7 +1,7 @@ import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler @@ -44,11 +44,11 @@ def info(self) -> DatasetInfo: @property def category_to_wnid(self) -> Dict[str, str]: - return self.info.extra.category_to_wnid + return cast(Dict[str, str], self.info.extra.category_to_wnid) @property def wnid_to_category(self) -> Dict[str, str]: - return self.info.extra.wnid_to_category + return cast(Dict[str, str], self.info.extra.wnid_to_category) def resources(self, config: DatasetConfig) -> List[OnlineResource]: if config.split == "train": @@ -152,7 +152,7 @@ def _make_datapipe( "n03710721": "tank suit", } - def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]: + def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: resources = self.resources(self.default_config) devkit_dp = resources[1].to_datapipe(root / self.name) devkit_dp = TarArchiveReader(devkit_dp) @@ -160,12 +160,15 @@ def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]: meta = next(iter(devkit_dp))[1] synsets = read_mat(meta, squeeze_me=True)["synsets"] - categories_and_wnids = [ - (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) - for _, wnid, category, _, num_children, *_ in synsets - # if num_children > 0, we are looking at a superclass that has no direct instance - if num_children == 0 - ] + categories_and_wnids = cast( + List[Tuple[str, ...]], + [ + (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) + for _, wnid, category, _, num_children, *_ in synsets + # if num_children > 0, we are looking at a superclass that has no direct instance + if num_children == 0 + ], + ) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) return categories_and_wnids diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 2413a2fb084..af22199ce39 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -38,7 +38,7 @@ prod = functools.partial(functools.reduce, operator.mul) -class MNISTFileReader(IterDataPipe): +class MNISTFileReader(IterDataPipe[np.ndarray]): _DTYPE_MAP = { 8: "u1", # uint8 9: "i1", # int8 @@ -48,13 +48,15 @@ class MNISTFileReader(IterDataPipe): 14: "f8", # float64 } - def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Optional[int]) -> None: + def __init__( + self, datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, start: Optional[int], stop: Optional[int] + ) -> None: self.datapipe = datapipe self.start = start self.stop = stop @staticmethod - def _decode(bytes): + def _decode(bytes: bytes) -> int: return int(codecs.encode(bytes, "hex"), 16) def __iter__(self) -> Iterator[np.ndarray]: @@ -107,7 +109,7 @@ def _collate_and_decode( *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ): + ) -> Dict[str, Any]: image_array, label_array = data image: Union[torch.Tensor, io.BytesIO] @@ -138,14 +140,14 @@ def _make_datapipe( labels_dp = Decompressor(labels_dp) labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop) - dp: IterDataPipe = Zipper(images_dp, labels_dp) + dp = Zipper(images_dp, labels_dp) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder)) class MNIST(_MNISTBase): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "mnist", type=DatasetType.RAW, @@ -176,7 +178,7 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], class FashionMNIST(MNIST): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "fashionmnist", type=DatasetType.RAW, @@ -209,7 +211,7 @@ def info(self): class KMNIST(MNIST): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "kmnist", type=DatasetType.RAW, @@ -231,7 +233,7 @@ def info(self): class EMNIST(_MNISTBase): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "emnist", type=DatasetType.RAW, @@ -295,7 +297,7 @@ def _collate_and_decode( *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ): + ) -> Dict[str, Any]: image_array, label_array = data # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, @@ -321,7 +323,7 @@ def _make_datapipe( images_dp, labels_dp = Demultiplexer( archive_dp, 2, - functools.partial(self._classify_archive, config=config), # type:ignore[arg-type] + functools.partial(self._classify_archive, config=config), drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) @@ -330,7 +332,7 @@ def _make_datapipe( class QMNIST(_MNISTBase): @property - def info(self): + def info(self) -> DatasetInfo: return DatasetInfo( "qmnist", type=DatasetType.RAW, @@ -381,7 +383,7 @@ def _collate_and_decode( *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ): + ) -> Dict[str, Any]: image_array, label_array = data label_parts = label_array.tolist() sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index c0244aa534a..3f57f488795 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,7 +1,7 @@ import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -135,7 +135,7 @@ def _make_datapipe( split_dp, images_dp, anns_dp = Demultiplexer( archive_dp, 3, - self._classify_archive, # type: ignore[arg-type] + self._classify_archive, buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) @@ -159,15 +159,21 @@ def _make_datapipe( def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m")) + dp = Filter(dp, path_comparator("name", "category_names.m")) dp = LineReader(dp) - dp: IterDataPipe = Mapper(dp, bytes.decode, input_col=1) + dp = Mapper(dp, bytes.decode, input_col=1) lines = tuple(zip(*iter(dp)))[1] pattern = re.compile(r"\s*'(?P\w+)';\s*%(?P