Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make mypy more strict for prototype datasets #4513

Merged
merged 18 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@ show_error_codes = True
pretty = True
allow_redefinition = True

[mypy-torchvision.prototype.datasets.*]

; untyped definitions and calls
disallow_untyped_defs = True

; None and Optional handling
no_implicit_optional = True

; warnings
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True

; miscellaneous strictness flags
allow_redefinition = True
Comment on lines +11 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the default values for those options?
If they're not the default, do we have a strong reason to use them instead of the defaults? Is this going to be clearly beneficial to the code-base and to us as developers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the default values for those options?

Nope.

If they're not the default, do we have a strong reason to use them instead of the defaults? Is this going to be clearly beneficial to the code-base and to us as developers?

Let's go through them one by one:

  • disallow_untyped_defs: by default mypy simply accepts untyped functions and uses Any for the input and output annotations. If our ultimate goal is to declare torchvision typed, we should make sure that we don't miss some functions. This flag enforces that.

  • no_implicit_optional: By default mypy allows this:

    def foo(bar: int = None) -> int:
        pass

    With this option enabled, it has to be

    def foo(bar: Optional[int] = None) -> int:
        pass

    Given that None is a valid input, we should also explicitly mention it in the annotation.

  • warn_unused_ignores: Sometimes we use # type: ignore directives on stuff that is actually wrong in other libraries. For example fix annotation for Demultiplexer pytorch#65998 will make some ignore directives obsolete that are needed now. Without this flag, we would never know.

  • warn_return_any: If a function does something with dynamic types, mypy usually falls back to treating the output as Any. This will warn us if something like this happened, but we specified a more concrete output type.

  • warn_unreachable: This is more a test functionality, as mypy will now warn us if some code is unreachable. For example, with this flag set, mypy will warn that the if branch is unreachable.

    def foo(bar: str) -> str:
        if isinstance(bar, int):
            bar = str(bar)
        return bar
  • allow_redefinition: See Set allow_redefinition = True for mypy #4531. If we have this globally, we can of course remove it here.

Apart from warn_return_any and warn_unreachable I think these flags are clearly beneficial. For the other two, they were beneficial for me in the past, but I can others object to them.


[mypy-torchvision.io._video_opt.*]

ignore_errors = True
Expand Down
13 changes: 8 additions & 5 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
return category, id

def _collate_and_decode_sample(
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
self,
data: Tuple[Tuple[str, str], Tuple[str, io.IOBase], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
key, image_data, ann_data = data
category, _ = key
Expand Down Expand Up @@ -117,11 +120,11 @@ def _make_datapipe(
images_dp, anns_dp = resource_dps

images_dp = TarArchiveReader(images_dp)
images_dp: IterDataPipe = Filter(images_dp, self._is_not_background_image)
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)

anns_dp = TarArchiveReader(anns_dp)
anns_dp: IterDataPipe = Filter(anns_dp, self._is_ann)
anns_dp = Filter(anns_dp, self._is_ann)

dp = KeyZipper(
images_dp,
Expand All @@ -136,7 +139,7 @@ def _make_datapipe(
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_background_image)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})


Expand Down Expand Up @@ -185,7 +188,7 @@ def _make_datapipe(
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_rogue_file)
dp = Filter(dp, self._is_not_rogue_file)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

Expand Down
34 changes: 18 additions & 16 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union, Iterator

import torch
from torchdata.datapipes.iter import (
Expand All @@ -23,37 +23,39 @@
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor


csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)


class CelebACSVParser(IterDataPipe):
def __init__(
self,
datapipe,
datapipe: IterDataPipe[Tuple[Any, io.IOBase]],
*,
has_header,
):
has_header: bool,
) -> None:
self.datapipe = datapipe
self.has_header = has_header
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)

def __iter__(self):
def __iter__(self) -> Iterator[Tuple[str, Union[Dict[str, str], List[str]]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)

if self.has_header:
# The first row is skipped, because it only contains the number of samples
next(file)

# Empty field names are filtered out, because some files have an extr white space after the header
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")

for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
yield line.pop("image_id"), line
for line_dict in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
yield line_dict.pop("image_id"), line_dict
else:
for line in csv.reader(file, **self._fmtparams):
yield line[0], line[1:]
for line_list in csv.reader(file, dialect="celeba"):
yield line_list[0], line_list[1:]


class CelebA(Dataset):
Expand Down Expand Up @@ -104,7 +106,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
"2": "test",
}

def _filter_split(self, data: Tuple[str, str], *, split):
def _filter_split(self, data: Tuple[str, str], *, split: str) -> bool:
_, split_id = data
return self._SPLIT_ID_TO_NAME[split_id[0]] == split

Expand Down Expand Up @@ -154,12 +156,12 @@ def _make_datapipe(
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps

splits_dp = CelebACSVParser(splits_dp, has_header=False)
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)

images_dp = ZipArchiveReader(images_dp)

anns_dp: IterDataPipe = Zipper(
anns_dp = Zipper(
*[
CelebACSVParser(dp, has_header=has_header)
for dp, has_header in (
Expand All @@ -170,7 +172,7 @@ def _make_datapipe(
)
]
)
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)
anns_dp = Mapper(anns_dp, self._collate_anns)

dp = KeyZipper(
splits_dp,
Expand Down
20 changes: 10 additions & 10 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast

import numpy as np
import torch
Expand Down Expand Up @@ -56,7 +56,7 @@ def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -

def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return pickle.load(file, encoding="latin1")
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))

