Skip to content

Commit

Permalink
Add type casting to satisfy mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Apr 22, 2024
1 parent 93e52b9 commit 56166d3
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 25 deletions.
45 changes: 33 additions & 12 deletions src/pydvl/influence/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@
"""
import logging
from abc import ABC, abstractmethod
from typing import Callable, Generator, Generic, List, Optional, Tuple, Union
from typing import (
Callable,
Generator,
Generic,
Iterator,
List,
Optional,
Tuple,
Union,
cast,
)

import zarr
from numpy.typing import NDArray
Expand Down Expand Up @@ -67,10 +77,13 @@ def __call__(
Returns:
A list containing all the tensors provided by the tensor_generator.
"""
gen: Union[tqdm[TensorType], ] = tensor_generator

gen = cast(Iterator[TensorType], tensor_generator)

if len_generator is not None:
gen = tqdm(gen, total=len_generator, desc="Blocks")
gen = cast(
Iterator[TensorType], tqdm(gen, total=len_generator, desc="Blocks")
)

return [t for t in gen]

Expand Down Expand Up @@ -117,15 +130,18 @@ def __call__(
A list of lists, where each inner list contains tensors returned from one
of the inner generators.
"""
outer_gen = nested_generators_of_tensors
outer_gen = cast(Iterator[Iterator[TensorType]], nested_generators_of_tensors)

if len_outer_generator is not None:
outer_gen = tqdm(outer_gen, total=len_outer_generator, desc="Row blocks")
outer_gen = cast(
Iterator[Iterator[TensorType]],
tqdm(outer_gen, total=len_outer_generator, desc="Row blocks"),
)

return [list(tensor_gen) for tensor_gen in outer_gen]


class LazyChunkSequence:
class LazyChunkSequence(Generic[TensorType]):
"""
A class representing a chunked, and lazily evaluated array,
where the chunking is restricted to the first dimension
Expand Down Expand Up @@ -202,10 +218,12 @@ def to_zarr(
row_idx = 0
z = None

gen = self.generator_factory()
gen = cast(Iterator[TensorType], self.generator_factory())

if self.len_generator is not None:
gen = tqdm(gen, total=self.len_generator, desc="Blocks")
gen = cast(
Iterator[TensorType], tqdm(gen, total=self.len_generator, desc="Blocks")
)

for block in gen:
numpy_block = converter.to_numpy(block)
Expand Down Expand Up @@ -240,7 +258,7 @@ def _initialize_zarr_array(block: NDArray, path_or_url: str, overwrite: bool):
)


class NestedLazyChunkSequence:
class NestedLazyChunkSequence(Generic[TensorType]):
"""
A class representing chunked, and lazily evaluated array, where the chunking is
restricted to the first two dimensions.
Expand Down Expand Up @@ -323,11 +341,14 @@ def to_zarr(
row_idx = 0
z = None
numpy_block = None
block_generator = self.generator_factory()
block_generator = cast(Iterator[Iterator[TensorType]], self.generator_factory())

if self.len_outer_generator is not None:
block_generator = tqdm(
block_generator, total=self.len_outer_generator, desc="Row blocks"
block_generator = cast(
Iterator[Iterator[TensorType]],
tqdm(
block_generator, total=self.len_outer_generator, desc="Row blocks"
),
)

for row_blocks in block_generator:
Expand Down
2 changes: 2 additions & 0 deletions src/pydvl/influence/base_influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from enum import Enum
from typing import Collection, Generic, Iterable, Optional, Type, TypeVar

__all__ = ["InfluenceMode"]


class InfluenceMode(str, Enum):
"""
Expand Down
8 changes: 4 additions & 4 deletions src/pydvl/influence/influence_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
from functools import partial
from typing import Generator, Iterable, Optional, Tuple, Type, Union
from typing import Generator, Iterable, Optional, Sized, Tuple, Type, Union, cast

import distributed
from dask import array as da
Expand Down Expand Up @@ -620,7 +620,7 @@ def influence_factors(
A lazy data structure representing the chunks of the resulting tensor
"""
try:
len_iterable = len(data_iterable)
len_iterable = len(cast(Sized, data_iterable))
except Exception as e:
logger.debug(f"Failed to retrieve len of data iterable: {e}")
len_iterable = None
Expand Down Expand Up @@ -684,7 +684,7 @@ def influences(
)

try:
len_iterable = len(test_data_iterable)
len_iterable = len(cast(Sized, test_data_iterable))
except Exception as e:
logger.debug(f"Failed to retrieve len of test data iterable: {e}")
len_iterable = None
Expand Down Expand Up @@ -751,7 +751,7 @@ def influences_from_factors(
)

try:
len_iterable = len(z_test_factors)
len_iterable = len(cast(Sized, z_test_factors))
except Exception as e:
logger.debug(f"Failed to retrieve len of factors iterable: {e}")
len_iterable = None
Expand Down
9 changes: 9 additions & 0 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@
flatten_dimensions,
)

__all__ = [
"DirectInfluence",
"CgInfluence",
"LissaInfluence",
"ArnoldiInfluence",
"EkfacInfluence",
"NystroemSketchInfluence",
]

logger = logging.getLogger(__name__)


Expand Down
27 changes: 18 additions & 9 deletions src/pydvl/influence/torch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
Dict,
Generator,
Iterable,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)

import dask
Expand All @@ -37,6 +39,8 @@
"TorchCatAggregator",
"NestedTorchCatAggregator",
"torch_dataset_to_dask_array",
"EkfacRepresentation",
"empirical_cross_entropy_loss_fn",
]


Expand Down Expand Up @@ -297,11 +301,11 @@ def _infer_data_len(d_set: Dataset):
return total_size
else:
logger.warning(
err_msg + f" Infer the number of samples from the dataset, "
f"via iterating the dataset once. "
f"This might induce severe overhead, so consider"
f"providing total_size, if you know the number of samples "
f"beforehand."
err_msg + " Infer the number of samples from the dataset, "
"via iterating the dataset once. "
"This might induce severe overhead, so consider"
"providing total_size, if you know the number of samples "
"beforehand."
)
idx = 0
while True:
Expand Down Expand Up @@ -419,10 +423,12 @@ def __call__(
A single tensor formed by concatenating all tensors from the generator.
The concatenation is performed along the default dimension (0).
"""
t_gen = tensor_generator
t_gen = cast(Iterator[torch.Tensor], tensor_generator)

if len_generator is not None:
t_gen = tqdm(t_gen, total=len_generator, desc="Blocks")
t_gen = cast(
Iterator[torch.Tensor], tqdm(t_gen, total=len_generator, desc="Blocks")
)

return torch.cat(list(t_gen))

Expand Down Expand Up @@ -459,10 +465,13 @@ def __call__(
"""

outer_gen = nested_generators_of_tensors
outer_gen = cast(Iterator[Iterator[torch.Tensor]], nested_generators_of_tensors)

if len_outer_generator is not None:
outer_gen = tqdm(outer_gen, total=len_outer_generator, desc="Row blocks")
outer_gen = cast(
Iterator[Iterator[torch.Tensor]],
tqdm(outer_gen, total=len_outer_generator, desc="Row blocks"),
)

return torch.cat(
list(
Expand Down

0 comments on commit 56166d3

Please sign in to comment.