Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make mypy more strict for prototype datasets #4513

Merged
merged 18 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@ files = torchvision
show_error_codes = True
pretty = True
allow_redefinition = True
warn_redundant_casts = True

[mypy-torchvision.prototype.datasets.*]

; untyped definitions and calls
disallow_untyped_defs = True

; None and Optional handling
no_implicit_optional = True

; warnings
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True

; miscellaneous strictness flags
allow_redefinition = True
Comment on lines +11 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the default values for those options?
If they're not the default, do we have a strong reason to use them instead of the defaults? Is this going to be clearly beneficial to the code-base and to us as developers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the default values for those options?

Nope.

If they're not the default, do we have a strong reason to use them instead of the defaults? Is this going to be clearly beneficial to the code-base and to us as developers?

Let's go through them one by one:

  • disallow_untyped_defs: by default mypy simply accepts untyped functions and uses Any for the input and output annotations. If our ultimate goal is to declare torchvision typed, we should make sure that we don't miss some functions. This flag enforces that.

  • no_implicit_optional: By default mypy allows this:

    def foo(bar: int = None) -> int:
        pass

    With this option enabled, it has to be

    def foo(bar: Optional[int] = None) -> int:
        pass

    Given that None is a valid input, we should also explicitly mention it in the annotation.

  • warn_unused_ignores: Sometimes we use # type: ignore directives on stuff that is actually wrong in other libraries. For example fix annotation for Demultiplexer pytorch#65998 will make some ignore directives obsolete that are needed now. Without this flag, we would never know.

  • warn_return_any: If a function does something with dynamic types, mypy usually falls back to treating the output as Any. This will warn us if something like this happened, but we specified a more concrete output type.

  • warn_unreachable: This is more a test functionality, as mypy will now warn us if some code is unreachable. For example, with this flag set, mypy will warn that the if branch is unreachable.

    def foo(bar: str) -> str:
        if isinstance(bar, int):
            bar = str(bar)
        return bar
  • allow_redefinition: See Set allow_redefinition = True for mypy #4531. If we have this globally, we can of course remove it here.

Apart from warn_return_any and warn_unreachable I think these flags are clearly beneficial. For the other two, they were beneficial for me in the past, but I can others object to them.


[mypy-torchvision.io._video_opt.*]

Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/usps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Callable, cast, Optional, Tuple
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
raw_data = [line.decode().split() for line in fp.readlines()]
tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]

self.data = imgs
Expand Down
13 changes: 8 additions & 5 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
return category, id

def _collate_and_decode_sample(
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
self,
data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
key, image_data, ann_data = data
category, _ = key
Expand Down Expand Up @@ -117,11 +120,11 @@ def _make_datapipe(
images_dp, anns_dp = resource_dps

images_dp = TarArchiveReader(images_dp)
images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image)
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)

anns_dp = TarArchiveReader(anns_dp)
anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann)
anns_dp = Filter(anns_dp, self._is_ann)

dp = KeyZipper(
images_dp,
Expand All @@ -136,7 +139,7 @@ def _make_datapipe(
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_background_image)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})


Expand Down Expand Up @@ -185,7 +188,7 @@ def _make_datapipe(
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_rogue_file)
dp = Filter(dp, self._is_not_rogue_file)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

Expand Down
64 changes: 31 additions & 33 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence

import torch
from torchdata.datapipes.iter import (
Expand All @@ -23,37 +23,38 @@
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor


class CelebACSVParser(IterDataPipe):
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)


class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __init__(
self,
datapipe,
datapipe: IterDataPipe[Tuple[Any, io.IOBase]],
*,
has_header,
):
fieldnames: Optional[Sequence[str]] = None,
) -> None:
self.datapipe = datapipe
self.has_header = has_header
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)
self.fieldnames = fieldnames

def __iter__(self):
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)

if self.has_header:
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(file)

# Empty field names are filtered out, because some files have an extr white space after the header
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")

for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
yield line.pop("image_id"), line
else:
for line in csv.reader(file, **self._fmtparams):
yield line[0], line[1:]
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line


class CelebA(Dataset):
Expand Down Expand Up @@ -104,13 +105,10 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
"2": "test",
}

def _filter_split(self, data: Tuple[str, str], *, split):
_, split_id = data
return self._SPLIT_ID_TO_NAME[split_id[0]] == split
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split

def _collate_anns(
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]:
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)

Expand All @@ -127,7 +125,7 @@ def _collate_and_decode_sample(

image = decoder(buffer) if decoder else buffer

identity = torch.tensor(int(ann["identity"][0]))
identity = int(ann["identity"]["identity"])
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = {
Expand All @@ -153,24 +151,24 @@ def _make_datapipe(
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps

splits_dp = CelebACSVParser(splits_dp, has_header=False)
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)

images_dp = ZipArchiveReader(images_dp)

anns_dp: IterDataPipe = Zipper(
anns_dp = Zipper(
*[
CelebACSVParser(dp, has_header=has_header)
for dp, has_header in (
(identities_dp, False),
(attributes_dp, True),
(bboxes_dp, True),
(landmarks_dp, True),
CelebACSVParser(dp, fieldnames=fieldnames)
for dp, fieldnames in (
(identities_dp, ("image_id", "identity")),
(attributes_dp, None),
(bboxes_dp, None),
(landmarks_dp, None),
)
]
)
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)
anns_dp = Mapper(anns_dp, self._collate_anns)

dp = KeyZipper(
splits_dp,
Expand Down
20 changes: 10 additions & 10 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast

import numpy as np
import torch
Expand Down Expand Up @@ -56,7 +56,7 @@ def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -

def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return pickle.load(file, encoding="latin1")
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))

def _collate_and_decode(
self,
Expand Down Expand Up @@ -86,19 +86,19 @@ def _make_datapipe(
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp: IterDataPipe = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config))
dp: IterDataPipe = Mapper(dp, self._unpickle)
dp = TarArchiveReader(dp)
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle)
return next(iter(dp))[self._CATEGORIES_KEY]
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])


class Cifar10(_CifarBase):
Expand Down Expand Up @@ -133,9 +133,9 @@ class Cifar100(_CifarBase):
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"

def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
return path.name == cast(str, config.split)

@property
def info(self) -> DatasetInfo:
Expand Down
23 changes: 13 additions & 10 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
Expand Down Expand Up @@ -44,11 +44,11 @@ def info(self) -> DatasetInfo:

@property
def category_to_wnid(self) -> Dict[str, str]:
return self.info.extra.category_to_wnid
return cast(Dict[str, str], self.info.extra.category_to_wnid)

@property
def wnid_to_category(self) -> Dict[str, str]:
return self.info.extra.wnid_to_category
return cast(Dict[str, str], self.info.extra.wnid_to_category)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train":
Expand Down Expand Up @@ -152,20 +152,23 @@ def _make_datapipe(
"n03710721": "tank suit",
}

def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, str]]:
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))

meta = next(iter(devkit_dp))[1]
synsets = read_mat(meta, squeeze_me=True)["synsets"]
categories_and_wnids = [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]
categories_and_wnids = cast(
List[Tuple[str, ...]],
[
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
],
)
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])

return categories_and_wnids
Loading