def _collate_and_decode(
self,
Expand Down Expand Up @@ -86,19 +86,19 @@ def _make_datapipe(
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp: IterDataPipe = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, functools.partial(self._is_data_file, config=config))
dp: IterDataPipe = Mapper(dp, self._unpickle)
dp = TarArchiveReader(dp)
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle)
return next(iter(dp))[self._CATEGORIES_KEY]
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])


class Cifar10(_CifarBase):
Expand Down Expand Up @@ -133,9 +133,9 @@ class Cifar100(_CifarBase):
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"

def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> bool:
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
path = pathlib.Path(data[0])
return path.name == config.split
return path.name == cast(str, config.split)

@property
def info(self) -> DatasetInfo:
Expand Down
24 changes: 12 additions & 12 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ class MNISTFileReader(IterDataPipe):
14: "f8", # float64
}

def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Optional[int]) -> None:
def __init__(self, datapipe: IterDataPipe[Any, io.IOBase], *, start: Optional[int], stop: Optional[int]) -> None:
self.datapipe = datapipe
self.start = start
self.stop = stop

@staticmethod
def _decode(bytes):
def _decode(bytes: bytes) -> int:
return int(codecs.encode(bytes, "hex"), 16)

def __iter__(self) -> Iterator[np.ndarray]:
Expand Down Expand Up @@ -107,7 +107,7 @@ def _collate_and_decode(
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
):
) -> Dict[str, Any]:
image_array, label_array = data

image: Union[torch.Tensor, io.BytesIO]
Expand Down Expand Up @@ -138,14 +138,14 @@ def _make_datapipe(
labels_dp = Decompressor(labels_dp)
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)

dp: IterDataPipe = Zipper(images_dp, labels_dp)
dp = Zipper(images_dp, labels_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))


class MNIST(_MNISTBase):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"mnist",
type=DatasetType.RAW,
Expand Down Expand Up @@ -176,7 +176,7 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str],

class FashionMNIST(MNIST):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"fashionmnist",
type=DatasetType.RAW,
Expand Down Expand Up @@ -209,7 +209,7 @@ def info(self):

class KMNIST(MNIST):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"kmnist",
type=DatasetType.RAW,
Expand All @@ -231,7 +231,7 @@ def info(self):

class EMNIST(_MNISTBase):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"emnist",
type=DatasetType.RAW,
Expand Down Expand Up @@ -295,7 +295,7 @@ def _collate_and_decode(
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
):
) -> Dict[str, Any]:
image_array, label_array = data
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
Expand All @@ -321,7 +321,7 @@ def _make_datapipe(
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
functools.partial(self._classify_archive, config=config), # type:ignore[arg-type]
functools.partial(self._classify_archive, config=config),
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
Expand All @@ -330,7 +330,7 @@ def _make_datapipe(

class QMNIST(_MNISTBase):
@property
def info(self):
def info(self) -> DatasetInfo:
return DatasetInfo(
"qmnist",
type=DatasetType.RAW,
Expand Down Expand Up @@ -381,7 +381,7 @@ def _collate_and_decode(
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
):
) -> Dict[str, Any]:
image_array, label_array = data
label_parts = label_array.tolist()
sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder)
Expand Down
26 changes: 16 additions & 10 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import numpy as np
import torch
Expand Down Expand Up @@ -135,7 +135,7 @@ def _make_datapipe(
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive, # type: ignore[arg-type]
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
Expand All @@ -159,15 +159,21 @@ def _make_datapipe(
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m"))
dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp)
dp: IterDataPipe = Mapper(dp, bytes.decode, input_col=1)
dp = Mapper(dp, bytes.decode, input_col=1)
lines = tuple(zip(*iter(dp)))[1]

pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)")
categories_and_labels = [
pattern.match(line).groups() # type: ignore[union-attr]
# the first and last line contain no information
for line in lines[1:-1]
]
return tuple(zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1]))))[0]
categories_and_labels = cast(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering why we need to cast anything here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pattern.match(line).groups() returns a Tuple[Optional[str], ...]. So we need to cast to tell it that this will be a tuple of length 2 and every group was actually matched.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to cast because of the Optional bit or because of the exact length of the tuple? Or both?
Would List[Tuple[str, ...]], be enough?

Also can we remove the # type: ignore[union-attr] below now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to cast because of the Optional bit or because of the exact length of the tuple? Or both?
Would List[Tuple[str, ...]], be enough?

List[Tuple[str, ...]] seems to work out. I assumed I needed a two element tuple due to the assignment in L177.

Also can we remove the # type: ignore[union-attr] below now?

Nope. re.match returns Optional[Match] and since we don't check for match is None because we are sure that we will always match, mypy complains that None has no attribute groups.

List[Tuple[str, str]],
[
pattern.match(line).groups() # type: ignore[union-attr]
# the first and last line contain no information
for line in lines[1:-1]
],
)
categories_and_labels.sort(key=lambda category_and_label: int(category_and_label[1]))
categories, _ = zip(*categories_and_labels)

return categories
Loading