diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index e0cb1e8a7..c60e11693 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -8,7 +8,17 @@ """ import logging from abc import ABC, abstractmethod -from typing import Callable, Generator, Generic, List, Optional, Tuple, Union +from typing import ( + Callable, + Generator, + Generic, + Iterator, + List, + Optional, + Tuple, + Union, + cast, +) import zarr from numpy.typing import NDArray @@ -67,10 +77,13 @@ def __call__( Returns: A list containing all the tensors provided by the tensor_generator. """ - gen: Union[tqdm[TensorType], ] = tensor_generator + + gen = cast(Iterator[TensorType], tensor_generator) if len_generator is not None: - gen = tqdm(gen, total=len_generator, desc="Blocks") + gen = cast( + Iterator[TensorType], tqdm(gen, total=len_generator, desc="Blocks") + ) return [t for t in gen] @@ -117,15 +130,18 @@ def __call__( A list of lists, where each inner list contains tensors returned from one of the inner generators. """ - outer_gen = nested_generators_of_tensors + outer_gen = cast(Iterator[Iterator[TensorType]], nested_generators_of_tensors) if len_outer_generator is not None: - outer_gen = tqdm(outer_gen, total=len_outer_generator, desc="Row blocks") + outer_gen = cast( + Iterator[Iterator[TensorType]], + tqdm(outer_gen, total=len_outer_generator, desc="Row blocks"), + ) return [list(tensor_gen) for tensor_gen in outer_gen] -class LazyChunkSequence: +class LazyChunkSequence(Generic[TensorType]): """ A class representing a chunked, and lazily evaluated array, where the chunking is restricted to the first dimension @@ -202,10 +218,12 @@ def to_zarr( row_idx = 0 z = None - gen = self.generator_factory() + gen = cast(Iterator[TensorType], self.generator_factory()) if self.len_generator is not None: - gen = tqdm(gen, total=self.len_generator, desc="Blocks") + gen = cast( + Iterator[TensorType], tqdm(gen, total=self.len_generator, desc="Blocks") + ) for block in gen: numpy_block = converter.to_numpy(block) @@ -240,7 +258,7 @@ def _initialize_zarr_array(block: NDArray, path_or_url: str, overwrite: bool): ) -class NestedLazyChunkSequence: +class NestedLazyChunkSequence(Generic[TensorType]): """ A class representing chunked, and lazily evaluated array, where the chunking is restricted to the first two dimensions. @@ -323,11 +341,14 @@ def to_zarr( row_idx = 0 z = None numpy_block = None - block_generator = self.generator_factory() + block_generator = cast(Iterator[Iterator[TensorType]], self.generator_factory()) if self.len_outer_generator is not None: - block_generator = tqdm( - block_generator, total=self.len_outer_generator, desc="Row blocks" + block_generator = cast( + Iterator[Iterator[TensorType]], + tqdm( + block_generator, total=self.len_outer_generator, desc="Row blocks" + ), ) for row_blocks in block_generator: diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index 73fe53d8f..541fbedf0 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -4,6 +4,8 @@ from enum import Enum from typing import Collection, Generic, Iterable, Optional, Type, TypeVar +__all__ = ["InfluenceMode"] + class InfluenceMode(str, Enum): """ diff --git a/src/pydvl/influence/influence_calculator.py b/src/pydvl/influence/influence_calculator.py index 1a40bdd5a..7c48e8636 100644 --- a/src/pydvl/influence/influence_calculator.py +++ b/src/pydvl/influence/influence_calculator.py @@ -7,7 +7,7 @@ import logging from functools import partial -from typing import Generator, Iterable, Optional, Tuple, Type, Union +from typing import Generator, Iterable, Optional, Sized, Tuple, Type, Union, cast import distributed from dask import array as da @@ -620,7 +620,7 @@ def influence_factors( A lazy data structure representing the chunks of the resulting tensor """ try: - len_iterable = len(data_iterable) + len_iterable = len(cast(Sized, data_iterable)) except Exception as e: logger.debug(f"Failed to retrieve len of data iterable: {e}") len_iterable = None @@ -684,7 +684,7 @@ def influences( ) try: - len_iterable = len(test_data_iterable) + len_iterable = len(cast(Sized, test_data_iterable)) except Exception as e: logger.debug(f"Failed to retrieve len of test data iterable: {e}") len_iterable = None @@ -751,7 +751,7 @@ def influences_from_factors( ) try: - len_iterable = len(z_test_factors) + len_iterable = len(cast(Sized, z_test_factors)) except Exception as e: logger.debug(f"Failed to retrieve len of factors iterable: {e}") len_iterable = None diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 46a5fa16e..fe3290195 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -41,6 +41,15 @@ flatten_dimensions, ) +__all__ = [ + "DirectInfluence", + "CgInfluence", + "LissaInfluence", + "ArnoldiInfluence", + "EkfacInfluence", + "NystroemSketchInfluence", +] + logger = logging.getLogger(__name__) diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 58385df5b..581894af2 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -7,12 +7,14 @@ Dict, Generator, Iterable, + Iterator, List, Mapping, Optional, Tuple, Type, Union, + cast, ) import dask @@ -37,6 +39,8 @@ "TorchCatAggregator", "NestedTorchCatAggregator", "torch_dataset_to_dask_array", + "EkfacRepresentation", + "empirical_cross_entropy_loss_fn", ] @@ -297,11 +301,11 @@ def _infer_data_len(d_set: Dataset): return total_size else: logger.warning( - err_msg + f" Infer the number of samples from the dataset, " - f"via iterating the dataset once. " - f"This might induce severe overhead, so consider" - f"providing total_size, if you know the number of samples " - f"beforehand." + err_msg + " Infer the number of samples from the dataset, " + "via iterating the dataset once. " + "This might induce severe overhead, so consider" + "providing total_size, if you know the number of samples " + "beforehand." ) idx = 0 while True: @@ -419,10 +423,12 @@ def __call__( A single tensor formed by concatenating all tensors from the generator. The concatenation is performed along the default dimension (0). """ - t_gen = tensor_generator + t_gen = cast(Iterator[torch.Tensor], tensor_generator) if len_generator is not None: - t_gen = tqdm(t_gen, total=len_generator, desc="Blocks") + t_gen = cast( + Iterator[torch.Tensor], tqdm(t_gen, total=len_generator, desc="Blocks") + ) return torch.cat(list(t_gen)) @@ -459,10 +465,13 @@ def __call__( """ - outer_gen = nested_generators_of_tensors + outer_gen = cast(Iterator[Iterator[torch.Tensor]], nested_generators_of_tensors) if len_outer_generator is not None: - outer_gen = tqdm(outer_gen, total=len_outer_generator, desc="Row blocks") + outer_gen = cast( + Iterator[Iterator[torch.Tensor]], + tqdm(outer_gen, total=len_outer_generator, desc="Row blocks"), + ) return torch.cat( list(