Skip to content

Commit

Permalink
make mypy more strict for prototype datasets (#4513)
Browse files Browse the repository at this point in the history
* make mypy more strict for prototype datasets

* fix code format

* apply strictness only to datasets

* fix more mypy issues

* cleanup

* fix mnist annotations

* refactor celeba

* warn on redundant casts

* remove redundant cast

* simplify annotation

* fix import
  • Loading branch information
pmeier authored Oct 21, 2021
1 parent 9407b45 commit 4ba91bf
Show file tree
Hide file tree
Showing 16 changed files with 146 additions and 119 deletions.
17 changes: 17 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@ files = torchvision
show_error_codes = True
pretty = True
allow_redefinition = True
warn_redundant_casts = True

[mypy-torchvision.prototype.datasets.*]

; untyped definitions and calls
disallow_untyped_defs = True

; None and Optional handling
no_implicit_optional = True

; warnings
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True

; miscellaneous strictness flags
allow_redefinition = True

[mypy-torchvision.io._video_opt.*]

Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/usps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Callable, cast, Optional, Tuple
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
raw_data = [line.decode().split() for line in fp.readlines()]
tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]

self.data = imgs
Expand Down
13 changes: 8 additions & 5 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
return category, id

def _collate_and_decode_sample(
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
self,
data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
key, image_data, ann_data = data
category, _ = key
Expand Down Expand Up @@ -117,11 +120,11 @@ def _make_datapipe(
images_dp, anns_dp = resource_dps

images_dp = TarArchiveReader(images_dp)
images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image)
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)

anns_dp = TarArchiveReader(anns_dp)
anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann)
anns_dp = Filter(anns_dp, self._is_ann)

dp = KeyZipper(
images_dp,
Expand All @@ -136,7 +139,7 @@ def _make_datapipe(
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_background_image)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})


Expand Down Expand Up @@ -185,7 +188,7 @@ def _make_datapipe(
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_rogue_file)
dp = Filter(dp, self._is_not_rogue_file)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

Expand Down
64 changes: 31 additions & 33 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence

import torch
from torchdata.datapipes.iter import (
Expand All @@ -23,37 +23,38 @@
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor


class CelebACSVParser(IterDataPipe):
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)


class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __init__(
self,
datapipe,
datapipe: IterDataPipe[Tuple[Any, io.IOBase]],
*,
has_header,
):
fieldnames: Optional[Sequence[str]] = None,
) -> None:
self.datapipe = datapipe
self.has_header = has_header
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)
self.fieldnames = fieldnames

def __iter__(self):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)

if self.has_header:
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(file)

# Empty field names are filtered out, because some files have an extr white space after the header
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")

for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
yield line.pop("image_id"), line
else:
for line in csv.reader(file, **self._fmtparams):
yield line[0], line[1:]
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line


class CelebA(Dataset):
Expand Down Expand Up @@ -104,13 +105,10 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
"2": "test",
}

def _filter_split(self, data: Tuple[str, str], *, split):
_, split_id = data
return self._SPLIT_ID_TO_NAME[split_id[0]] == split
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split

def _collate_anns(
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]:
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)

Expand All @@ -127,7 +125,7 @@ def _collate_and_decode_sample(

image = decoder(buffer) if decoder else buffer

identity = torch.tensor(int(ann["identity"][0]))
identity = int(ann["identity"]["identity"])
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = {
Expand All @@ -153,24 +151,24 @@ def _make_datapipe(
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps

splits_dp = CelebACSVParser(splits_dp, has_header=False)
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)

images_dp = ZipArchiveReader(images_dp)

anns_dp: IterDataPipe = Zipper(
anns_dp = Zipper(
*[
CelebACSVParser(dp, has_header=has_header)
for dp, has_header in (
(identities_dp, False),
(attributes_dp, True),
(bboxes_dp, True),
(landmarks_dp, True),
CelebACSVParser(dp, fieldnames=fieldnames)
for dp, fieldnames in (
(identities_dp, ("image_id", "identity")),
(attributes_dp, None),
(bboxes_dp, None),
(landmarks_dp, None),
)
]
)
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)
anns_dp = Mapper(anns_dp, self._collate_anns)

dp = KeyZipper(
splits_dp,
Expand Down
20 changes: 10 additions & 10 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast

import numpy as np
import torch
Expand Down Expand Up @@ -56,7 +56,7 @@ def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -

def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return pickle.load(file, encoding="latin1")
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))

def _collate_and_decode(
self,
Expand Down Expand Up @@ -86,19 +86,19 @@ def _make_datapipe(
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp: IterDataPipe = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config))
dp: IterDataPipe = Mapper(dp, self._unpickle)
dp = TarArchiveReader(dp)
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle)
return next(iter(dp))[self._CATEGORIES_KEY]
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])


class Cifar10(_CifarBase):
Expand Down Expand Up @@ -133,9 +133,9 @@ class Cifar100(_CifarBase):
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"

def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
return path.name == cast(str, config.split)

@property
def info(self) -> DatasetInfo:
Expand Down
23 changes: 13 additions & 10 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
Expand Down Expand Up @@ -44,11 +44,11 @@ def info(self) -> DatasetInfo:

@property
def category_to_wnid(self) -> Dict[str, str]:
return self.info.extra.category_to_wnid
return cast(Dict[str, str], self.info.extra.category_to_wnid)

@property
def wnid_to_category(self) -> Dict[str, str]:
return self.info.extra.wnid_to_category
return cast(Dict[str, str], self.info.extra.wnid_to_category)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train":
Expand Down Expand Up @@ -152,20 +152,23 @@ def _make_datapipe(
"n03710721": "tank suit",
}

def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]:
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))

meta = next(iter(devkit_dp))[1]
synsets = read_mat(meta, squeeze_me=True)["synsets"]
categories_and_wnids = [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]
categories_and_wnids = cast(
List[Tuple[str, ...]],
[
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
],
)
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])

return categories_and_wnids
Loading

0 comments on commit 4ba91bf

Please sign in to comment.